diff --git a/model/train/yoco_moe/.gitignore b/model/train/yoco_moe/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ce476ed083154ad05372b58380b1b5c3fec8de99
--- /dev/null
+++ b/model/train/yoco_moe/.gitignore
@@ -0,0 +1,5 @@
+*.pyc
+.idea/
+.vscode/
+.pytest_cache/
+*.egg-info/
diff --git a/model/train/yoco_moe/LICENSE b/model/train/yoco_moe/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..762244cf2c63c01efe635c591d7cfe7b102307ba
--- /dev/null
+++ b/model/train/yoco_moe/LICENSE
@@ -0,0 +1,290 @@
+The following applies to all files unless otherwise noted;
+Copyright (c) 2024, Bytedance Inc.
+Copyright (c) 2023, Huawei Technologies Co., Ltd
+Copyright (c) 2022, NVIDIA CORPORATION.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+--
+
+This repository also contains code from Microsoft (from their DeepSpeed
+project). Files from these organization(s) have notices at the top of each file.
+Below are licenses used in those files, as indicated.
+
+
+----------------------------- LICENSE FOR Microsoft code, Facebook, huggingface and Google Research code -----------------------------
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+------------- LICENSE FOR various code from Facebook --------------
+
+MIT License
+
+Copyright (c) Facebook, Inc. and its affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+------------- LICENSE FOR Mircrosoft Swin transformer code --------------
+
+MIT License
+
+Copyright (c) Microsoft Corporation.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/model/train/yoco_moe/OWNERS b/model/train/yoco_moe/OWNERS
new file mode 100644
index 0000000000000000000000000000000000000000..d8969ab0706e619f098e6017e45009aa8dc00ca1
--- /dev/null
+++ b/model/train/yoco_moe/OWNERS
@@ -0,0 +1,36 @@
+approvers:
+- gp513
+- qingfenxiaochong
+- clinglai
+- bingobb
+- jializheng
+reviewers:
+- gp513
+- qingfenxiaochong
+- pxc911
+- sanjaychan
+- jializheng
+- zhizaidicengshehua
+- li-baokui
+- arch-devil
+- fishhhqi
+- ryanyeee
+- glhyy
+- bingobb
+- mingzhenwang1
+- wangzw1022
+- aijgnem
+- clinglai
+- ckirchhoff
+- liu-tianlong
+- akstt
+- yu-xinjie62
+- yangyuan667
+- veronica-lam
+- kong-deyuan
+- zhao-yifan27
+- wang-xiaochao23
+- harygo2
+- ChenDonYY
+- wenjiang2357
+- z0020952
diff --git a/model/train/yoco_moe/README.md b/model/train/yoco_moe/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6213b8f95cb916d266aa5c756500b743750f1b8
--- /dev/null
+++ b/model/train/yoco_moe/README.md
@@ -0,0 +1,948 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# 简介
+
+MindSpeed 是针对华为[昇腾设备](https://www.hiascend.com/)的大模型加速库。
+
+大模型训练是一种非常复杂的过程,涉及到许多技术和挑战,其中大模型训练需要大量的显存资源是一个难题,对计算卡提出了不小的挑战。
+为了在单个计算卡显存资源不足时,可以通过多张计算卡进行计算,业界出现了类似 Megatron、DeepSpeed 等第三方大模型加速库,对模型、输入数据等进行切分并分配到不同的计算卡上,最后再通过集合通信对结果进行汇总。
+
+昇腾提供 MindSpeed 加速库,使能客户大模型业务快速迁移至昇腾设备,并且支持昇腾专有算法,确保开箱可用。
+
+---
+# 安装
+
+### 1. 安装依赖
+
+☀️ 在安装 **MindSpeed** 之前,请参考[版本配套表](#版本配套表),安装最新昇腾软件栈。
+
+
+
+
+
+
+### 2. 安装 MindSpeed
+
+☀️ 下载源码安装:
+
+ ```shell
+ git clone -b core_r0.8.0 https://gitee.com/ascend/MindSpeed.git
+ pip install -e MindSpeed
+ ```
+
+如需使用Ascend Transformer Boost(ATB)加速库算子,请先安装 CANN-NNAL 并初始化添加环境,例如:
+ ```shell
+# CANN-NNAL默认安装路径为:/usr/local/Ascend/nnal
+# 运行CANN-NNAL默认安装路径下atb文件夹中的环境配置脚本set_env.sh
+source /usr/local/Ascend/nnal/atb/set_env.sh
+ ```
+
+### 3. 获取 Megatron-LM 并指定分支
+
+☀️ 获取并切换 Megatron-LM 版本至 core_r0.8.0 的release版本,如下所示:
+ ```shell
+ git clone https://github.com/NVIDIA/Megatron-LM.git
+ cd Megatron-LM
+ git checkout core_r0.8.0
+ ```
+
+# 快速上手
+
+1. 仅仅一行代码就可以轻松使能 MindSpeed 的各项功能。以 GPT 模型为例:在 Megatron-LM 目录下修改`pretrain_gpt.py`文件,在`import torch`下新增一行:`import mindspeed.megatron_adaptor`,即如下修改:
+
+ ```diff
+ import os
+ import torch
+ +import mindspeed.megatron_adaptor
+ from functools import partial
+ from typing import Union
+ ```
+
+2. 在 Megatron-LM 目录下,准备好训练数据,并在示例脚本中填写对应路径,然后执行。
+ ```shell
+ bash examples/gpt3/train_gpt3_175b_distributed.sh
+ ```
+---
+# 自定义优化级别
+MindSpeed 提供了多层次的优化解决方案,并划分为三个层级,用户可根据实际需求灵活启用任意层级。高层级兼容低层级的能力,确保了整个系统的稳定性和扩展性。
+用户可以通过设置启动脚本中的 `--optimization-level {层级}` 参数来自定义开启的优化层级。该参数支持以下配置:
+
+
+
+ 层级 |
+ 层级名称 |
+ 介绍 |
+
+
+
+ 0 |
+ 基础兼容层 |
+ 提供Megatron-LM框架对NPU的支持,确保无缝集成。该层包含基础功能集patch,保证可靠性和稳定性,为高级优化奠定基础。 |
+
+
+
+ 1 |
+ 亲和性增强层🔥 |
+ 兼容L0能力,集成高性能融合算子库,结合昇腾亲和的计算优化,充分释放昇腾算力,显著提升计算效率。 |
+
+
+
+ 2 |
+ 自研加速算法层🔥🔥 |
+ 默认值。该模式兼容了L1, L0能力,并集成了昇腾多项自主研发核心技术成果,可提供全面的性能优化。 |
+
+
+
+
+# 特性介绍
+MindSpeed 特性由七大模块组成,分别为:megetron特性支持、并行策略特性、内存优化特性、亲和计算特性、通信优化特性、关键场景特性以及多模态特性。其中【Released】表示是否商用发布,原型特性为非商用发布。
+
+- 特性的介绍中说明了对应特性的应用场景及使用说明。一般而言,在脚本中加入相关参数即可轻松使用对应特性。🛰️
+
+- 当前大模型训练主要使用bf16数据类型,以下特性若无特殊声明原则上兼容fp16, 如遇到fp16场景下相关问题请联系 MindSpeed 团队或提交issue, 我们会快速响应。🛰️
+## Megatron特性支持
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+
+
+
+ Megatron 数据并行 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 张量并行 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 流水并行 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 虚拟流水并行 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 分布式优化器 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 序列并行 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 异步DDP |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 权重更新通信隐藏 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+ Megatron 重计算 |
+ link |
+ ✅ |
+ ✅ |
+
+
+
+
+
+## 并行策略特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend Ulysses 长序列并行 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Ring Attention 长序列并行 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend 混合长序列并行 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend 自定义空操作层 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+
+
+## 内存优化特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend 自适应选择重计算 |
+ link |
+ ❌ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend 激活函数重计算 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend 重计算流水线独立调度 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Mask归一 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend BF16 参数副本复用 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend swap_attention |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Norm重计算 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Hccl Buffer 自适应 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+
+## 亲和计算特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend rms_norm 融合算子 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend swiglu 融合算子 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend rotary_embedding 融合算子 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend flash attention |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Moe Token Permute and Unpermute 融合算子 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend npu_matmul_add_fp32 梯度累加融合算子 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Moe BMM通算融合算子 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+ Ascend 计算通信并行优化 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+ Ascend MC2 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+ Ascend fusion_attention_v2 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+
+## 通信优化特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend nano-pipe流水线并行 |
+ link |
+ ❌ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Gloo 存档落盘优化 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend 高维张量并行 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+## Mcore MoE特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend Megatron MoE GMM |
+ link |
+ ✅ |
+ ❌ |
+ ✅ |
+
+
+
+ Ascend Megatron MoE Allgather Dispatcher 性能优化 |
+ link |
+ ✅ |
+ ❌ |
+ ✅ |
+
+
+
+ Ascend Megatron MoE Alltoall Dispatcher 性能优化 |
+ link |
+ ✅ |
+ ❌ |
+ ✅ |
+
+
+
+ Ascend Megatron MoE TP拓展EP |
+ link |
+ ✅ |
+ ❌ |
+ ✅ |
+
+
+
+ Ascend 共享专家 |
+ link |
+ ✅ |
+ ❌ |
+ ✅ |
+
+
+
+ Ascend Megatron MoE 负载感知内存均衡算 |
+ link |
+ ✅ |
+ ❌ |
+ ❌ |
+
+
+
+ Ascend Megatron MoE 分层通信 |
+ link |
+ ✅ |
+ ❌ |
+ ❌ |
+
+
+
+ Ascend Megatron MoE 大专家流水 |
+ link |
+ ✅ |
+ ❌ |
+ ❌ |
+
+
+
+
+## DeepSpeed MoE特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ DeepSpeed MoE |
+ link |
+ ❌ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend DeepSpeed MoE token 重排性能优化 |
+ link |
+ ❌ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend DeepSpeed MoE dropless 性能优化 |
+ link |
+ ❌ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend MLP 通信隐藏 |
+ link |
+ ❌ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Ampipe流水通信隐藏 |
+ link |
+ ❌ |
+ ✅ |
+ ✅ |
+
+
+
+## 关键场景特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend EOD Reset训练场景 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend alibi |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+## 多模态特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend fused ema adamw优化器 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+ Ascend PP支持动态形状 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend PP支持多参数传递 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+ Ascend PP支持多参数传递和动态形状 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+## 其它特性
+
+
+
+ 特性名称 |
+ 介绍 |
+ Mcore |
+ Legacy |
+ Released |
+
+
+
+ Ascend TFLOPS计算 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend Auto Tuning 并行策略自动搜索系统 |
+ link |
+ ✅ |
+ ✅ |
+ ✅ |
+
+
+
+ Ascend 确定性计算 |
+ link |
+ ✅ |
+ ✅ |
+ ❌ |
+
+
+
+
+## 自定义算子
+
+
+部分自定义算子设置为公开接口,公开接口设置说明请参照 MindSpeed 安全声明中的[公开接口声明](SECURITYNOTE.md#公开接口声明),具体对外接口细节参照以下算子对应的手册链接。
+
+
+
+ 自定义算子名称 |
+ 介绍 |
+ Released |
+
+
+
+ npu_dropout_add_layer_norm |
+ link |
+ ✅ |
+
+
+
+ npu_rotary_position_embedding |
+ link |
+ ✅ |
+
+
+
+ fusion_attention |
+ link |
+ ✅ |
+
+
+
+ rms_norm |
+ link |
+ ✅ |
+
+
+
+ swiglu |
+ link |
+ ✅ |
+
+
+
+ npu_mm_all_reduce_add_rms_norm |
+ link |
+ ✅ |
+
+
+
+ npu_mm_all_reduce_add_rms_norm_ |
+ link |
+ ✅ |
+
+
+
+ npu_gmm |
+ link |
+ ✅ |
+
+
+
+ npu_grouped_mat_mul_all_reduce |
+ link |
+ ✅ |
+
+
+
+ lcal_coc |
+ link |
+ ❌ |
+
+
+
+ ffn |
+ link |
+ ❌ |
+
+
+
+ npu_fused_moe_token_permute |
+ link |
+ ❌ |
+
+
+
+ npu_fused_moe_token_unpermute |
+ link |
+ ❌ |
+
+
+
+ npu_ring_attention_update |
+ link |
+ ❌ |
+
+
+
+ npu_matmul_add_fp32 |
+ link |
+ ❌ |
+
+
+
+ npu_groupmatmul_add_fp32 |
+ link |
+ ❌ |
+
+
+
+ npu_all_to_all_all_gather_bmm |
+ link |
+ ❌ |
+
+
+
+ npu_bmm_reduce_scatter_all_to_all |
+ link |
+ ❌ |
+
+
+
+ quant_gmm |
+ link |
+ ❌ |
+
+
+
+ npu_apply_fused_ema_adamw |
+ link |
+ ❌ |
+
+
+
+---
+# MindSpeed 中采集Profile数据
+
+📝 MindSpeed 支持命令式开启Profile采集数据,命令配置介绍如下:
+
+| 配置命令 | 命令含义 |
+|-------------------------|-----------------------------------------------------------------------------------|
+| --profile | 打开profile开关 |
+| --profile-step-start | 配置开始采集步,未配置时默认为10, 配置举例: --profile-step-start 30 |
+| --profile-step-end | 配置结束采集步,未配置时默认为12, 配置举例: --profile-step-end 35 |
+| --profile-level | 配置采集等级,未配置时默认为level0, 可选配置: level0, level1, level2, 配置举例: --profile-level level1 |
+| --profile-with-cpu | 打开cpu信息采集开关 |
+| --profile-with-stack | 打开stack信息采集开关 |
+| --profile-with-memory | 打开memory信息采集开关,配置本开关时需打开--profile-with-cpu |
+| --profile-record-shapes | 打开shapes信息采集开关 |
+| --profile-save-path | 配置采集信息保存路径, 未配置时默认为./profile_dir, 配置举例: --profile-save-path ./result_dir |
+| --profile-ranks | 配置待采集的ranks,未配置时默认为-1,表示采集所有rank的profiling数据,配置举例: --profile-ranks 0 1 2 3, 需注意: 该配置值为每个rank在单机/集群中的全局值 |
+
+---
+# 版本配套表
+
+💡 **PyTorch Extension**版本号采用`{PyTorch版本}-{昇腾版本}`命名规则,前者为**PyTorch Extension**匹配的PyTorch版本,后者用于匹配CANN版本,详细匹配如下:
+
+| MindSpeed版本 | Megatron版本 | PyTorch版本 | torch_npu版本 | CANN版本 | Python版本 | 硬件型态 |
+|-------------------------|-----------------|------------- |-------------|---------|----------------------------------------|----------|
+| master(主线) | Core 0.8.0 | 2.1.0 | 在研版本 | 在研版本 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| core_r0.7.0(主线) | Core 0.7.0 | 2.1.0 | 在研版本 | 在研版本 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| core_r0.6.0(主线) | Core 0.6.0 | 2.1.0 | 在研版本 | 在研版本 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| 1.0.0_core_r0.7.0(商用) | Core 0.7.0 | 2.1.0 | 6.0.0 | 8.0.0 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| 1.0.0_core_r0.6.0(商用) | Core 0.6.0 | 2.1.0 | 6.0.0 | 8.0.0 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| 1.0.RC3_core_r0.7.0(商用) | Core 0.7.0 | 2.1.0 | 6.0.RC3 | 8.0.RC3 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| 1.0.RC3_core_r0.6.0(商用) | Core 0.6.0 | 2.1.0 | 6.0.RC3 | 8.0.RC3 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| 1.0.RC2(商用) | Core 0.6.0 | 2.1.0 | 6.0.RC2 | 8.0.RC2 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+| 1.0.RC1(商用) | commitid bcce6f | 2.1.0 | 6.0.RC1 | 8.0.RC1 | Python3.8.x, Python3.9.x, Python3.10.x | Atlas 200T A2 Box16, Atlas 800T A2, Atlas 900 A2 PODc |
+
+[昇腾辅助软件](https://gitee.com/ascend/pytorch#%E6%98%87%E8%85%BE%E8%BE%85%E5%8A%A9%E8%BD%AF%E4%BB%B6)中有更多关于PyTorch和CANN的版本信息。
+
+# 分支维护策略
+
+🛠️ MindSpeed 版本分支的维护阶段如下:
+
+| **状态** | **时间** | **说明** |
+| ------------------- | -------- |----------------------------------------------------------------------|
+| 计划 🕐 | 1-3 个月 | 计划特性 |
+| 开发 🕔 | 3 个月 | 开发特性 |
+| 维护 🕚 | 6-12 个月| 合入所有已解决的问题并发布版本,针对不同的MindSpeed 版本采取不同的维护策略,常规版本和长期支持版本维护周期分别为6个月和12个月 |
+| 无维护 🕛 | 0-3 个月 | 合入所有已解决的问题,无专职维护人员,无版本发布 |
+| 生命周期终止(EOL)🚫 | N/A | 分支不再接受任何修改 |
+
+🛠️ MindSpeed 版本维护策略:
+
+| **MindSpeed版本** | **维护策略** | **当前状态** | **发布时间** | **后续状态** | **EOL日期** |
+|---------------------|-----------|---------|------------|--------------------|-----------|
+| 1.0.0_core_r0.7.0 | 常规版本 | 开发 | 2024/12/30 | 预计2025/6/30起无维护 | |
+| 1.0.0_core_r0.6.0 | 常规版本 | 开发 | 2024/12/30 | 预计2025/6/30起无维护 | |
+| 1.0.RC3_core_r0.7.0 | 常规版本 | 维护 | 2024/09/30 | 预计2025/3/30起无维护 | |
+| 1.0.RC3_core_r0.6.0 | 常规版本 | 维护 | 2024/09/30 | 预计2025/3/30起无维护 | |
+| 1.0.RC2 | 常规版本 | 维护 | 2024/06/30 | 预计2024/12/30起无维护 | |
+| 1.0.RC1 | 常规版本 | 停止维护 | 2024/03/30 | 2024/9/30起无维护 | |
+
+---
+
+# 常见问题
+
+| 现象 | 介绍 |
+|------------------------------------|---------------------------------------|
+| Data helpers 数据预处理出错 ❗ | [link](docs/faq/data_helpers.md) |
+| Torch extensions 编译卡住 ❗ | [link](docs/faq/torch_extensions.md) |
+| megatron0.7.0版本长稳测试出现grad norm为nan ❗| [link](docs/faq/megatron070_grad_norm_nan.md) |
+| Gloo建链失败Gloo connectFullMesh failed with ... ❗| [link](docs/features/hccl-replace-gloo.md) |
+
+# 技术文章
+- [MindSpeed 加速百万级超长序列大模型训练](https://mp.weixin.qq.com/s/8q4MxCkosLn0yoneuxzynw) 🚀🚀
+- [MindSpeed 加速万亿MoE大模型训练](https://mp.weixin.qq.com/s/HQRzYzSUNNMonv5d1AP0OQ) 🚀🚀
+- [大模型训练内存优化难?MindSpeed 帮你来支招](https://mp.weixin.qq.com/s/lwjVgM67hwsgtOKp06zYPg) 🚀🚀
+
+# 安全声明
+
+⚠️ [MindSpeed 安全声明](SECURITYNOTE.md)
+
+---
+
+# 致谢
+
+🔎 MindSpeed-Core 由华为公司的下列部门联合贡献 :
+
+华为公司:
+
+- 昇腾计算产品部
+- 计算算法部
+- 计算软件平台部
+- 计算技术开发部
+- 公共开发部:NAIE
+- 网络技术实验室
+
+此外,MindSpeed-Core 感谢以下团队对项目的贡献:
+
+- 微信基础架构团队
+
+感谢来自社区的每一个PR,欢迎贡献 MindSpeed-Core!
diff --git a/model/train/yoco_moe/SECURITYNOTE.md b/model/train/yoco_moe/SECURITYNOTE.md
new file mode 100644
index 0000000000000000000000000000000000000000..d909e54e13f8efc7fbdebb5b3f0e2a57f8598526
--- /dev/null
+++ b/model/train/yoco_moe/SECURITYNOTE.md
@@ -0,0 +1,116 @@
+# MindSpeed 安全声明
+
+## 系统安全加固
+
+- 用户可在运行系统配置时开启 ASLR(级别2)以提高系统安全性,保护系统随机化开启。
+可参考以下方式进行配置:
+
+ ```
+ echo 2 > /proc/sys/kernel/randomize_va_space
+ ```
+
+## 运行用户建议
+
+- 基于安全性考虑,建议您在执行任何命令时,不建议使用root等管理员类型账户执行,遵循权限最小化原则。
+
+## 文件权限控制
+
+- 建议用户在主机(包括宿主机)及容器中设置运行系统umask值为0027及以上,保障新增文件夹默认最高权限为750,新增文件默认最高权限为640。
+- 建议用户对训练所需文件、训练过程中保存的文件、用户个人的隐私数据、商业资产等敏感文件做好权限控制等安全措施,例如多用户共享数据集场景下的数据集文件写权限控制等,设定的权限建议参考[附录A 文件(夹)各场景权限管控推荐最大值](#A-文件(夹)各场景权限管控推荐最大值)进行设置。
+- MindSpeed 中各类融合算子通过调用 PyTorch 中的 cpp_extension 特性进行编译,编译结果会默认缓存到 `~/.cache/torch_extensions` 目录下,建议用户根据自身需要,参考[附录A 文件(夹)各场景权限管控推荐最大值](#A-文件(夹)各场景权限管控推荐最大值)对生成文件做好权限控制。
+- 原生 Megatron-LM 以及 PyTorch 框架运行中所生成的文件权限依赖系统设定,如 Megatron-LM 生成的数据集索引文件、torch.save 接口保存的文件等。建议当前执行脚本的用户根据自身需要,对生成文件做好权限控制,设定的权限可参考[附录A 文件(夹)各场景权限管控推荐最大值](#A-文件(夹)各场景权限管控推荐最大值)进行设置。
+- 运行时 CANN 可能会缓存算子编译文件,存储在运行目录下的`kernel_meta_*`文件夹内,加快后续训练的运行速度,用户可根据需要自行对生成后的相关文件进行权限控制。
+- 用户安装和使用过程需要做好权限控制,建议参考[附录A 文件(夹)各场景权限管控推荐最大值](#A-文件(夹)各场景权限管控推荐最大值)文件权限参考进行设置。如需要保存安装/卸载日志,可在安装/卸载命令后面加上参数 `--log `, 注意对``文件及目录做好权限管控。
+
+## 数据安全声明
+
+- MindSpeed 依赖 CANN 的基础能力实现 AOE 性能调优、算子 dump、日志记录等功能,用户需要关注上述功能生成文件的权限控制。
+
+## 运行安全声明
+
+- 建议用户结合运行环境资源状况编写对应训练脚本。若训练脚本与资源状况不匹配,如数据集加载内存大小超出内存容量限制、训练脚本在本地生成数据超过磁盘空间大小等情况,可能引发错误并导致进程意外退出。
+- MindSpeed 在运行异常时会退出进程并打印报错信息,建议根据报错提示定位具体错误原因,包括设定算子同步执行、查看 CANN 日志、解析生成的 Core Dump 文件等方式。
+
+## 公网地址声明
+- MindSpeed代码中包含公网地址声明如下表所示:
+
+| 类型 | 开源代码地址 | 文件名 | 公网IP地址/公网URL地址/域名/邮箱地址 | 用途说明 |
+| :------------: |:------------------------------------------------------------------------------------------:|:----------------------------------------------------------:| :----------------------------------------------------------: |:-----------------------------------------:|
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | mindspeed/moe/gate.py | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | deepspeed moe源码地址 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | mindspeed/moe/gate.py | https://arxiv.org/pdf/2006.16668.pdf | 开源引入TopKGate类实现 |
+| 开源引入 | https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py | mindspeed/moe/gate.py | https://arxiv.org/pdf/2202.08906.pdf | 开源引入apply_z_loss实现 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | mindspeed/moe/moe_layer.py | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | deepspeed moe源码地址 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | mindspeed/moe/moe_layer.py | https://arxiv.org/pdf/2006.16668.pdf | 开源引入MOELayer类实现 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/layer.py | mindspeed/moe/mixtral_parallel_mlpbm.py | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/layer.py | deepspeed moe源码地址 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/layer.py | mindspeed/moe/moe.py | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/layer.py | deepspeed moe源码地址 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | mindspeed/moe/utils.py | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | deepspeed moe源码地址 |
+| 开源引入 | https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py | mindspeed/moe/utils.py | https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py | megatron moe源码地址 |
+| 开源引入 | https://github.com/pytorch/pytorch/pull/40762 | mindspeed/moe/utils.py | https://github.com/pytorch/pytorch/pull/40762 | alltoall实现源码 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py | mindspeed/moe/utils.py | https://arxiv.org/pdf/2006.16668.pdf | einsum论文地址 |
+| 开源引入 | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py | mindspeed/moe/experts.py | https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py | deepspeed moe源码地址 |
+| 开源引入 | https://github.com/HazyResearch/flash-attention | docs/features/flash-attention.md | https://arxiv.org/pdf/2205.14135 | flash-attention说明文档 |
+| 开源引入 | https://github.com/nvidia/megatron-lm | docs/features/virtual-pipeline-parallel.md | https://people.eecs.berkeley.edu/~matei/papers/2021/sc_megatron_lm.pdf | virtual-pipeline-parallel说明文档 |
+| 开源引入 | https://github.com/feifeibear/long-context-attention | docs/features/hybrid-context-parallel.md | https://arxiv.org/abs/2405.07719 | hybrid-context-parallel说明文档 |
+| 开源引入 | https://github.com/feifeibear/long-context-attention | docs/features/ring-attention-context-parallel.md | https://arxiv.org/pdf/2310.01889 | ring-attention-context-parallel说明文档 |
+| 开源引入 | https://github.com/ofirpress/attention_with_linear_biases | docs/features/alibi.md | https://arxiv.org/pdf/2108.12409 | alibi说明文档 |
+| 开源引入 | https://github.com/NVIDIA/Megatron-LM | docs/features/sequence-parallel.md | https://arxiv.org/pdf/2205.05198 | sequence-parallel说明文档 |
+| 开源引入 | https://github.com/NVIDIA/Megatron-LM | docs/features/pipeline-parallel.md | https://arxiv.org/pdf/1806.03377 | pipeline-parallel说明文档 |
+| 开源引入 | https://github.com/NVIDIA/Megatron-LM/pull/598 | docs/faq/data_helpers.md | https://github.com/NVIDIA/Megatron-LM/pull/598 | data_helpers说明文档 |
+| 开源引入 | https://pytorch.org/docs/stable/distributed.html | mindspeed/core/parallel_state.py | https://pytorch.org/docs/stable/distributed.html | torch.distributed相关接口注意事项 |
+| 开源引入 | https://github.com/pytorch/pytorch/pull/40762 | mindspeed/moe/utils.py | https://github.com/pytorch/pytorch/pull/40762 | _AllToAll自动反向参考 |
+| 开源引入 | https://github.com/NVIDIA/Megatron-LM | mindspeed/optimizer/distrib_optimizer.py | https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md | distributed_optimizer_zero3_init文档字符串参数说明 |
+| 开源引入 | https://github.com/InternLM/InternEvo | mindspeed/docs/features/ring-attention-context-parallel.md | https://arxiv.org/pdf/2406.18485 | ring-attention-context-parallel说明文档 |
+| 开源引入 | https://github.com/sail-sg/zero-bubble-pipeline-parallelism | mindspeed/docs/features/nanopipe-pipeline-parallel.md | https://arxiv.org/abs/2401.10241 | nanopipe-pipeline-parallel说明文档 |
+| 开源引入 | https://github.com/iclr24-3434/AMPipe.git | mindspeed/docs/features/ampipe.md | https://openreview.net/pdf?id=yLgr02IsXY | ampipe说明文档 |
+| 开源引入 | https://gitee.com/ascend/pytorch | mindspeed/docs/features/adaptive-recompute.md | https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha001/apiref/envref/envref_07_0053.html | 环境变量`PYTORCH_NPU_ALLOC_CONF`说明文档 |
+| 开源引入 | https://github.com/deepseek-ai/DeepSeek-MoE | mindspeed/docs/features/shared-experts.md | https://arxiv.org/pdf/2401.06066 | 共享专家说明文档 |
+| 开源引入 | https://gitee.com/ascend/MindSpeed | mindspeed/setup.py | https://gitee.com/ascend/MindSpeed | MindSpeed源码地址 |
+| 开源引入 | https://gitee.com/ascend/MindSpeed/release | mindspeed/setup.py | https://gitee.com/ascend/MindSpeed/release | MindSpeed源码地址 |
+| 开源引入 | https://packaging.python.org/en/latest/single_source_version.html | mindspeed/setup.py | https://packaging.python.org/en/latest/single_source_version.html | MindSpeed版本管理 |
+| 开源引入 | https://github.com/NVIDIA/TransformerEngine/pull/719 | mindspeed/core/data_parallel/distributed_data_parallel.py | https://github.com/NVIDIA/TransformerEngine/pull/719 | use_distributed_optimizer实现源码 |
+
+## 公开接口声明
+
+-MindSpeed已更新其接口策略,现在除了对原生megatron在昇腾设备的无缝支持,还新增了针对融合算子的公开接口。用户在使用时,可以直接调用这些新增的融合算子接口,以充分利用MindSpeed在特定计算任务上的优化能力。
+#### 判断函数是否为公开接口:
+如果一个函数被定义在__all__中,并且在MindSpeed/tree/{分支}/docs 中进行了对外接口的文档记录,则该接口为公开接口,可以依赖其作为公共函数。该对外接口的具体使用方法以及场景请参照docs中的接口使用手册说明。如果需要依赖一个在文档中未记录的函数,请在MindSpeed主页开启Issue向我们确认该函数是否为公开接口、是否是因意外暴露、或者可能在未来被移除。
+
+## 通信安全加固
+
+[通信安全加固说明](https://gitee.com/ascend/pytorch/blob/master/SECURITYNOTE.md#%E9%80%9A%E4%BF%A1%E5%AE%89%E5%85%A8%E5%8A%A0%E5%9B%BA
+)
+
+## 通信矩阵
+[通信矩阵说明](https://gitee.com/ascend/pytorch/blob/master/SECURITYNOTE.md#%E9%80%9A%E4%BF%A1%E7%9F%A9%E9%98%B5%E4%BF%A1%E6%81%AF)
+
+### 特殊场景
+| 场景 | 使用方法 | 端口 | 可能的风险 |
+|-------------------------------------| ------------------------------------------------ | ---------- | ---------- |
+| 用户下载并使用HuggingFace的开源数据集 | 调用`load_dataset`函数,并填写目标开源数据集路径 | 随机端口 | 数据集可能包含敏感或不合法内容,导致合规问题。数据集中可能存在质量问题,如标签错误或数据偏差,影响数据预处理。|
+| 使用`from_pretrained`信任特定代码,使用相关模型的实现 | 调用`from_pretrained`函数,设置`trust_remote_code=True` | 随机端口 |如果 trust_remote_code=True,下载的代码可能包含恶意逻辑或后门,威胁系统安全。但同时已设置local_files_only=True,程序仅会运行本地的文件来规避风险。|
+| 调用auto_tuning进行训练任务时,新增端口 | torchrun拉起训练端口 auto_tuning通过此端口指定MindSpeed拉起特定配置采集Profiling信息 | [1024, 65535]内 |业务需要,无风险 |
+
+
+## 附录
+
+### A-文件(夹)各场景权限管控推荐最大值
+
+| 类型 | linux权限参考最大值 |
+| -------------- | --------------- |
+| 用户主目录 | 750(rwxr-x---) |
+| 程序文件(含脚本文件、库文件等) | 550(r-xr-x---) |
+| 程序文件目录 | 550(r-xr-x---) |
+| 配置文件 | 640(rw-r-----) |
+| 配置文件目录 | 750(rwxr-x---) |
+| 日志文件(记录完毕或者已经归档) | 440(r--r-----) |
+| 日志文件(正在记录) | 640(rw-r-----) |
+| 日志文件目录 | 750(rwxr-x---) |
+| Debug文件 | 640(rw-r-----) |
+| Debug文件目录 | 750(rwxr-x---) |
+| 临时文件目录 | 750(rwxr-x---) |
+| 维护升级文件目录 | 770(rwxrwx---) |
+| 业务数据文件 | 640(rw-r-----) |
+| 业务数据文件目录 | 750(rwxr-x---) |
+| 密钥组件、私钥、证书、密文文件目录 | 700(rwx—----) |
+| 密钥组件、私钥、证书、加密密文 | 600(rw-------) |
+| 加解密接口、加解密脚本 | 500(r-x------) |
diff --git a/model/train/yoco_moe/Third_Party_Open_Source_Software_Notice b/model/train/yoco_moe/Third_Party_Open_Source_Software_Notice
new file mode 100644
index 0000000000000000000000000000000000000000..c5d32948cfa10bd71bc6829e896bac7be97b221d
--- /dev/null
+++ b/model/train/yoco_moe/Third_Party_Open_Source_Software_Notice
@@ -0,0 +1,141 @@
+OPEN SOURCE SOFTWARE NOTICE
+
+Please note we provide an open source software notice along with this product and/or this product firmware (in the following just “this product”). The open source software licenses are granted by the respective right holders. And the open source licenses prevail all other license information with regard to the respective open source software contained in the product, including but not limited to End User Software Licensing Agreement. This notice is provided on behalf of Huawei Technologies Co. Ltd. and any of its local subsidiaries which may have provided this product to you in your local country.
+
+Warranty Disclaimer
+THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
+
+Copyright Notice and License Texts
+Software: pytorch v2.1.0
+Copyright notice:
+Copyright (c) 2016- Facebook, Inc
+Copyright (c) 2014- Facebook, Inc
+Copyright (c) 2011-2014 Idiap Research Institute
+Copyright (c) 2012-2014 Deepmind Technologies
+Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
+Copyright (c) 2011-2013 NYU
+Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
+Copyright (c) 2006 Idiap Research Institute
+Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
+Copyright (c) 2016-present, Facebook Inc.
+Copyright (c) 2016 Facebook Inc.
+Copyright (c) 2015 Google Inc.
+Copyright (c) 2015 Yangqing Jia
+Copyright 2019-2020 Kakao Brain
+Copyright (c) 2022 Cruise LLC.
+Copyright (c) 2013, 2014, 2015, the respective contributors
+Copyright (c) 2015, 2016 the respective contributors
+Copyright (c) 2014, The Regents of the University of California (Regents)
+Copyright (c) 2014, the respective contributors
+Copyright (c) 2018, Steven Moshier
+Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers
+Copyright (c) 1997-2011 by Secret Labs AB
+Copyright (c) 1995-2011 by Fredrik Lundh
+Copyright (c) 2010-2022 by Alex Clark and contributors
+Copyright (c) 2006 The Android Open Source Project
+Copyright (c) Facebook, Inc. and its affiliates
+Copyright (c) Meta Platforms, Inc. and affiliates
+Copyright 2004-present Facebook
+Copyright (c) 2017 by Contributors
+Copyright (c) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura
+Copyright (c) 2022 Apple Inc.
+Copyright (c) 2023 Apple Inc.
+Copyright 2005 Robert Kern (robert.kern@gmail.com)
+copyright 2019 The TensorFlow Authors
+Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
+Copyright (c) 2014 Indiana University (c)
+Copyright John Maddock 2006
+Copyright (c) 2012 Massachusetts Institute of Technology
+Copyright (c) 2012 Giovanni Garberoglio Interdisciplinary Laboratory for Computational Science (LISC) Fondazione Bruno Kessler and University of Trento
+Copyright (c) 2018 Marat Dukhan
+Copyright (c) 2017-2018 Facebook Inc.
+Copyright (c) 2017 Georgia Institute of Technology
+Copyright 2015 Google Inc.
+Copyright (c) 2011-2021, NVIDIA CORPORATION.
+Copyright (c) 2022, Tri Dao
+Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES.
+Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES.
+Copyright (c) 2017 The Android Open Source Project
+Copyright (c) 2016-present, Facebook, Inc.
+Copyright (c) 2005-2020 Rich Felker
+Copyright Malte Skarupke 2017
+Copyright 2008 Google Inc.
+Copyright (c) 2011 - 2012 Andrzej Krzemienski
+Copyright (c) 2001-2019 Free Software Foundation, Inc.
+Copyright (c) 1994 Hewlett-Packard Company
+Copyright (c) 1996-1998 Silicon Graphics Computer Systems, Inc.
+Copyright (c) Bjorn Fahller
+Copyright Michael Park, 2015-2017
+Copyright (c) 2017-present, Facebook, Inc.
+Copyright (c) 2018-present, Facebook, Inc.
+Copyright (c) 2008-2015 The Khronos Group Inc.
+Copyright 2016 Facebook
+Copyright (c) 2016, NVIDIA CORPORATION
+Copyright (c) 2008 - 2012 The Khronos Group Inc.
+Copyright (c) 2008-2013 The Khronos Group Inc.
+Copyright (c) 2008-2012 The Khronos Group Inc.
+Copyright (c) 2016-2017, ARM Limited and Contributors
+Copyright (c) 2014-2015 The Khronos Group Inc.
+Copyright (c) 2015-2017 The Khronos Group Inc.
+Copyright (c) Facebook Inc. and Microsoft Corporation
+Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+Copyright (c) 2014-2017, the respective contributors
+Copyright (c) 2017 Microsoft
+Copyright 2015 The Gemmlowp Authors
+Copyright (c) 2011-2019 Stephan Brumme
+Copyright 2006, Google Inc.
+Copyright (c) Meta Platforms, Inc. and its affiliates
+Copyright (c) 2008 - 2009 NVIDIA Corporation
+Copyright (c) 2007-2009 Scientific Computing and Imaging Institute, University of Utah
+Copyright (c) 2006, Laurent Montel, montel@kde.org
+Copyright 2013 Conrad Steenberg conrad.steenberg@gmail.com
+copyright 2022, PyTorch
+copyright 2023, PyTorch
+Copyright (c) 2005-2022 NVIDIA Corporation Built
+copyright PyTorch Contributors
+Copyright (c) 2018 Alex Rogozhnikov
+Copyright (c) 2016 Microsoft
+Copyright (c) 2014, 2015, The Regents of the University of California (Regents)
+Copyright (c) 2014, 2015, the respective contributors
+Copyright (c) 2005-2017, NumPy Developers (c) Parameter containing Float
+Copyright 2005, Google Inc.
+Copyright 2019 Kakao Brain
+Copyright 2013-2014 RAD Game
+Copyright 2010-2014 Rich Geldreich and Tenacious Software LLC
+Copyright 2016 Martin Raiber
+Copyright (c) 2003-2017 Josef Weidendorfer
+Copyright (c) 2000-2017 Julian Seward
+Copyright (c) Edward Z. Yang ezyang@mit.edu
+Copyright (c) 2005-2010 ActiveState Software Inc.
+Copyright (c) 2013 Eddy Petrisor
+Copyright (c) 2010 ActiveState Software Inc.
+Copyright (c) 2001-2014 Python Software Foundation
+Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation
+Copyright Python Software Foundation
+Copyright 2022 Cruise LLC
+Copyright (c) 2014 Matthew Rocklin
+Copyright (c) 2015 Melissa E. O'Neill
+Copyright (c) 2019 NumPy Developers
+Copyright (c) 2015-2016 Advanced Micro Devices, Inc.
+Copyright 2013 Mark Dickinson
+
+License: BSD 3-Clause License
+Copyright (c) , ,
+All rights reserved.
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+Written Offer
+This product contains software whose rights holders license it on the terms of the GNU General Public License, version 2 (GPLv2) and/or other open source software licenses. We will provide you and any third party with the source code of the software licensed under an open source software license if you send us a written request by mail or email to the following addresses:
+foss@huawei.com
+detailing the name of the product and the firmware version for which you need the source code and indicating how we can contact you.
+
+Please note you need to make a payment before you obtain the complete Corresponding Source Code from us. For how much you will pay and how we will deliver the complete Corresponding Source Code to you, we will further discuss it by mail or email.
+This offer is valid to anyone in receipt of this information.
+
+THIS OFFER IS VALID FOR THREE YEARS FROM THE MOMENT WE DISTRIBUTED THE PRODUCT OR FIRMWARE.
\ No newline at end of file
diff --git a/model/train/yoco_moe/ci/access_control_test.py b/model/train/yoco_moe/ci/access_control_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..58a90dfc63f804a17869cc6448ae0c4ff401b854
--- /dev/null
+++ b/model/train/yoco_moe/ci/access_control_test.py
@@ -0,0 +1,43 @@
+import os
+import stat
+import sys
+import unittest
+from pathlib import Path
+import xmlrunner
+
+
+# =============================
+# ST test, run with shell
+# =============================
+def success_check(res):
+ if res != 0:
+ sys.exit(1)
+
+
+def success_check_ut(res):
+ if len(res.failures) + len(res.errors) != 0:
+ sys.exit(1)
+
+
+class ST_Test:
+ def __init__(self):
+ self.shell_file_list = []
+
+
+ def run_shell(self):
+ for shell_file in self.shell_file_list:
+ success_check(os.system("sh {}".format(shell_file)))
+
+# ===============================================
+# UT test, run with pytest, waiting for more ...
+# ===============================================
+
+
+if __name__ == "__main__":
+ st_test = ST_Test()
+ st_test.run_shell()
+ test_loader = unittest.TestLoader()
+ discover = test_loader.discover(start_dir="./", pattern="test*.py")
+
+ runner = unittest.TextTestRunner()
+ success_check_ut(runner.run(discover))
diff --git a/model/train/yoco_moe/docs/LOGO.png b/model/train/yoco_moe/docs/LOGO.png
new file mode 100644
index 0000000000000000000000000000000000000000..46b994b2ada90fd3f421120a3c6edf58ea0bf56b
Binary files /dev/null and b/model/train/yoco_moe/docs/LOGO.png differ
diff --git a/model/train/yoco_moe/docs/RELEASENOTE.md b/model/train/yoco_moe/docs/RELEASENOTE.md
new file mode 100644
index 0000000000000000000000000000000000000000..ae88b447289661d5a0cce921e761056dde1c3fed
--- /dev/null
+++ b/model/train/yoco_moe/docs/RELEASENOTE.md
@@ -0,0 +1,54 @@
+# MindSpeed 版本说明书
+- [MindSpeed 1.0](#FrameworkPTAdapter-5-0-RC1md)
+ - [用户须知](#用户须知md)
+ - [新增特性](#新增特性md)
+ - [特性修改](#特性修改md)
+ - [已修复问题](#已修复问题md)
+ - [已知问题](#已知问题md)
+ - [兼容性](#兼容性md)
+
+
+## MindSpeed 1.0
+
+### 用户须知
+
+本框架基于NVIDIA主导的开源Megatron进行修改,采用插件化适配方式,延续原生的Megatron特性,使用NPU进行大模型加速训练;代码重用性好,支持现有的网络只修改设备类型或数据类型,即可迁移到NPU上使用。使能客户大模型业务快速迁移至昇腾设备,并且支持昇腾专有算法。
+
+### 新增特性
+
+**表 1** MindSpeed支持的版本特性列表
+
+| 一级特性 | 二级特性 | 说明 |
+| -------------- | --------------- | --------------- |
+| Megatron原生特性 | 数据并行 | 支持数据并行训练策略 |
+| | 张量并行 | 支持张量并行训练策略 |
+| | 流水并行 | 支持流水并行训练策略 |
+| | 张量并行 | 支持张量并行训练策略 |
+| | 虚拟流水并行 | 支持虚拟流水并行训练策略 |
+| | 序列并行 | 支持序列并行训练策略 |
+| | 重计算 | 支持选择性重计算和完全重计算策略 |
+| | 分布式优化器 | 支持分布式优化器策略,将优化器状态拆分到所有DP组间 |
+| | 异步DDP | 支持异步DDP,在进行梯度更新时,将数据并行组中的通信和计算并行执行 |
+| 昇腾专有算法 | TP 重计算通信优化 | 重计算通信算子消除,优化重计算层划分,实现大模型训练通信性能提升 |
+| | 自适应选择重计算 | 支持通过自动调整训练内存大小来自动选择重新计算策略 |
+| | 计算通信并行优化 | 通过将计算和通信任务分别拆分成更细粒度的子任务来实现相互的流水掩盖 |
+| 昇腾自定义算子 | npu_dropout_add_layer_norm | 支持自定义算子npu_dropout_add_layer_norm调用 |
+
+### 特性修改
+
+不涉及
+
+### 已修复问题
+
+不涉及
+
+### 已知问题
+
+| 已知问题 | 问题描述 |
+| -------------- | --------------- |
+
+### 兼容性
+
+A800-9010:CentOS 7.6/Ubuntu 18.04, 2.04/BC-Linux 7.6/Debian 9.9/Debian 10/OpenEuler 20.03 LTS
+
+A800-9000:CentOS 7.6/Ubuntu 18.04, 2.04/Euler 2.8, 2.10/Kylin v10/BC-Linux 7.6/OpenEuler 20.03 LTS/UOS 20 1020e
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/faq/data_helpers.md b/model/train/yoco_moe/docs/faq/data_helpers.md
new file mode 100644
index 0000000000000000000000000000000000000000..8312c1861a594f640185420e07e056e220b80ec4
--- /dev/null
+++ b/model/train/yoco_moe/docs/faq/data_helpers.md
@@ -0,0 +1,62 @@
+# Data helpers overflow bug
+## 问题现象
+在增大 gbs、iteration 等理论上不影响模型内存的参数后,出现OOM现象,或者在模型预处理数据集的阶段报如下错误:
+```shell
+Traceback (most recent call last):
+ File "pretrain_gpt.py", line 121, in
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
+ File "/home/ma-user/modelarts/user-job-dir/GPT-3-kernel_ID2728_for_PyTorch_zgcl/megatron/training.py", line 150, in pretrain
+ process_non_loss_data_func)
+ File "/home/ma-user/modelarts/user-job-dir/GPT-3-kernel_ID2728_for_PyTorch_zgcl/megatron/training.py", line 689, in train
+ opt_param_scheduler)
+ File "/home/ma-user/modelarts/user-job-dir/GPT-3-kernel_ID2728_for_PyTorch_zgcl/megatron/training.py", line 417, in train_step
+ optimizer, fwd_bwd_timers, forward_only=False)
+ File "/home/ma-user/modelarts/user-job-dir/GPT-3-kernel_ID2728_for_PyTorch_zgcl/megatron/schedules.py", line 654, in forward_backward_pipelining_without_interleaving
+ timers, collect_non_loss_data)
+ File "/home/ma-user/modelarts/user-job-dir/GPT-3-kernel_ID2728_for_PyTorch_zgcl/megatron/schedules.py", line 118, in forward_step
+ output_tensor, loss_func = forward_step_func(data_iterator, model)
+ File "pretrain_gpt.py", line 84, in forward_step
+ data_iterator)
+ File "pretrain_gpt.py", line 45, in get_batch
+ data = next(data_iterator)
+ File "/home/ma-user/anaconda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 530, in __next__
+ data = self._next_data()
+ File "/home/ma-user/anaconda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 570, in _next_data
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
+ File "/home/ma-user/anaconda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
+ return self.collate_fn(data)
+ File "/home/ma-user/anaconda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in default_collate
+ return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
+ File "/home/ma-user/anaconda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in
+ return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
+ File "/home/ma-user/anaconda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 146, in default_collate
+ return default_collate([torch.as_tensor(b) for b in batch])
+ File "/home/ma-user/anaconda/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 138, in default_collate
+ return torch.stack(batch, 0, out=out)
+RuntimeError: stack expects each tensor to be equal size, but got [8193] at entry 0 and [8246] at entry 1
+```
+
+## 问题根因
+在 `megatron/core/datasets/helpers.cpp` 文件里的 `build_sample_idx()` 函数中创建了 `sample_idx` 的 int32 数组去记录每个 sample 的 index,
+而每个 sample 的 index 又是以 `doc_idx_index` 这个 int64 的变量去计算,在 `sample_idx[2 * sample_index] = doc_idx_index;` 这个赋值操作中存在溢出的可能。
+在数据集中的句子较短,而要求训练的步数 * Global Batch Size * Sequence Length 较大的情况下就会出现 `doc_idx_index` 超过 int32 的表达范围而导致最终的 index 溢出。
+
+## 解决方案
+
+#### 规避方案
+1. 减小模型训练步数
+
+
+#### 推荐方案
+1. 将相关变量修改为 int64 数据类型,具体可查看:[PR](https://github.com/NVIDIA/Megatron-LM/pull/598)
+
+ > 可以在 Megatron-LM 目录下,运行`mindspeed -P`命令,自动完成修改。
+ >```shell
+ > mindspeed -P
+ >```
+2. 删除 `megatron/core/datasets/` 下面的 `helpers.cpython-xx-xxx-linux-gnu.so` 文件。
+3. 删除已生成的数据集缓存文件夹,例如 `enwiki/my-t5_text_sentence/cache/GPTDataset_indices`。
+
+
+## 备注
+此问题为 Megatron-LM 原生问题,CPP 代码难以通过 monkey patch 的方式进行修改。已多次提交修复 PR,但似乎 Megatron-LM 较为封闭,无人管理且不接受来自社区的代码提交。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/faq/megatron070_grad_norm_nan.md b/model/train/yoco_moe/docs/faq/megatron070_grad_norm_nan.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf4f2f97c4995dc9399605212a962b5b086aad89
--- /dev/null
+++ b/model/train/yoco_moe/docs/faq/megatron070_grad_norm_nan.md
@@ -0,0 +1,51 @@
+# megatron0.7.0版本长稳测试出现grad norm为nan
+## 问题现象
+在megatron0.7.0版本中,采用mindspeed自定义`--tokenizer-type PretrainedFromHF`, 长稳测试一定步数后发现loss抖动异常最终出现grad norm为nan的问题,报错示例如下:
+```
+2024-09-18 11:14:247 iteration 427/ 5000 consumed samples: 6832 elapsed time per iteration (
+ms): 209.8 | Learning rate: 1.229919E-06 | global batch size: 16 | Lm loss: 8.567080E+00 | loss scale: 1.0 | gr
+ad norm: 35.518 | number of skipped iterations: О | number of nan iterations: 0
+[2024-09-18 11:14:25] iteration 428/ 5000] consumed samples: 6848 elapsed time per iteration (
+ms): 210.5 | Learning rate: 1.229826E-06 | global batch size: _ 16 | lm loss: 7.180392E+00 | loss scale: 1.0 | gr
+ad norm: 36.838 ] number of skipped iterations: О | number of nan iterations:
+Traceback (most recent call last):
+File "pretrain_gpt.py”, line 247, in
+pretrain(
+File "/home/Megatron-LM/megatron/training/training.py”, Line 274, in pretrain
+iteration, num floating point operations so far = train(
+File "/home/Megatron-LM/megatron/training/training.py”, Line 1027, in train
+train step(forward step func,
+File "/home/Megatron-LM/megatron/training/training.py”, Line 550, in train_step
+losses reduced = forward backward func(
+File "/home/Megatron-LM/megatron/core/pipeline parallel/schedules.py”, line 1400, in forward backward
+pipelining without interleaving
+config.finalize model grads func(
+File "/home/Megatron-LM/megatron/core/distributed/finalize model_grads.py”, Line 113, in finalize mode
+l grads
+model chunk.finish grad sync()
+File "/home/Megatron-LM/megatron/core/distributed/distributed data parallel.py”, Line 248, in finish_g
+rad sync
+buffer.finish grad sync()
+File "/home/Megatron-LM/megatron/core/distributed/param and_grad buffer.py”, Line 513, in finish_grad
+sync
+bucket.finish grad sync()
+File "/home/Megatron-LM/megatron/core/distributed/param and_grad buffer.py”, Line 151, in finish_grad
+sync
+self.start grad sync()
+File “/home/Megatron-LM/megatron/core/distributed/param and grad buffer.py”, Line 114, in start_grad_s
+ync
+assert not norm.isnan( ), (
+AssertionError: Rank 13: found NaN in local grad norm in backward pass before data-parallel communication collectie
+ve. Device: 5, node: node-15-11
+```
+
+## 问题根因
+
+1. 问题场景使用的数据集生成时,增加了`--append-eod`参数,这会让每个数据sample末尾增加一个eos结束标志位;
+2. megatron0.7.0对数据集提取过程增加了pad功能(在`class GPTDataset`类中),`PretrainedFromHF`模式下,会将pad标志位与eos标志位配成相同值(`pad_token_id == eos_token_id`)。loss_mask中会去掉pad标志位,但实际去掉的都是eos标志位。
+3. 以上两个原因综合导致了grad norm为nan的问题,这个问题是megatron原生问题,相同配置下实测GPU中也会报错。
+
+
+## 解决方案
+
+在`--tokenizer-type PretrainedFromHF`模式下,不使用`--append-eod`生成数据集
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/faq/torch_extensions.md b/model/train/yoco_moe/docs/faq/torch_extensions.md
new file mode 100644
index 0000000000000000000000000000000000000000..049febfef595e502b7a4994f24854446dbbbc3d0
--- /dev/null
+++ b/model/train/yoco_moe/docs/faq/torch_extensions.md
@@ -0,0 +1,22 @@
+# Torch extensions 卡住
+## 问题现象
+在模型运行时,卡在如下场景,且等待十几分钟无反应。
+```
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+Using ~/.cache/torch_extensions/py38_cpu as PyTorch extensions root...
+```
+
+## 问题根因
+此问题为 Pytorch extension 编译问题,编译开始前其中一个线程会生成 `.lock` 文件对编译文件夹进行锁定,其他线程会进行等待。
+如果因为其他原因导致编译的线程中途被强制结束,`.lock` 文件不会被清除,导致第二次编译开始时,所有的线程看到存在 `.lock` 文件,就都会开始进行等待。
+
+
+## 解决方案
+
+删除 `~/.cache/torch_extensions/py38_cpu` 文件夹,再重新启动程序。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/Automatic_Parallelism.md b/model/train/yoco_moe/docs/features/Automatic_Parallelism.md
new file mode 100644
index 0000000000000000000000000000000000000000..2cfd65c874e959e0bb6492b663af0e2c81060324
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/Automatic_Parallelism.md
@@ -0,0 +1,157 @@
+## Automatic Parallelism
+
+## 问题分析
+
+当前主流的大模型并行训练方法有PP、TP、DP、SP、CP、Ulyssess Parallel(UP)、VPP、EP等,在内存、计算、通信方面都有不同的优化,直接叠加。大模型端到端训练性能由模型结构、集群规模、并行配置、batch_size等因素共同决定,在调优时需要综合考虑。当前并行配置人工调优需要大量的专家经验、人工分析和实验调优,预计数天~数周,实验成本高。相似模型的最优并行配置也并不相同,仍需花费时间进行优化。随着搜索空间变大,依赖手工调优变的不可行。例如,llama65B模型在4*8的集群规模下,仅考虑PP、TP、DP、SP、VP、mbs六个维度,配置组合有812种,手工调优时间成本太高。因此,需要构建自动并行系统根据模型结构和集群规模给用户自动推荐一个性能较优的并行配置策略。
+
+## 解决方案
+
+针对该问题场景提出多维并行配置自动寻优算法,在给定模型结构、集群配置的条件下,用户仅需要在启动脚本中配置相关参数即可启动多维并行配置自动寻优,在规定时间内找到较优的并行配置推荐给用户。算法原理图如下:
+
+* **内存自适应感知的搜索空间构建**:考虑模型结构和集群信息约束,采用内存灰盒模型排除OOM并行配置,缩小搜索空间;
+* **基于算子不确定性估计的高保序性Cost Model建模方法**:引入低保真数据(单算子调用)作为先验信息,结合算子整网性能数据构建算子执行耗时的不确定性模型,结合通信耗时根据并行策略合成得到端到端性能的概率分布模型。
+* **基于概率匹配的高效搜索算法**:基于Thompson Sampling方法探索并行策略,以高概率探索高价值并行配置,提高探索效率,灵活支持探索早停,提高易用性。
+
+
+
+
+**并行配置的支持情况:**
+
+已支持搜索的并行配置维度:
+
+- [x] PP
+- [x] TP
+- [x] DP
+- [x] CP
+- [x] DeepSpeed-Ulyssess
+- [x] Megatron-SP
+- [x] mbs
+
+正在支持的并行配置维度:
+
+- [ ] MOE
+- [ ] VP
+- [ ] 自适应重计算
+
+## 使用方法
+
+在使用多维自动并行特性时,**需使用python作为脚本启动器,在所有的节点上拉起脚本**,并配置多维自动并行相关的参数。相关参数及其函数如下表所示:
+
+| 参数名 | 参数含义 |
+| ---------------- | ------------------------------------------------- |
+| --auto-parallel | 多维自动并行特性总开关 |
+| --nodes | 集群中节点的个数 |
+| --nproc-per-node | 每个节点中计算设备的个数 |
+| --master-addr | 集群中主节点的IP地址 |
+| --master-port | 用于通信的端口号,各节点需要配置相同的端口号 |
+| --node-rank | 集群中节点的rank,主节点为0,其他节点为1,2,······ |
+
+注:算法最长搜索时间为8小时,支持灵活提前退出,无需人工干预。
+
+下面是基于llama7B模型的配置示例:
+
+```shell
+#!/bin/bash
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export NPU_ASD_ENABLE=0
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+
+MASTER_ADDR=localhost
+MASTER_PORT=6001
+GPUS_PER_NODE=8
+NNODES=1
+NODE_RANK=0
+
+LOAD_CHECKPOINT_PATH=./ckpt
+SAVE_CHECKPOINT_PATH=./ckpt
+DATA_PATH={your dataset path}
+TOKENIZER_MODEL={your tokenizer model path}
+TP=1
+PP=8
+
+DISTRIBUTED_ARGS="
+ --nproc_per_node $GPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank $NODE_RANK \
+ --master_addr $MASTER_ADDR \
+ --master_port $MASTER_PORT
+"
+
+GPT_ARGS="
+ --tensor-model-parallel-size ${TP} \
+ --pipeline-model-parallel-size ${PP} \
+ --sequence-parallel \
+ --num-layers 32 \
+ --hidden-size 4096 \
+ --ffn-hidden-size 11008 \
+ --num-attention-heads 32 \
+ --tokenizer-type Llama2Tokenizer \
+ --tokenizer-model ${TOKENIZER_MODEL} \
+ --seq-length 2048 \
+ --max-position-embeddings 2048 \
+ --micro-batch-size 4 \
+ --global-batch-size 256 \
+ --make-vocab-size-divisible-by 1 \
+ --lr 1.0e-6 \
+ --train-iters 5000 \
+ --lr-decay-style cosine \
+ --untie-embeddings-and-output-weights \
+ --disable-bias-linear \
+ --attention-dropout 0.0 \
+ --init-method-std 0.01 \
+ --hidden-dropout 0.0 \
+ --position-embedding-type rope \
+ --normalization RMSNorm \
+ --use-fused-rmsnorm \
+ --swiglu \
+ --use-flash-attn \
+ --no-masked-softmax-fusion \
+ --attention-softmax-in-fp32 \
+ --min-lr 1.0e-7 \
+ --weight-decay 1e-1 \
+ --lr-warmup-fraction 0.01 \
+ --clip-grad 1.0 \
+ --adam-beta1 0.9 \
+ --initial-loss-scale 65536 \
+ --adam-beta2 0.95 \
+ --no-gradient-accumulation-fusion \
+ --load ${LOAD_CHECKPOINT_PATH} \
+ --no-load-optim \
+ --no-load-rng \
+ --fp16
+"
+
+DATA_ARGS="
+ --data-path $DATA_PATH \
+ --split 100,0,0
+"
+
+OUTPUT_ARGS="
+ --log-interval 1 \
+ --save-interval 10000 \
+ --eval-interval 1000 \
+ --eval-iters 0 \
+"
+
+SEARCH_ARGS="
+ --auto-parallel \
+ --nnodes $NNODES \
+ --nproc-per-node $GPUS_PER_NODE \
+ --master-addr $MASTER_ADDR \
+ --master-port $MASTER_PORT \
+ --node-rank $NODE_RANK \
+"
+
+python pretrain_gpt.py \
+ $GPT_ARGS \
+ $DATA_ARGS \
+ $OUTPUT_ARGS \
+ $SEARCH_ARGS \
+ --distributed-backend nccl \
+ | tee logs/search_llama_7b.txt
+```
+
+## 使用效果
+
+
+
diff --git a/model/train/yoco_moe/docs/features/activation-function-recompute.md b/model/train/yoco_moe/docs/features/activation-function-recompute.md
new file mode 100644
index 0000000000000000000000000000000000000000..976391342a477323e4819f015da2d985d120f623
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/activation-function-recompute.md
@@ -0,0 +1,94 @@
+# 激活函数重计算
+
+## 问题分析
+
+现有的大模型训练框架中,重计算和反向计算是绑定在一起调度的,这严重限制了重计算的灵活性。在某些场景下,会限制重计算在模型性能上的优化。
+
+比如在模型中存在某个流程:
+
+前向:gelu激活函数模块->后续模块A。
+
+反向:后续模块A的反向(需要gelu输出的激活值)->gelu反向(与重计算绑定)。
+
+gelu激活函数会产生大量的数据,但本身计算量很小。此时进行激活函数的重计算可以在性能劣化极少的代价下,减少内存占用。
+但在现有重计算框架下,如果对gelu激活函数模块做重计算,并不能节省gelu函数的输出。这是因为在反向时,模块A所需要的gelu输出的激活值,会早于gelu激活函数模块的重计算流程,所以前向必须保留激活函数的输出,导致激活函数的输出并不能节省下来。
+
+
+## 解决方案
+
+本特性重新实现了一套重计算框架,可以将重计算灵活地插入到反向计算之前的任意位置。
+
+反向(新框架):
+
+gelu函数重计算->后续模块A的反向。
+
+此时,gelu函数的输出已经早于模块A的反向,在前向时就无须保留gelu函数的输出值。
+
+## 解决思路
+
+通过设计一种传入模块函数进行重计算的机制,在合适的时机,丢弃重计算模块输出的物理存储,保留逻辑视图。在反向时,在恰当时机,利用register_hook插入重计算流程,并利用传入的函数重新进行计算,得到结果。
+
+例如,gelu在mlp中的位置下图所示。反向计算需要前向产生的a,b,c, d。其中b, c的shape为(batch, seq , 4 * hidden_szie),gelu为激活函数,其计算较少,故可将tensor c释放掉,反向在 4h->h 反向前重新计算。
+
+
+
+在前向4h->h计算完毕后,将c释放,保留逻辑视图。在4h->h grad前,需要将c计算回来。这里使用给d打tensor_hook的方式来进行重计算的插入,如下图所示:
+
+
+
+## 使用场景
+
+主要用于训练场景,用户内存不足或要节省内存时。
+
+## 使用方法
+
+脚本中添加:`--recompute-activation-function` 可开启激活函数重计算。
+
+添加:`--recompute-activation-function-num-layers ${num}` 可指定激活函数重计算的层数。
+
+激活函数重计算可以与全重计算同时开启:
+
+1.同时开启时,仅支持 `--recompute-method 为 block`
+
+2.同时开启时,会按照指定的全重计算和激活函数重计算的层数做各自类型的重计算,即不会有一层既做全重计算又做激活函数重计算。
+
+(注意点:执行优先级是先计算全重计算层,后计算激活函数重计算层。在流水线并行未开启的情况下,全重计算层数和激活函数重计算层数之和应该等于总层数。)
+
+3.暂不兼容自适应重计算特性。
+
+## 使用效果
+激活函数重计算在llama2-7B场景下,根据模型配置不同,收益也会发生改变。
+在不同参数场景下,激活函数重计算收益表现如下:
+| 模型参数 | 设备数 | 内存收益 |
+|-----------------------------------------------------------------------------|----------|-------------|
+| seq-length=12288、micro-batch-size=2、TP=4、DP=2 | 8卡(单机) | 8.05G |
+| seq-length=8096、micro-batch-size=2、TP=4、DP=2 | 8卡(单机) | 5.31G |
+| seq-length=16384、micro-batch-size=2、TP=4、DP=2 | 8卡(单机) | 12.49G |
+| seq-length=12288、micro-batch-size=1、TP=4、DP=2 | 8卡(单机) | 4.04G |
+| seq-length=12288、micro-batch-size=2、TP=8、DP=1 | 8卡(单机) | 4.02G |
+
+## 扩展使用
+
+本特性引入的 CheckpointWithoutOutput 类可以自定义对任何模块进行重计算,并且在合适的时机进行重计算恢复。
+
+此处提供一个示例,可以灵活使用 CheckpointWithoutOutput 来对自定义的函数进行重计算:
+
+(注意:如要使用register_hook,需要确保张量有梯度)
+```python
+from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
+
+
+class Custom_module(torch.nn.Module):
+ def __init__(self):
+ ......
+
+ def forward(self, input):
+ self.activation_checkpoint_manager = CheckpointWithoutOutput()
+ function_output = self.activation_checkpoint_manager.checkpoint(self.custom_function, False, function_input1, function_input2, ...)
+ ...(after used output)
+ self.activation_checkpoint_manager.discard_output()
+ if module_output.requires_grad:
+ module_output.register_hook(self.activation_checkpoint_manager.recompute)
+
+ return module_output
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/adaptive-memory.md b/model/train/yoco_moe/docs/features/adaptive-memory.md
new file mode 100644
index 0000000000000000000000000000000000000000..e85267dbeee5a5a9e499bce60895f1a5c8763079
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/adaptive-memory.md
@@ -0,0 +1,71 @@
+# 自适应内存优化
+
+## 问题分析
+
+在大模型训练中,重计算特性可以有效的减少显存使用,但是策略较为固定,无法最大限度使用显存资源。
+
+## 解决方案
+
+为了在最大限度地利用NPU显存的同时,提高模型训练的性能,我们支持通过自动调整训练内存大小来自动选择计算策略。这一特性称为自适应内存优化。
+
+### 解决思路
+
+自适应内存优化设计主要包括:自适应无损swap、 策略生成、策略搜索、SwapManager功能以及内存管理等几部分。
+
+- 策略生成依赖自适应无损swap去生成策略
+
+- 策略搜索依赖SwapManager功能及时将tensor换到CPU,避免OOM导致训练中断。
+
+自适应内存优化策略流程如下图所示:
+
+ 
+
+SwapManager功能需要内存管理适配PTA的NPUPluggableAllocator接口拦截OOM,让SwapManager功能可以介入,流程如下图所示:
+ 
+
+## 使用场景
+
+该特性主要用于训练场景,如果用户发现开启了全重计算功能后, NPU显存剩余较多,此时若想充分利用显存,从而提高训练性能,可以考虑开启该特性。
+
+## 使用方法
+
+在训练脚本中添加`--adaptive-memory-optimization`
+
+注意:
+1. 当前自适应内存优化与全重计算、自适应选择重计算、预取特性swap-attention、 recompute-in-bubble等不兼容。
+2. 目前自适应内存优化已能够管理一部分使用torch.autograd.Function修饰的auto_function类
+
+ - 在调用auto_function的文件中 添加 `from mindspeed.core.memory.adaptive_memory.adaptive_memory_function import adapt_mem_func_wrapper`
+ - 将 `auto_function.apply(*args)` 修改为 `adapt_mem_func_wrapper(auto_function, *args)`
+ - 以mindspeed.moe.pipe_experts中的PipeExpert类的调用为例,在mindspeed.moe.moe_layer文件中添加`from mindspeed.core.memory.adaptive_memory.adaptive_memory_function import adapt_mem_func_wrapper`,将`expert_output = PipeExpert.apply(*args)`修改为`expert_output = adapt_mem_func_wrapper(PipeExpert, *args)`
+
+## 使用效果
+
+这里的gpt-175B是经过裁剪后的
+
+gpt-175B:
+
+| 特性 | 参数 | NPU卡数 | TFLOPs | 收益 |
+|------------|---------------------------------------------------------------------------------------------------------------------------|----------|-------------| -------------|
+| adaptive-memory-optimization | seq-length=8192、mico-batch-size=10、global-batch-size=40、TP=8、PP=1、DP=1、CP=1、NL=8、hidden-size=12288 | 8卡(单机) | 165.90 | - |
+| 全重计算 | seq-length=8192、mico-batch-size=10、global-batch-size=40、TP=8、PP=1、DP=1、CP=1、NL=3、hidden-size=12288、recompute-num-layers=3 | 8卡(单机) | 145.93 | 13.68% |
+
+
+| 特性 | 参数 | NPU卡数 | TFLOPs | 收益 |
+|------------|---------------------------------------------------------------------------------------------------------------------------|----------|--------|--------|
+| adaptive-memory-optimization | seq-length=8192、mico-batch-size=3、global-batch-size=9、TP=2、PP=4、DP=1、CP=1、NL=8、hidden-size=12288 | 8卡(单机) | 76.30 | - |
+| 全重计算 | seq-length=8192、mico-batch-size=3、global-batch-size=9、TP=2、PP=4、DP=1、CP=1、NL=8、hidden-size=12288、recompute-num-layers=1 | 8卡(单机) | 66.50 | 14.17% |
+
+| 特性 | 参数 | NPU卡数 | TFLOPs | 收益 |
+|------------|---------------------------------------------------------------------------------------------------------------------------------|----------|--------|--------|
+| adaptive-memory-optimization | seq-length=8192、mico-batch-size=2、global-batch-size=8、TP=2、PP=4、VPP=2、DP=1、CP=1、NL=8、hidden-size=12288 | 8卡(单机) | 86.10 | - |
+| 全重计算 | seq-length=8192、mico-batch-size=2、global-batch-size=8、TP=2、PP=4、VPP=2、DP=1、CP=1、NL=8、hidden-size=12288、recompute-num-layers=1 | 8卡(单机) | 75.10 | 14.65% |
+
+## 注意事项
+
+1. 由于自适应内存优化与内存碎片优化两个特性都修改了PyTorch内存管理模块,这两个特性都打开会存在冲突,mindspeed进行了assert判断。
+2. 由于自适应内存优化依赖cpu的绑核,因此需要保证运行环境内含有npu-smi以及lspci命令。
+安装命令:yum install pciutils
+
+
+
diff --git a/model/train/yoco_moe/docs/features/adaptive-recompute.md b/model/train/yoco_moe/docs/features/adaptive-recompute.md
new file mode 100644
index 0000000000000000000000000000000000000000..a2f44bda42c8b9f193c087a612b9c6e08c17d3b8
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/adaptive-recompute.md
@@ -0,0 +1,46 @@
+# 自适应选择重计算
+
+## 问题分析
+
+重计算特性可以有效的减少显存使用,但是策略较为固定,无法最大限度使用显存资源。
+
+## 解决方案
+
+为了在最大限度地利用计算设备显存的同时,提高模型训练的性能,我们支持通过自动调整训练内存大小来自动选择重新计算策略。这一特性称为自适应选择重计算。
+
+### 解决思路
+
+自适应选择重计算设计主要包括重计算策略搜索、SwapManager 功能和内存管理三大部分。
+
+其中重计算策略搜索依赖 SwapManager 功能及时将 tensor 换到 CPU,避免 OOM 导致训练中断。
+
+自动选择重计算策略流程如下图所示:
+
+ 
+
+SwapManager 能需要内存管理适配 PTA 的 NPUPluggableAllocator 接口拦截 OOM,让 SwapManager 功能可以介入,流程如下图所示:
+ 
+
+## 使用场景
+
+该特性主要用于训练场景,如果用户发现开启了全重计算功能后, NPU显存剩余较多,此时若想充分利用显存,从而提高训练性能,可以考虑开启该特性。
+
+## 使用方法
+
+1. 在训练脚本中添加`--adaptive-recompute-device-swap`。
+2. (可选)支持手动调整训练内存大小来自动选择重计算策略,请使用`--adaptive-recompute-device-size`进行设置来指定自适应选择重计算策略的训练内存大小(单位:MB)。内存>0为有效内存,最大内存限度为device最大内存。在该范围内自适应重计算才可以进行最优策略搜寻,不在有效内存范围内会使用读取到的device最大内存信息作为默认值。需要注意的是内存设定较小时,性能会与全重计算一致。该方式如果发生OOM,您需要重新选择一个新的内存值来重启模型训练。您也可以通过二分法的方式获得最优解,对该特性不熟悉请勿使用此选项。
+3. (可选)支持设置停止profiling的训练step,请使用`--adaptive-recompute-profiling-step`进行设置。该参数需要设置为>0的整数。默认在第10步停止profiling。若该值<=0,则采用默认值10,推荐设置该值>5。当step<5或者>总步数的1/10时,会有告警信息,但不影响正常训练,不会对性能和精度有任何影响。
+4. 此特性暂只适用于`--use-legacy-models`。
+
+## 使用效果
+
+相比全重计算,Llama2-7B场景下,性能提升约 16.29%,Llama2-13B 性能提升约12.05%。
+
+## 注意事项
+
+- 自适应选择重计算当前暂只适用于`--use-legacy-models`。
+- 当前自适应选择性重计算与全重计算、选择重计算、重计算独立调度流水线ripipe、激活函数重计算、预取特性swap-attention等特性均不兼容。
+- 由于自适应选择重计算特性修改了PyTorch内存管理模块,打开会存在冲突,mindspeed进行了assert判断。
+- 当使用`--adaptive-recompute-device-swap`时,用户可以通过环境变量`MIN_SWAP_TENSOR_SIZE`来指定允许被换出tensor的最小大小(最小可为1024),如不指定,则默认为1024
+- 当使用`--adaptive-recompute-device-swap`时,用户可以通过环境变量`SWAP_SIZE_MULTIPLE`来指定换出大小与malloc大小的比值(最小可为1),如不指定,则默认为1
+- 自适应重计算通过实现自己的allocator来实现对OOM的拦截,此allocator仍然支持PTA的环境变量`PYTORCH_NPU_ALLOC_CONF`,用户可以参考[此处](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha001/apiref/envref/envref_07_0053.html)来配置该环境变量。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/alibi.md b/model/train/yoco_moe/docs/features/alibi.md
new file mode 100644
index 0000000000000000000000000000000000000000..26356224c86164838ed95b1c28a098019e723689
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/alibi.md
@@ -0,0 +1,33 @@
+# Alibi 位置编码
+
+## 问题分析
+
+当大模型在训练和预测时的输入长度不一致时,模型的泛化能力会下降。若外推能力不佳,大模型在处理长文本或多轮对话时的效果就会受到限制。正弦位置编码的外推能力比较弱,RoPE(Rotary Position Embedding)的外推能力有一定提高但仍然有限。
+
+## 解决方案
+
+支持Alibi位置编码,提高模型外推能力。
+
+### 解决思路:
+
+Alibi算法给attention score添加了一个预设的线性偏置矩阵(如下图所示),使模型能够理解输入之间的相对位置关系。由于位置信息直接作用于attention score上,位置差异性被突出,使模型具有较强的外推能力。
+ 
+
+[原文链接](https://arxiv.org/pdf/2108.12409)
+
+
+## 使用方法
+
+(1)对于不开启`--use-fusion-attn-v2`特性的情况,设置`--position-embedding-type alibi`即可调用该算法。
+
+(2)对于开启`--use-fusion-attn-v2`特性的情况设置,需要设置`--position-embedding-type alibi`和`--alibi-fusion-attn-type 2`(支持0,2,3)。
+0表示生成alibi后传入,1暂不开放, 2和3表示核内生成, 3做pse的时候会做sqrt。
+如果要设置alibi为对角线对称取反,则需设置`alibi_diagonal_opposite`,反之(亦是默认情况,且与2和3时核内生成一致)无需进行设置。
+
+(3)目前alibi位置编码已经支持ring-attention长序列并行,当前只支持mask为causal的场景,以及 `--alibi-fusion-attn-type` 为2,3的压缩模式。暂不支持ulysses长序列并行和混合长序列并行。
+
+(4)开启`--use-fusion-attn-v2`特性和长序列并行时,alibi编码不支持开启dropout。
+
+## 使用效果
+
+模型外推能力提高。
diff --git a/model/train/yoco_moe/docs/features/ampipe.md b/model/train/yoco_moe/docs/features/ampipe.md
new file mode 100644
index 0000000000000000000000000000000000000000..c5d85b23fb466f3dbb671546fc49842fc980c264
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/ampipe.md
@@ -0,0 +1,59 @@
+# Ampipe流水通信隐藏
+
+## 问题分析
+
+MoE模型中引入了alltoall通信算子,用于在ep组中不同rank间交换token。在MoE层前向过程中,专家mlp部分前后各有一个alltoall通信算子,且计算与通信为串行执行,需要减少这部分通信的时间,提升训练性能。
+
+
+## 解决方案
+
+ampipe将transformer模型中从attention到mlp部分的通信和计算的输入切分为多份,每一份数据之间互相独立不存在依赖,使得各个部分的计算和通信可以循环流水并行,同时调整计算和通信的算子执行顺序,实现计算和通信并行达到掩盖通信的目的。
+
+
+
+论文参考:
+https://openreview.net/pdf?id=yLgr02IsXY
+
+## 解决思路
+1. 从attention的输入开始切分,q和attention_mask在seq序列维度进行切分, k, v保持完整输入,可以使得切分attention后再拼接结果等价。
+2. attention之后的dropout、残差、norm归一化以及MLP等计算在seq序列维度上均独立,切分后再拼接结果同样可以等价,所以在中间各个部分不需要拼接,直到所有计算完成后再拼接结果即可。
+3. 切分后重新编排各个切分副本循环流水的顺序,使得计算和通信并行。
+4. 针对主流的megatron的序列并行sequence-parallel以及长序列并行的context-parallel进行适配,可以实现sp开启时mlp部分的all-gather和reduce-scatter通信隐藏。
+
+## 使用场景
+
+在训练MoE模型时,可以开启ampipe特性。
+推荐在`--seq-length`序列长度较长时开启特性,可以获得更好的性能提升。
+
+## 使用方法
+
+1. 在训练脚本中添加`--ampipe-degree N`即可使能ampipe特性,N为切分数。
+2. 推荐开启`--ampipe-tp-sp-comm-overlap`,额外掩盖mlp中tp域内通信以达到最佳性能提升。
+3. 支持同时开启ampipe特性(包含1,2中两个特性开关)以及mlp通信隐藏特性`--use-pipe-experts`,单独或同时设置`--pipe-experts-multi-stream`和`--pipe-experts-multi-data N`来叠加使用“多流水线”和“多副本”的特性。
+
+限制条件:
+1. 需要开启`--moe-model-type deepspeed_moe`以及`--use-flash-attn`的前提下使用特性
+2. 暂不支持`--use-ascend-mc2`、`--overlap-grad-reduce`、`--overlap-param-gather`以及nanopipe `--use-nanopipe`、ripipe `--recompute-in-bubble` `--recompute-in-advance`和自适应选择重计算。
+3. 需要保证设置的`--seq-length`即序列长度可以被`--ampipe-degree`整除,如果需要设置`--sequence-parallel`以及`--context-parallel-size > 1`,需要额外保证设置的`--seq-length`可以被tp和cp整除
+4. 同时开启ampipe特性以及mlp通信隐藏特性时,`--pipe-experts-multi-data N`多副本数量N必须被`--ampipe-degree M`ampipe切分数M整除且N>M,否则`--use-pipe-experts`不生效;同时额外设置`--pipe-experts-multi-stream`时,此限制可以放开至N>=M
+
+## 使用效果
+
+使用该特性可以提升性能。
+
+场景:双机16P, sequence_len = 128k, num_layers = 2, num_experts = 4, recompute_method = block, recompute_granularity = full, recompute_num_layers = 2, hidden_size = 12288, moe_router_topk = 2, ep = 2, tp = 8, dp = 1, cp = 2, pp = 1, sp = True
+
+
+| 对比场景 | ampipe-degree | ampipe-tp-sp-comm-overlap | multi-stream | multi-data | 平均TFLOPs | 提升幅度 |
+|:-----------------------:|:-------------:|:-------------------------:|:------------:|:----------:|:--------:|:-----:|
+| baseline | 1 | 关 | 关 | 1 | 120.56 | / |
+| pipe-experts(baseline2) | 1 | 关 | 开 | 2 | 124.85 | 3.56% |
+| ampipe | 2 | 开 | 关 | 1 | 127.29 | 5.58% |
+| ampipe&pipe-experts | 2 | 开 | 开 | 4 | 126.87 | 5.23% |
+
+
+## 注意事项
+
+- 在开启`--ampipe-degree N`时,若`N`过大,导致输入数据切分过细,会引入多余的 cast 和 add 算子,导致额外的开销,引起性能劣化。 目前仅推荐开启`--ampipe-degree 2`,在开启`--context-parallel-size` > 1的场景下,仅支持设置`--ampipe-degree 2`。
+- 推荐开启`--ampipe-tp-sp-comm-overlap`,尤其在开启`--sequence-parallel`时,可额外掩盖mlp中tp域内通信以达到最佳性能提升。
+- 与部分通信隐藏特性冲突,暂时不支持,参考使用方法中的限制条件。
diff --git a/model/train/yoco_moe/docs/features/async-ddp-param-gather.md b/model/train/yoco_moe/docs/features/async-ddp-param-gather.md
new file mode 100644
index 0000000000000000000000000000000000000000..c99c621c7950c44e8da0ecc57554e191f62578a4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/async-ddp-param-gather.md
@@ -0,0 +1,41 @@
+# 权重更新通信隐藏
+
+## 问题分析
+
+大模型训练过程中,通常会使用数据并行。在进行梯度更新时,数据并行组中的通信要等反向计算完成后再进行。这样的串行执行顺序会造成计算和通信流存在一定的空闲等待时间,导致执行效率较低。
+
+## 解决方案
+
+通过计算和通信任务并行的方式来实现相互的流水掩盖。
+
+### a. 仅打开 `--use-distributed-optimizer`
+仅打开分布式优化器时(`--use-distributed-optimizer`),运行流程如下图所示,前向和反向计算完成后,会有独立的通信时间,进行梯度的reduce-scatter、计算权重、进行权重的all-gather,获得权重之后再进入下一轮的前向计算。
+ 
+
+### b. 打开 `--use-distributed-optimizer` 和 `--overlap-grad-reduce`
+在打开`--use-distributed-optimizer`的同时打开`--overlap-grad-reduce`,运行流程如下图所示,对梯度的reduce-scatter过程与反向计算过程并行,从而节省了单独的reduce-scatter过程,提高了计算-通信并行效率。
+ 
+
+### c. 打开 `--use-distributed-optimizer` 和 `--overlap-grad-reduce` 和 `--overlap-param-gather`
+在打开`--use-distributed-optimizer`和`--overlap-grad-reduce`的基础上进一步打开`--overlap-param-gather`,运行流程如下图所示,对权重的all-gather过程与下一轮的前向计算并行,从而节省了单独的all-gather过程。
+ 
+
+以上流程对比发现,打开--overlap-param-gather后,通信与计算完全并行,极大提高了计算-通信并行效率,进而提升了模型训练效率。
+
+## 使用场景
+
+在数据并行场景可以开启该特性。
+
+## 使用方法
+
+设置`--overlap-param-gather`即可调用该算法。
+确保同时开启了`--use-distributed-optimizer`和`--overlap-grad-reduce`。
+
+## 使用效果
+
+使用该特性可以提升性能。
+
+## 注意事项
+
+开启该特性后,attention层init的顺序会更正为先创建linear_qkv再创建linear_proj,这是为了修复Megatron的错误init顺序,该bug会导致当linear_qkv和linear_proj被分配在不同bucket时,overlap-param-gather可能会在权重未完成更新时进行下一轮前向计算。
+legacy下,`--overlap-param-gather`暂不支持和`reuse_fp32_param`一起使用。
diff --git a/model/train/yoco_moe/docs/features/async-ddp.md b/model/train/yoco_moe/docs/features/async-ddp.md
new file mode 100644
index 0000000000000000000000000000000000000000..4901eb55ebb011c025fb6f640c0dd968e695f59e
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/async-ddp.md
@@ -0,0 +1,27 @@
+# 异步DDP
+
+## 问题分析
+
+大模型训练过程中,通常会使用数据并行。在进行梯度更新时,数据并行组中的通信(未开启分布式优化器时为AllReduce,开启时为ReduceScatter)要等反向计算完成后再进行。这样的串行执行顺序会造成计算和通信流存在一定的空闲等待时间,导致执行效率较低。
+
+## 解决方案
+
+通过将计算和通信任务分别拆分成更细粒度的子任务来实现相互的流水掩盖。并行原理如下图所示:
+ 
+
+### 解决思路:
+
+设置一个Bucket,存储反向计算的结果。每当Bucket存满时立刻执行桶中结果的通信任务,后续反向计算可以和这部分通信并行执行,从而增大计算和通信流的利用率,提高执行效率。
+
+## 使用场景
+
+使用该特性的前提是模型开启数据并行和虚拟流水并行,脚本中设置了`--num-layers-per-virtual-pipeline-stage N`。
+
+## 使用方法
+
+设置`--overlap-grad-reduce`即可调用该算法。
+
+## 使用效果
+
+开启该特性可以提升性能。
+
diff --git a/model/train/yoco_moe/docs/features/auto_tuning.md b/model/train/yoco_moe/docs/features/auto_tuning.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf97797e8505511c0405d5eee713f504d8c29379
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/auto_tuning.md
@@ -0,0 +1,73 @@
+# 开箱优化-大模型并行策略自动搜索 Auto Tuning 特性说明
+
+## 问题背景
+
+随着大模型并行训练可配置的参数越来越多,例如DP、TP(以及SP)、PP、ZERO、VPP、CP、EP、mbs、重计算等,内存和性能受到各种配置的影响变得越来越复杂,人工调优变得越来越困难。于是,业界开始尝试一些自动调优的方法,主要思路是基于网络模型的结构进行白盒或者灰盒建模,在建模的指导下结合一些profiling,进行配置参数的搜索。
+
+但是,这些方法通常存在以下两个不足之处:
+
+- 白盒或灰盒的建模**对网络模型的结构进行了假设**,而很多用户都会对模型做出修改,这类建模难以捕捉到模型的变化。例如,仅仅是GQA/MQA的修改,就会让此类建模的内存出现偏差。
+- **profiling的规模和实际的负载规模相同**,当要进行大规模(如千卡)的训练时,profiling的开销会变得很大。
+
+因此,我们设计并开发了一种Auto Tuning的特性,该特性和业界已有的自动调优方案相比,完全基于profiling的分析,无需对网络的结构做出假设,并且支持”以小仿大“,即以小规模的profiling预估更大集群上的较优训练配置。
+
+## 解决方案
+
+Auto Tuning特性完全依赖由profiling得出的黑盒建模,与网络结构的变化解耦,并且支持在小规模集群(如双机)上推测大规模集群的配置。
+
+- **阶段1:** 用少量机器拉起auto tuning,该特性会裁剪网络大小,并生成多个profiling的配置,自动多次拉起。这些profiling主要是用作黑盒分析,例如分析配置变化时,哪些tensor会被切分,哪些算子的shape会如何变化,会增加或减少哪些算子等。profiling结束后会对结果文件进行解析,提取出后续黑盒建模需要的信息。
+- **阶段2:** 依据profiling结果进行黑盒建模。内存方面会自动分析各个tensor在不同配置下的切分情况,性能方面会推断算子随不同配置的增减和shape变化,并回归出机内和机间通信的效率。除了基础的性能和内存建模之外,还会分析各个候选重计算模块的性能和内存,从而可以在后续搜索中预估应该选择哪些模块做重计算,以及其对性能和内存的影响。
+- **阶段3:** 根据阶段2得出的建模,进行配置的搜索,给出每个配置下预期的性能和内存。这一步还会依赖一个算子性能知识库,从中查询不同shape的算子的性能。profiling产生的没见过的算子都会被添加到算子性能知识库中。如果某个配置下算子性能知识库覆盖的算子比例小于阈值,则会额外拉起一组profiling,该profiling仍然可以以小仿大,通过同时缩小网络的规模和并行参数,从而得到相同shape的算子。如果算子性能知识库覆盖的算子比例不足以推测算子性能,则未覆盖到的少量算子会通过回归来估计性能。搜索结束后会推荐出内存充足的性能最好的三组配置。
+
+已支持的模型:
+- [x] llama2-7b
+- [x] mixtral-8*7b
+- [x] gpt3-15b
+
+已支持的特性:
+
+- [x] DP
+- [x] TP
+- [x] Megatron-SP
+- [x] PP
+- [x] ZeRO1
+- [x] VPP
+- [x] CP (ring attention)
+- [x] EP (Deepspeed-MOE)
+- [x] MicroBatchSize
+- [x] Token重排
+- [x] 重计算
+- [x] MC2
+
+未来计划支持的特性:
+
+- [ ] ZeRO2
+- [ ] EP (Megatron-MOE)
+- [ ] swap-attention
+- [ ] 激活函数重计算
+- [ ] MoE All2All overlap comm
+
+## 使用方法
+
+在训练脚本的参数列表中加入以下配置开启 Auto Tuning 特性:
+
+```bash
+--auto-tuning \ # 开启 Auto Tuning 特性
+--auto-tuning-work-dir ./auto_tuning_dir \ # 工作目录,在此会保存profiling等文件
+--auto-tuning-ranks 16 \ # 需求搜索的卡数,最低16卡
+--auto-tuning-log-level debug \ # Auto Tuning log记录等级,可选warning, info, debug
+--nnodes $NNODES \ # Profiling拉起的节点数,与基线训练脚本保持一致
+--nproc-per-node $GPUS_PER_NODE \ # 每个节点上运行的进程数,一般与单节点卡数相同,与基线训练脚本保持一致
+--master-addr $MASTER_ADDR \ # 主节点IP,与基线训练脚本保持一致
+--master-port 6005 \ # 主节点端口,设置一个与基线脚本不同的端口
+--node-rank $NODE_RANK \ # 与基线训练脚本保持一致
+```
+
+## 环境变量
+以下环境变量为 Auto Tuning 控制阶段性 Profiling 所用环境变量开关,**仅为 Auto Tuning 内部使用**,**禁止**在正常训练流程中设置
+
+**Auto Tuning会在一个隔离的进程环境中设置以下环境变量,不会export至用户环境中**
+- "OOTB_OPTIMIZER_MODIFIED_ARGV_PATH=${WORK_dir}/auto_tuning_modified_argv.json": 修改Profiling拉起配置参数的文件位置
+- "OOTB_OPTIMIZER_PARSE_ARGS=TRUE": 获取硬件相关信息及模型参数
+- "OOTB_OPTIMIZER_PARSE_MODEL=TRUE": 获取模型结构
+- "OOTB_OPTIMIZER_PROFILING=TRUE": 获取完整Profiling信息及自适应重计算Profiling信息
diff --git a/model/train/yoco_moe/docs/features/automated-pipeline.md b/model/train/yoco_moe/docs/features/automated-pipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..89974e662a359d49fb9452e6b6c466ce93d4aad7
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/automated-pipeline.md
@@ -0,0 +1,44 @@
+# PP自动并行算法
+
+## 问题分析
+
+流水线并行是将模型网络层切分成多个stage,再把stage映射到不同的设备上,使得不同设备并行计算神经网络的不同部分。流水线并行大大缓解了单卡内存瓶颈问题,并通过多卡之间的流水训练提高了硬件的利用率。流水线并行成为了当前大模型训练最常用的并行方式之一。然而当前流水线并行在内存消耗和性能方面并非最优,主要存在两大问题:
+
+1)内存不均衡:当前流水线常用调度模式(1F1B)下,靠近模型前面层的流水线stage的内存占用远多于后面的stage内存占用,并且内存占用差距有2~3倍,总体上可训的模型规模受限于PP-Stage 0的显存消耗。
+
+2)流水线气泡:流水线1F1B调度策略在每个设备上交替进行小批次数据的前向后向计算,由于各流水设备之间计算负载不均衡或者网络通信的波动,导致设备与设备之间存在等待(流水线气泡),影响训练性能。
+
+## 解决方案
+本系统基于在线profiling+PP建模搜索,通过使能内存优化模块、性能优化模块分别最大化流水线并行训练的内存和性能。内存优化模块旨在通过自动寻找流水线并行中stage的最优层分布和细粒度重计算模块,均匀分配每个卡上的显存,优化存在显存瓶颈的PP-stages,降低峰值内存;性能优化模块采用mbs序列和前向反向调度序列自动寻优和多流异步通信机制,压缩流水线气泡,提升训练性能。
+
+### 内存优化模块
+基于在线profiling+PP建模搜索,自动构建出最优的内存排布方案均衡化各个stage之间的内存开销,降低峰值内存的同时最小化端到端训练时间,具备较好的易用性和泛化性。具体而言,在层分布和细粒度重计算的联合搜索空间自动寻优内存排布方案:
+① PP层分布切分:采用不均匀层切分策略,自动搜索最优层切分方式,均衡化每个卡消耗的显存,从而优化存在显存瓶颈的PP-stages,降低峰值内存。
+② 细粒度重计算:利用流水线气泡时间来做重计算,保证性能不劣化,通过自动寻优细粒度的重计算策略,进一步降低峰值内存。
+
+### 性能优化模块
+在满足训练峰值内存开销不超过设备最大内存容量的条件下,通过自动寻找流水线并行中最优的mbs序列及前向反向调度序列,最小化端到端训练时间。
+① 动态mbs:在给定的gbs下,自动搜索最优mbs序列。通过小mbs加速流水线的启动与冷却,压缩气泡时间,稳态阶段自动寻找最高效的mbs进行计算,缩短稳态阶段计算时间,提升端到端训练性能。
+② 前反向调度:通过调整流水线并行过程中前反向计算的顺序,结合多流异步通信机制,压缩流水线稳态气泡,提升训练性能。
+
+PP自动并行系统如下图所示:
+
+ 
+
+
+## 使用场景
+
+该系统主要用于开启流水线并行的训练场景,使用PP自动并行系统可有效优化内存不足或流水线气泡占比过大的问题。
+**使用条件:**
+1. `--pipeline-model-parallel-size >= 2`;
+2. 内存、性能优化模块不能同时使用。
+
+
+## 使用方法
+
+(1)当内存不足时,可启用PP自动并行内存优化模块,请首先在训练脚本中添加 `--automated-pipeline` 标志启用功能。
+(2)当流水线气泡过大导致训练性能不优时,可启用PP自动并行性能优化模块,请首先在训练脚本中添加 `--automated-pipeline-perf` 标志启用功能。
+
+## 使用效果
+
+PP自动并行内存优化模块收益:LLaMA2-7B,LLaMA-13B,LLaMA2-70B等使用流水线并行PP配置训练的模型,叠加本算法后平均峰值内存减少11.5%,平均性能劣化小于1%。性能优化模块收益:LLaMA2-7B,LLaMA-13B,LLaMA3-8B等使用流水线并行PP配置训练的模型,叠加本算法后平均性能提升7.6%。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/automatic_parallelism_mm.md b/model/train/yoco_moe/docs/features/automatic_parallelism_mm.md
new file mode 100644
index 0000000000000000000000000000000000000000..9d29229d94b9227abf3aa44cca1152f590956c7e
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/automatic_parallelism_mm.md
@@ -0,0 +1,179 @@
+## Automatic Parallelism For Multi-Modal
+
+## 问题分析
+当前多模态大模型的并行训练方法越来越丰富,主要有TP\PP\DP\CP\VPP等,每种并行方法在计算、内存、通信上均有不同的优势。在当前的生产过程中,主要是使用基于专家经验的人工调优,一般需要数天甚至数周的时间。相似模型或者一个模型的不同训练阶段,最优的并行配置也并不相同。随着并行方法的不断丰富,并行搜索空间不断变大,使用人工调优的方法变得越来越不可行。因此需要构建一个面向多模态大模型的并行配置自动调优算法,可以自动化得根据集群资源、模型结构得出最优的并行方法。
+
+## 解决方案
+针对多模态大模型结构丰富,训练阶段多样的特点,我们将网络进行切分和子图归并,然后使用基于黑盒Profiling的方法对多种并行配置采样,最后使用基于整数规划的方法进行非均匀的网络层切分:
+
+- 采样性能:
+遵照多模态大模型原有的训练方法和调用逻辑,将模型进行切分和子图归并,然后使用少量资源进行block级别的性能采样,这样的采样方案兼顾了子图外推的灵活性和采样操作低开销的要求。
+- 端到端建模:
+根据性能采样得到的子图性能、内存数据,使用白盒建模的方法得到网络的峰值内存,以及仿真得到的单步迭代时间。
+- 并行策略调优:
+根据集群的资源和模型支持的并行策略,构建全量的并行策略搜索空间;针对每种并行策略,将PP非均匀层最优切分问题转化为整数规划问题,联合考虑PP流水调度、内存限制和重计算策略,优化目标为端到端时间最短。遍历所有可行的并行策略,得到最优的并行方案;
+
+
+
+## 使用方法
+在使用多维自动并行特性时,**需使用python作为脚本启动器,在所有的节点上拉起脚本**,并配置多维自动并行相关的参数。相关参数及其函数如下表所示:
+
+| 参数名 | 参数含义 |
+| --------------------------- | -------------------------------------------------- |
+| --auto-parallel-mm | 多维自动并行特性总开关 |
+| --nnodes | 采样集群中节点的个数 |
+| --nproc-per-node | 采样集群中每个节点计算设备的个数 |
+| --master-addr | 采样集群中主节点的IP地址 |
+| --master-port | 采样集群用于通信的端口号,各节点需要配置相同的端口 |
+| --node-rank | 采样集群中节点的rank,主节点为0,其他节点为1,2,······ |
+| --simulated-nnodes | 待训练集群的节点个数 |
+| --simulated-nproc-per-node | 待训练集群每个节点的设备数 |
+
+下面是基于QWen2VL-72B模型的配置示例:
+```shell
+#!/bin/bash
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export ASCEND_SLOG_PRINT_TO_STDOUT=0
+export ASCEND_GLOBAL_LOG_LEVEL=3
+export TASK_QUEUE_ENABLE=2
+export COMBINED_ENABLE=1
+export CPU_AFFINITY_CONF=2
+export HCCL_CONNECT_TIMEOUT=1200
+export NPU_ASD_ENABLE=0
+export ASCEND_LAUNCH_BLOCKING=0
+export HOST_CACHE_CAPACITY=20
+export ACLNN_CACHE_LIMIT=100000
+export MULTI_STREAM_MEMORY_REUSE=2
+export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True"
+# 根据机器实际情况填写
+NPUS_PER_NODE=8
+MASTER_ADDR=localhost
+MASTER_PORT=6010
+NODE_RANK=0
+NNODES=1
+WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
+echo $MASTER_ADDR
+echo $NODE_ADDR
+echo $NODE_RANK
+echo $NNODES
+
+
+MM_DATA="./examples/qwen2vl/data_72b.json"
+MM_MODEL="./examples/qwen2vl/model_72b.json"
+MM_TOOL="./mindspeed_mm/tools/tools.json"
+LOAD_PATH="ckpt/Qwen2-VL-72B-Instruct"
+SAVE_PATH="save_dir"
+
+TP=4
+PP=2
+CP=1
+SEQ_LEN=1024
+MBS=1
+GRAD_ACC_STEP=32
+DP=$(($WORLD_SIZE/$TP/$PP/$CP))
+GBS=$(($MBS*$GRAD_ACC_STEP*$DP))
+
+DISTRIBUTED_ARGS="
+ --nproc_per_node $NPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank $NODE_RANK \
+ --master_addr $MASTER_ADDR \
+ --master_port $MASTER_PORT
+"
+
+GPT_ARGS="
+ --use-mcore-models \
+ --tensor-model-parallel-size ${TP} \
+ --pipeline-model-parallel-size ${PP} \
+ --micro-batch-size ${MBS} \
+ --global-batch-size ${GBS} \
+ --num-layers 80 \
+ --hidden-size 8192 \
+ --ffn-hidden-size 29568 \
+ --num-attention-heads 64 \
+ --tokenizer-type NullTokenizer \
+ --vocab-size 152064 \
+ --seq-length 8192 \
+ --max-position-embeddings 32768 \
+ --make-vocab-size-divisible-by 1 \
+ --init-method-std 0.01 \
+ --normalization RMSNorm \
+ --use-fused-rmsnorm \
+ --swiglu \
+ --use-fused-swiglu \
+ --lr 1.0e-5 \
+ --lr-decay-style cosine \
+ --weight-decay 0 \
+ --train-iters 5 \
+ --lr-warmup-fraction 0.1 \
+ --clip-grad 0.0 \
+ --adam-beta1 0.9 \
+ --adam-beta2 0.999 \
+ --no-gradient-accumulation-fusion \
+ --no-load-optim \
+ --no-load-rng \
+ --no-save-optim \
+ --no-save-rng \
+ --seed 42 \
+ --bf16 \
+ --load $LOAD_PATH \
+ --variable-seq-lengths \
+ --enable-one-logger \
+ --use-distributed-optimizer \
+ --reuse-fp32-param
+"
+
+MM_ARGS="
+ --mm-data $MM_DATA \
+ --mm-model $MM_MODEL \
+ --mm-tool $MM_TOOL
+"
+
+SEARCH_ARGS="
+ --auto-parallel-mm \
+ --nnodes $NNODES \
+ --nproc-per-node $NPUS_PER_NODE \
+ --master-addr $MASTER_ADDR \
+ --master-port $MASTER_PORT \
+ --node-rank $NODE_RANK \
+ --simulated-nnodes 8 \
+ --simulated-nproc-per-node 16 \
+"
+
+OUTPUT_ARGS="
+ --log-interval 1 \
+ --save-interval 10000 \
+ --eval-interval 10000 \
+ --eval-iters 5000 \
+ --save $SAVE_PATH \
+"
+logfile=$(date +%Y%m%d)_$(date +%H%M%S)
+mkdir -p logs
+
+python pretrain_qwen2vl.py \
+ $GPT_ARGS \
+ $MM_ARGS \
+ $OUTPUT_ARGS \
+ $SEARCH_ARGS \
+ --distributed-backend nccl \
+ | tee logs/train_${logfile}.log 2>&1
+
+chmod 440 logs/train_${logfile}.log
+```
+
+## 调优效果
+
+
+
+
+## 搜索结果说明
+
+搜索算法的调优结果会存放在执行目录下的 auto_parallel_search_optimal_config.json 文件中,以下表格是对搜索结果的解释
+
+| 参数名 | 参数含义 |
+| --------------------------- | -------------------------------------------------- |
+| parallel_config | 并行配置,包含PP/TP/DP/MBS维度 |
+| layer_placement | 层切分配置,其中包含ViT及LLM的PP层切分策略 |
+| layer_recompute | 细粒度重计算层数,包含ViT及LLM的重计算层数 |
+| e2e_time | 仿真的端到端时间 |
+| throughput | 仿真的模型吞吐率 |
diff --git a/model/train/yoco_moe/docs/features/communication-over-computation.md b/model/train/yoco_moe/docs/features/communication-over-computation.md
new file mode 100644
index 0000000000000000000000000000000000000000..eb2dd8bd51534c709f1ed17e51eb37cf6c8006d5
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/communication-over-computation.md
@@ -0,0 +1,82 @@
+# 计算通信并行 CoC (Communication Over Computation)
+
+## 问题分析
+
+大模型训练过程中,其ColumnParallelLinear和RowParallelLinear部分的前反向均存在相互毗邻、顺序依赖的计算通信组合,计算为Matmul,而通信则为AllReduce(不开启序列并行)或AllGather和ReduceScatter(开启序列并行)。这些计算通信的组合因为存在顺序依赖(即后一个的输入是前一个输出),常常被串行执行,但这时候计算和通信流都存在一定的空闲等待时间,该过程的执行效率没有被最大化。
+
+## 解决方案
+
+通过将计算和通信任务分别拆分成更细粒度的子任务来实现相互的流水掩盖。
+
+### 解决思路
+
+#### Python脚本侧实现
+将张量进行进一步切分(2/4/8份),通过Python脚本的方式实现每个子tensor之间计算和通信的并行,从而增大计算和通信流的利用率;
+
+
+#### 融合算子实现
+基于MTE远端内存访问能力,以融合大Kernel方式在算子实现的内部将计算和通信任务分别拆分成更细粒度的子任务来实现相互的流水掩盖;
+
+## 使用场景
+该特性目前主要用于训练场景,当Attention模块和MLP模块串行执行且计算通信存在顺序依赖与位置毗邻关系时适用。
+
+使用Python脚本侧实现时,对Matmul左矩阵的m轴有一定要求,必须是切分数(2/4/8)的倍数,且不适用于计算与通信片段耗时相差较大的情况。需要注意的是,脚本侧实现在切分矩阵、切分数量较大时,容易出现host bound问题,从而不能得到预期的收益。支持ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER三个通信场景,支持灵活设置先通信或先计算。
+
+对于计算通信融合算子,目前已支持:
+1. MATMUL_ALL_REDUCE融合算子(先计算后通信)及其确定性计算;
+2. MATMUL_REDUCE_SCATTER融合算子(先计算后通信)及其确定性计算;
+3. ALL_GATHER_MATMUL, ALL_GATHER_MATMUL_V2融合算子(先通信后计算)(V2版本接口支持ALL_GATHER中间结果获取);
+4. 量化场景:MATMUL_ALL_REDUCE融合算子支持fp16格式的w8A16伪量化,粒度包含per tensor / per channel / per group;
+
+## 使用方法
+
+当前计算通信并行有两种实现方法:python脚本使能、融合算子使能,两者选其一即可。两个方式都需要替换原Megatron框架中的ColumnParallelLinear和RowParallelLinear这两个class的forward函数,替换脚本已经根据MindSpeed指定Megatron版本进行编码和适配,位于mindspeed/core/tensor_parallel/lcal_coc/目录下。
+
+请根据需要选择下列两种场景中的一个进行使用。
+
+设置--use-ascend-coc使能计算通信并行功能,使用方式通过如下变量进行设置:
+
+### 1. 使用通过Python脚本使能的计算通信并行特性
+
+```shell
+--use-ascend-coc
+--coc-parallel-num 2 # 或者4,或者8
+```
+
+### 2. 使用通过融合算子使能的计算通信并行特性
+注意:计算通信并行融合算子需要安装ATB后才能使用!
+
+ATB安装方法:
+
+- 二进制包安装:安装CANN-NNAL包之后, source /usr/local/Ascend/nnal/atb/set_env.sh
+```shell
+--use-ascend-coc
+--coc-fused-kernel # 注意:当前只支持TP=8的场景!
+```
+
+融合算子的环境变量拥有更高优先级,即当 coc-parallel-num > 1 且 使能coc-fused-kernel时,前者不会生效。
+
+
+## CFG自定义方法
+
+用户可以自定义mindspeed/core/tensor_parallel/lcal_coc/user_config.py中的coc_cfgs字典,来达到自定义COC的部分配置。
+
+【只对通过Python脚本使能的计算通信并行实现适用】
+'matmul_soc_friendly':是否对输入matmul的张量做transpose/padding操作,使其以NPU亲和的shape进入Matmul算子从而获得一定性能提升,默认为True;
+'customized_coc': 自定义指定shape的matmul的COC切分份数,默认为{}。如果需要设置指定shape的matmul的CoC切分份数为1(不开COC)或与coc-parallel-num不同的值,可以按照这个例子设置:
+'customized_coc': {"[16384, 5120, 1920]": 8, "[16384, 1920, 5120]": 1}
+
+【只对通过融合算子使能的计算通信并行实现适用】
+'enable_coc_in_column_backward': 是否在ColumnParllelLinear的反向中使用COC(ColumnParallelLinear的反向中本来就有非互相依赖的计算通信并行),默认为False;
+
+【对脚本实现和融合算子实现都适用】
+'recompute_all_gather': 是否在ColumnParallelLinear的反向中重新计算all gather,默认为True。若为False,则将从前向保存all gather结果到反向,会减少反向计算时间但是会增加训练过程中的峰值内存占用;
+
+## COC融合算子使用效果
+
+在BLOOM 7B模型中获得端到端性能收益约3.20%,在BLOOM 176B模型中获得端到端性能收益约5.47%,在LLAMA2 70B模型中获得端到端性能收益约7.85%。精度相对误差控制在2%的范围内。
+
+## 注意事项
+
+暂不兼容 --use-ascend-mc2 特性 。
+当前暂未适配MoE模型。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/context_parallelism_kv_cache.md b/model/train/yoco_moe/docs/features/context_parallelism_kv_cache.md
new file mode 100644
index 0000000000000000000000000000000000000000..bcf2b6a47d6f4e41fbd344ebee25bd2006a00efe
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/context_parallelism_kv_cache.md
@@ -0,0 +1,66 @@
+# Context Parallelism特性中的KV缓存优化
+
+## 问题分析
+Context Parallelism特性在attention计算前将序列在sequence维度切分,在attention计算时,需要计算全量的序列,因此在attention计算时,会有CP的通信产生。
+
+1. Ring CP在attention计算时通过send/recv方式循环接收其他卡的KV,最终保证Q能关注到全量KV,保持和不拆解情况的数学等价性。当前,前向计算完成后KV被丢弃,反向时需要再次send-recv拿到kv。 当在短序列计算过程中通信时间会大于计算时间,导致计算时间无法掩盖通信时间的情况,从而导致性能变差。因此,需要针对计算时间不足以掩盖通信时间的情况进行优化来加速改场景下的训练需求。
+
+2. Ulysses CP方案在GQA模型下,开启TP后,每个rank通常会只有一个head,在这种情况下,使用All2All的通信量与AllGather通信量相同,而All2All方案在只有一个head的情况下,需要对KV进行repeat,在数据layerout通常为sbh或sbnd的情况下,对h维做repeat,地址不连续,会导致算子存在效率问题,并且需要插入transpose等操作,而allgather直接操作s维,地址连续,无需额外操作。
+
+3. Ulysses CP在有repeat产生的情况下,传入attention反向的Key和Value相较于repeat前的Key和Value内存扩大了CP倍,这将会导致内存的消耗增加出现out of memory的情况。
+
+## 解决方案
+
+1. 在Ring Attention长序列并行的基础上加入KV缓存功能,可选择进行(1)缓存所有K,V,(2)只缓存K以及(3)设置分层缓存的方式在长序列并行的前向中使前向计算接收的kv始终保留至反向计算,直接计算出梯度结果,减少通信时间。
+
+2. 在GQA模型,一个head的情况下,Ulysses Attention长序列并行的基础上加入Allgather KV + All2All Q的方案,减少repeat操作以及transpose等内存非连续的开销,提高训练性能。
+
+3. 在Ulysses使用All2All和AllGather方案加入KV缓存功能,可选择进行(1)缓存所有K,V,(2)只缓存K以及(3)设置分层缓存的方式在前向中将通信前的KV进行缓存始终保留至反向再进行重通信进行计算,节省内存。All2All方案只能在做了Repeat的情况下可以开启KV缓存。
+
+### 解决思路:
+1. Ring方案中序列被切分成CP份并行计算,在不同rank上计算出自己的K和V,同时send-recv其他rank的K和V。例如rank0上的K0/V0和K7V7发送给“下游”的rank,同时接收“上游”rank发送过来的K3/V3和K4/V4,每张卡重复执行相同的动作CP-1次,最终每个切分后的序列可以“关注”到全局的KV,计算得到完整attention结果。反向计算逻辑同理,初始时每个rank有自己的KV,在计算出自己的gradient后,之后步骤将接收到的K和V分块以及dK和dV发送给其他rank,同时接收其他rank的K、V分块以及dK和dV分块,并把接收到的K和V作为输入计算和更新梯度,实现计算和通信并行。
+反向过程关键的一点,rank间通信需要发送K、V、dK、dV四个数据块,一共要发送CP-1次,其中K和V在前向已经在各个rank间逐次接收发送,如果在前向过程中将K、V缓存,反向的通信时间将减半。在CP比较大时,缓存全部K、V对内存压力增大,通过支持缓存K、V的一部分,或者每经过N个Layer缓存一次,支持按需灵活配置。
+
+2. 在GQA模型,一个head的情况下,使用AllGather KV的通信方式替换原有的Repeat-All2All KV方式获取全量的sequence,对Q仍然使用All2All方案。
+
+3. Ulysses方案中,将在前向进行Repeat-All2All或者AllGather通信前的KV进行缓存带到反向,并使用通信后的KV进行计算确保计算的正确性,反向在拿到Repeat-All2All或者AllGather通信前的KV的时候,对KV进行Repeat-All2All或者AllGather重通信进行梯度计算。因为进行重通信会有性能损失,因此可以缓存K、V的一部分,或者每经过N个Layer缓存一次,灵活组合,在内存限制内达到最优的性能。
+
+灵活缓存方案如下,
+1. 支持配置缓存K、V的layer间隔:缓存部分K、V可通过考虑在不同layer之间进行缓存来实现,通过增加一个参数interval来控制缓存的间隔层数。例如interval=1时,那么就会在编号为0,2,4,...的layer中对K、V进行缓存,依次类推。缓存间隔支持从0开始,不超过rank上的layer数量,间隔默认值等于0。
+
+2. 支持缓存K、V的一部分:在每个layer上,可支持只缓存K(K和V的size一样),这种方法通过使用一个参数对其控制,当参数的值为half时,只对K缓存,配置full则缓存K和V,默认缓存K和V。此配置和按layer间隔配置缓存可同时开启,配置后的缓存效果叠加,互不冲突
+
+## 使用场景
+
+训练过程中开启长序列并行的情况下。
+
+需使用FlashAttention,目前已默认开启FlashAttention。
+
+在Ring Attention中想要使用KV缓存获得收益,需要使得计算时间小于通信时间,理论上需要确保每个计算块分到的序列长度需要`c < F/B`。其中`F`是每个device的FLOPS,`B`是每个device间的带宽。
+
+在Ulysse Attention中,想要使用AllGather KV + All2All Q获得收益,需要使用GQA模型,并且需要在通信量相同的前提下,即KV仅有一个head的情况下。
+
+在Ulysses Attention中,想要使用KV缓存获得收益,Repeat-All2All方案需要在使用repeat的情况下,才能获得内存收益,而AllGather KV + All2All Q开启CP即可以获得内存收益。
+
+## 使用方法
+
+| 重要参数 | 参数说明 |
+|------------------------------------------------|----------------------------------------------------------|
+| --context-parallel-kv-cache-policy [full/half] | 开启CP前向计算过程缓存KV及其级别,默认full缓存K和V,half缓存K |
+| --context-parallel-cache-interval [int] | 设定执行CP前向计算过程缓存KV的layer间隔层数,默认为0,即每一个layer都需要缓存,根据用户需求配置。 |
+| --use-ulysses-allgather-kv | 设定Ulysses Attention启用AllGather方案,默认为False,不启用。 |
+
+## 使用效果
+
+在Ring Attention中计算时间无法掩盖通信时间的场景下,开启KV缓存特性会使得训练时间变短,提升训练性能,但会导致内存增加。
+
+在Ulysses Attention中开启AllGather KV,在允许的场景下,会使得训练时间变短,提升训练性能。
+
+在Ulysses Attention中开启KV缓存,在Repeat-All2All做了Repeat的情况下,内存使用会减少,但会导致性能下降。 AllGather情况下,内存使用会减少,但会导致性能下降。
+
+## 注意事项:
+
+1. 开启--context-parallel-kv-cache-policy时需要同时开启Context Parallel,否则特性不支持。
+2. 开启--context-parallel-cache-interval时需要同时开启--context-parallel-kv-cache-policy并且interval的值需要小于layer的数量,否则特性不支持。
+3. 开启--use-ulysses-allgather-kv时需要开启Context Parallel且设置--context-parallel-algo ulysses_cp_algo,并且需要开启--group-query-attention,且KV每个rank的head数量为1, 否则特性不支持。
+4. 开启--context-parallel-kv-cache-policy以及--context-parallel-algo ulysses_cp_algo的情况下,需要使KV做Repeat操作,否则特性不支持。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/conv3d_sequence_paralle.md b/model/train/yoco_moe/docs/features/conv3d_sequence_paralle.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc4b99bc21f805f4c91df86fc886553439e1da9a
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/conv3d_sequence_paralle.md
@@ -0,0 +1,36 @@
+# conv3d 序列并行
+## 问题分析
+在多模态、机器视觉等领域的模型结构中经常会采用conv3d模块用于特征图的三维卷积操作。在大模型中,卷积操作的耗时会随着特征图规模的增加而增加。
+由于特征图的每一个卷积区块的卷积过程是顺序执行,但实际上各个区块的执行顺序并不存在先后顺序上的约束关系。在分布式训练中需要对三维卷积操作进行并行化处理来提高卷积速度。
+
+## 解决思路
+构造Conv3DSequenceParallel类,将输入特征图按照卷积核的depth维度进行切分后进行并行卷积。
+- **前向过程** :
+ 构造Conv3DSequenceParallel类,将输入特征图按照卷积核的depth维度进行切分,分发到不同的进程组中进行conv3d三维卷积操作,将卷积结果进行gather操作后输出到下游模块。
+- **反向过程** :
+ Conv3DSequenceParallel类会将下游反向得到的梯度进行split操作,实现梯度的depth维度进行切分,分发到并行的三维卷积模块上进行反向传播,再将并行的三维卷积模块的反向梯度进行gather操作后输出到上游模块。
+
+## 使用场景
+训练含有conv3d(非padding模式)模块的模型。
+
+## 使用方法
+将原有的conv3d模块替换为Conv3DSequenceParallel并指定相关参数,以实现并行加速。
+Conv3DSequenceParallel模块接口如下:
+
+`Conv3DSequenceParallel(pg, in_channels, out_channels, kernel_size, stride, dilation, bias, param_async, dtype, sp_size)`
+- `pg`:必选输入,数据类型为list(int),表示通信进程组。
+- `in_channels`:必选输入,数据类型为int,表示输入通道数。
+- `out_channels`:必选输入,数据类型为int,表示输出通道数。
+- `kernel_size`:可选属性,数据类型为tuple(int,int,int),默认值:(1, 1, 1),表示卷积核大小。
+- `stride`:可选属性,数据类型为tuple(int,int,int),默认值:(1, 1, 1),表示各个维度卷积步长大小。
+- `dilation`:可选属性,数据类型为float,默认值:1.0,表示扩张率。
+- `bias`:可选属性,数据类型为bool,默认值:True。表示是否开启偏置。
+- `param_async`:可选属性,数据类型为bool,默认值:False。表示是否开启参数异步通信。
+- `dtype`:可选属性,表示数据类型,默认值:torch.bfloat16。表示数据类型。
+- `sp_size`:可选属性,数据类型为int,默认值:1。表示序列并行大小。
+
+## 使用影响
+将逐卷积区域的卷积操作分发到进程组中进行并行化执行,提高三维卷积效率。
+
+## 注意事项
+Conv3DSequenceParallel模块并不支持padding模式,因此使用了padding的conv3d模块不能使用Conv3DSequenceParallel模块替换。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/data-parallel.md b/model/train/yoco_moe/docs/features/data-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..7e34b53bdf89c6e52ceb00cdd9a06e854e2913f7
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/data-parallel.md
@@ -0,0 +1,31 @@
+# 数据并行
+
+## 问题分析
+
+对于数据集过大的模型训练场景,其训练时间过长,要将数据集进行切分,让一个计算设备只处理一部分数据。
+
+## 解决方案
+
+数据并行将数据集切分为多个batch,并且均分给不同计算设备。每个计算设备只负责处理自己的batch。
+数据并行满足:
+1.每个设备上模型的组网和参数相同。
+2.每个设备处理不同batch的数据。
+
+### 解决思路
+
+1.每个计算设备上都存储一份完整的模型复制。
+2.数据集被切分为多个batch,并且平均分给不同的计算设备,各计算设备处理不同的数据。
+3.前向计算完成得到梯度之后,需要通过all-reduce操作将各计算设备得到的梯度取平均后再发给各计算设备,保证各计算设备的参数保持一致。
+
+## 使用场景
+
+训练数据集过大,训练时间过长,且可用于训练的计算设备比较充足,可以存储多份完整模型,可以开启数据并行,来加快训练速度,减轻单设备的计算压力。
+
+## 使用方法
+
+框架中数据并行通过总设备数(world_size)、模型并行数(tensor_model_parallel_size)、流水线并行数(pipeline_model_parallel_size)、长序列并行数(context_parallel_size)计算得到。
+数据并行数(data_parallel_size) = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size)
+
+## 使用效果
+
+利用多个设备,增加了总的显存占用量,但是加快了训练速度,减轻了单设备的计算压力。
diff --git a/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe-efficient-moe.md b/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe-efficient-moe.md
new file mode 100644
index 0000000000000000000000000000000000000000..558353721a2a135b08a7c3bd52a89b0d089e6338
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe-efficient-moe.md
@@ -0,0 +1,52 @@
+# MoE token dropless性能优化
+
+### 问题分析
+
+现有MoE无token丢弃方案在训练过程中实时all reduce通信全局最大专家容量作为所有rank的容量,引入频繁的all reduce开销。
+
+### 优化方案
+
+采用滑窗预测,无需每一个step都进行all reduce通信,all reduce通信频次降低为1/window_size,对于超过预测值场景,使用实际最大专家容量值替换预测capacity。
+
+#### 优化思路:
+
+1. MoE模型训练过程中capacity具有一定的连续性,维护一个滑动窗口来保存近几次统计的capacity来预估下一个窗口的capacity。
+2. 在Gate中不再每一个step都直接进行all reduce统计全局最大capacity,而是各个进程先判断当前的capacity能否满足不丢弃token,通过reduce统计全局的判断信息,若都能满足则无需进行all reduce通信,否则进行all reduce通信取得实际max capacity.
+ 
+
+### 使用场景
+
+请在MoE模型中的无token丢弃训练场景下使用此优化特性,以提升训练速度。当训练脚本`--num-experts`等于`--moe-train-capacity-factor`即`专家数量`等于`专家容量`时,为无token丢弃场景。
+
+### 使用方法
+
+设置`--moe-no-drop`: 表示开启MoE无token丢弃训练模式,Top1 Gate &Top2 Gate均已支持, 请搭配aux loss/sinkhorn负载均衡方式使用,避免无token丢弃场景负载均衡情况劣化严重
+
+设置`--moe-dynamic-padding`: 表示开启MoE无token丢弃训练优化,需要搭配`--moe-no-drop`同时开启,
+附加功能
+
+设置`--moe-use-sinkhorn`: 表示开启sinkhorn负载均衡功能
+
+
+### 使用效果
+
+在保持精度的同时提升训练速度。
+
+训练模型:Mixtral(4层)
+
+精度对比图如下:
+ 
+
+top2 多种并行方式 提速效果:
+ 
+
+top1 多种并行方式 提速效果:
+ 
+
+同时开启此优化减少显存占用3%:
+ 
+
+## 注意事项:
+
+暂不兼容 流水线并行特性,即需满足--pipeline_model_parallel_size <= 1。
+
diff --git a/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe-token-rearrange.md b/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe-token-rearrange.md
new file mode 100644
index 0000000000000000000000000000000000000000..f2d12bce89fee00ebd0bb502b9791f831c75aa0c
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe-token-rearrange.md
@@ -0,0 +1,28 @@
+# Token 重排性能优化
+
+## 问题分析
+
+DeepSpeed MoE的token重排采用了两个BatchMatmul实现,时间复杂度为o(s^2),而token重排进行计算时由于矩阵的稀疏性导致一些不必要的计算,存在优化空间。
+
+## 解决方案
+
+重排操作可以通过等价的pytorch API: index_select来实现,降低计算时间复杂度到o(s),从而提高训练性能。
+
+### 解决思路:
+
+1. 重排过程:top1gating/top2gating 函数计算出每个专家选择的token的索引:expert_select_token_idx,shape为: [E*C],MoE前向过程中根据此索引通过index_select API实现token的重排;
+
+2. 反重排过程:top1gating/top2gating 函数同时需要计算每个token在各个专家输出的索引位置:token_rearrange_ec_idx,shape为:[S]。在MoE前向过程中,token经过专家输出后通过index_select API 从[E*C, M]的专家输出中恢复token的输出:[S, M],最后乘以token选择对应专家的权重,得到MoE layer的输出。
+
+## 使用场景
+
+进MoE层时实际序列长度8K以上。
+
+## 使用方法
+
+设置`--enable-token-rearrange-opt`,即可调用该算法。
+
+## 使用效果
+
+预期性能收益在2%~3%左右。
+
diff --git a/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe.md b/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe.md
new file mode 100644
index 0000000000000000000000000000000000000000..60c7e4bd52e3062e79c05214f11ce04d4ec53ac8
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/deepspeed_moe/deepspeed-moe.md
@@ -0,0 +1,79 @@
+# Ascend DeepSpeed MoE 相关特性
+
+## 整体方案:
+
+Mixture of Expert(MoE)是指混合专家模型功能。在大模型训练中使用该功能可以将常规的稠密大模型变成稀疏的MoE大模型,在计算量不显著增加的情况下大幅提升模型的参数量。
+
+通过使用专家并行(Expert Parallelism,EP),把专家分配到多个计算设备上,减轻单个计算设备的显存压力,也就是说专家并行(Expert Parallelism,EP),对全量专家进行分组。
+如图所示,一个包含6个专家的MoE模型在EP=2时的专家分布情况。可以把专家并行理解成模型并行的一种形态(模型被切分成多份),但是输入的数据又是不同的(DP),因此token在经过Router之后,可能会选中别的卡上的专家,此时就需要将这些token发送过去,即EP进程组内需要通过All2All通信交换token。值得注意的是,该MoE模型在token选择专家时,如果超过容量会drop掉token。
+
+
+
+## 特性背景:
+
+支持moe模型及相关特性的兼容和适配,包含MoE基础模型、MoE适配序列并行(SP)、MoE适配长序列(CP)、MoE token重排性能优化。
+
+1.Mindspeed新增MoE混合专家模型(Mixtral 8*7B),支持使用MoE模型进行训练。
+
+2.MoE支持序列并行(sequence parallel),支持MoE与序列并行同时开启,减少MoE模块计算,提升MoE模块的训练性能。
+
+3.MoE适配长序列(context parallel)特性,支持MoE和CP特性同时开启。
+
+4.MoE适配分布式优化器特性,支持MoE和分布式优化器同时开启,降低内存,减少OOM风险。
+
+5.MoE token重排性能优化,减少token选择专家的gate计算量,提升训练性能。
+
+## 使用场景
+
+在需要处理大规模数据集和复杂任务的情况下,使用基于 moe 结构的大模型,以及其他SP、CP等特性。此特性暂只适配`--use-legacy-models`。
+
+### 使用建议:
+
+MoE+SP MoE开启TP需要同时开启SP。
+
+MoE+cp 建议处理长序列时开启。
+
+MoE+分布式优化器 建议默认开启。
+
+Token重排优化 建议默认开启。
+
+
+## 使用方法
+
+MoE特性基础功能:
+
+| 重要参数 | 参数说明 |
+|-------------------------------------|--------------|
+| --moe-model-type deepspeed_moe | 使用mixtral模型 |
+| --num-experts [int] | 专家数 |
+| --expert-model-parallel-size [int] | 专家并行 |
+| --expert-interval [int] | 专家层数间隔 |
+| --moe-train-capacity-factor [float] | 专家容量因子 |
+| --noisy-gate-policy | gate噪声策略 |
+| --no-use-rts | 不使用随机token选择 |
+
+
+MoE支持序列并行:
+
+|重要参数| 参数说明 | 注意事项 |
+| ---- | ---- | ---- |
+|--sequence-parallel |开启SP | MoE场景开启SP,由于影响数据分发,节省内存的同时可能引起性能波动|
+
+MoE适配长序列:
+
+| 重要参数| 参数说明 |
+| ---- | ---- |
+|--context-parallel-algo megatron_cp_algo |配置CP算法|
+
+
+MoE适配分布式优化器:
+
+|重要参数| 参数说明 |
+| ---- | ---- |
+|--use-distributed-optimizer |开启分布式优化器特性|
+
+MoE特性token重排优化:
+
+|重要参数| 参数说明 |
+| ---- | ---- |
+|--enable-token-rearrange-opt |开启token重排|
diff --git a/model/train/yoco_moe/docs/features/deepspeed_moe/moe.png b/model/train/yoco_moe/docs/features/deepspeed_moe/moe.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5cf063886329023e6c36bbc0e5f660bc46127a3
Binary files /dev/null and b/model/train/yoco_moe/docs/features/deepspeed_moe/moe.png differ
diff --git a/model/train/yoco_moe/docs/features/dist-train.md b/model/train/yoco_moe/docs/features/dist-train.md
new file mode 100644
index 0000000000000000000000000000000000000000..8863678102b8048dc44fb2d676c7570c0f847dc1
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/dist-train.md
@@ -0,0 +1,53 @@
+# 问题分析
+多模态模型的训练中,由于不同模态模型对算力和内存需求的异构性,会产生以下问题:
+- 不同模态模型的最优并行配置不同,全部使用同一种并行配置,造成负载不均衡、资源利用不充分;
+- 多模型融合部署,造成静态内存占用偏高,训练内存资源利用率不佳。
+
+
+# 解决方案
+新增dist-train功能,通过对异构模型配置不同的计算资源和并行配置,减少冗余的静态资源和异构模型间的气泡,使能异构模型之间的运行速度达到最优匹配。
+
+
+# 使用方法
+在启动脚本中添加参数`--dist-train`。
+需要在MindSpeed-MM仓库中,对应模型目录下的`model.json`中添加`dist_config`字段,具体配置示例如下:
+```json
+{
+ "dist_config": {
+ "model_name": "internvl2", // 多模态模型名称
+ "use_multiparam_send_recv": false, // 模型间是否传递tensor列表
+ "model_config": [
+ {
+ "name": "vit", // 内部模型名称
+ "model_index": 0, // 模型位于流水线中的序号
+ "world_size": 1, // 模型使用卡数
+ "tensor_model_parallel_size": 1,
+ "pipeline_model_parallel_size": 1,
+ "context_parallel_size": 1,
+ "forward_only": false // 是否不做反向计算
+ },
+ {
+ "name": "gpt",
+ "model_index": 1,
+ "world_size": 3,
+ "tensor_model_parallel_size": 1,
+ "pipeline_model_parallel_size": 3,
+ "context_parallel_size": 1,
+ "forward_only": false,
+ "main_dp": true // 配置该项时,代表DP数量以该模型为准,只在需要开启inner_dp时配置该项,且该配置唯一
+ }
+ ]
+ }
+}
+```
+
+
+# 使用效果
+根据模型不同、参数量不同,效果各有差异,可以针对SPS、MFU等指标进行调优,均有收益。
+
+
+# 注意事项
+- 需要配合多模态模型仓库MindSpeed-MM使用,目前支持模型和对应的子模型名称:internvl2 - [vit, gpt], opensoraplan1.3 - [vae, dit];
+- 需要注意在配置并行策略时,若原生模型不支持某种并行策略,则dist-train配置中也不应该开启此种并行策略;
+- 配置并行策略时,需要保证各个模型的DP数量保持一致,若配置有main_dp,则以main_dp的DP数量为准;
+- 需要使用dist-train配套的权重转换脚本,和MindSpeed-MM中使用的权重互相转换。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/distributed-optimizer.md b/model/train/yoco_moe/docs/features/distributed-optimizer.md
new file mode 100644
index 0000000000000000000000000000000000000000..888a43fcf3e6c724d46452b2d1befd91bcd4dea4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/distributed-optimizer.md
@@ -0,0 +1,15 @@
+# Megatron 分布式优化器
+## 问题分析
+数据并行(DP)场景下, 梯度all-reduce之前,不同的设备并行执行不同的工作。然而all-reduce以后, 所有设备都利用同样的梯度、参数以及优化器状态做相同的梯度更新操作,因此存在计算和存储冗余。
+
+## 解决思路
+将重复的内存存储和计算步骤拆分给不同的设备,通过通信进行联通,最终在同时节省显存和计算时间的条件下,达到和原始DP相同的效果。
+
+## 使用场景
+主要用于训练场景,当DP不为1时,将优化器状态拆分到所有DP组间,在对梯度进行reduce-scatter后,每个DP组分别进行部分权重更新,最后再用一个all-gather操作收集模型权重。
+
+## 使用方法
+脚本中添加`--use-distributed-optimizer`开启分布式优化器。
+
+## 使用影响
+降低显存开销。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/eod-reset.md b/model/train/yoco_moe/docs/features/eod-reset.md
new file mode 100644
index 0000000000000000000000000000000000000000..542f96d7f56017930730f967c96da4bd8a52a9e2
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/eod-reset.md
@@ -0,0 +1,56 @@
+# 支持EOD Reset训练场景
+
+## EOD Reset训练场景
+通常一个批次中输入进模型的文本序列是由多个文档(doc)拼接得到。在默认情况下,多个文档被视为同一序列,互相间的self attention没有掩盖。在特定情况下,多个文档间要求独立,文档间不能互相做self attention,在这种情况下attention mask和position ids需要在每个文档结束的位置(EOD)被重新设置。--reset-position-ids参数关闭时,整个序列计算位置编码;开启时,在每个序列内独立计算位置编码。
+
+## 解决方案
+通过调用底层flash-attention算子的可变长模式,支持EOD Reset训练场景。同时在EOD Reset训练场景下,支持Ring Attention长序列并行,对超长序列场景进行加速。
+
+## 使用方式
+### 1. Megatron代码修改
+1. 在 Megatron-LM 目录下修改`pretrain_gpt.py`文件中的`get_batch`函数。
+ ```diff
+ def get_batch(data_iterator):
+ """Generate a batch."""
+
+ - # TODO: this is pretty hacky, find a better way
+ - if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
+ - return None, None, None, None, None
+
+ # get batches based on the TP rank you are on
+ batch = get_batch_on_this_tp_rank(data_iterator)
+
+ # slice batch along sequence dimension for context parallelism
+ batch = get_batch_on_this_cp_rank(batch)
+
+ + # TODO: this is pretty hacky, find a better way
+ + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
+ + return None, None, None, None, None
+
+ return batch.values()
+ ```
+
+2. 在 Megatron-LM 目录下修改`pretrain_gpt.py`文件中的`get_batch`函数。
+
+ ```diff
+ def is_dataset_built_on_rank():
+ - return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0
+ + return mpu.get_tensor_model_parallel_rank() == 0
+ ```
+
+### 2. 数据准备
+首先确保每一个文档的末尾都添加了EOD Token。
+
+
+### 3. 参数设置
+前提,确保`--attention-mask-type`设置为`general`。
+
+#### 不启用长序列并行(CP)
+打开 `--reset-attention-mask`和`--reset-position-ids`选项
+#### 启用长序列并行
+首先确保`--context-parallel-size`大于`1`。
+
+打开`--reset-attention-mask`和`--reset-position-ids`选项。
+
+### 4. 注意事项
+Ascend EOD Reset训练场景下mask-type为general时,Ring/Hybrid Attention比Ulysses下降较多,为正常现象。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/flash-attention.md b/model/train/yoco_moe/docs/features/flash-attention.md
new file mode 100644
index 0000000000000000000000000000000000000000..15555927c5749bbbeb2572c38da3252c1438cffb
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/flash-attention.md
@@ -0,0 +1,35 @@
+# flash attention
+
+## 问题分析
+
+由于self-attention的时间计算复杂度和内存复杂度与序列长度成二次方关系,因此transformer在长序列上的处理时间、内存开销较大。近似的注意力方法可以优化这一问题,但会降低模型质量。
+
+## 解决方案
+
+加速注意力的关键在于优化IO访存,即降低片上内存的读/写次数。
+
+### 解决思路:
+
+Flash Attention 是一种优化IO访存开销的精确注意力方法,原理如下图所示[1],通过Tiling切片、重计算、Kernel Fusion等方式来减少高带宽内存(片上内存)和SRAM之间的内存读/写次数。NPU上提供了相同解决方案,可参考[fusion attention 对外接口](../ops/fusion_attention.md) 。
+
+a. Tiling切片:利用更高速的SRAM代替片上内存,但SRAM的内存容量较少,无法一次性完成所有数据的完整注意力计算,因此需要进行分块计算。
+
+b. 重计算:放弃中间结果写回,需要使用时重新计算,用计算换访存。
+
+c. Kernel Fusion:将多个操作融合为一个操作,基于Tiling利用一个kernel完成整个计算。
+
+ 
+
+[原文链接](https://arxiv.org/pdf/2205.14135)
+
+## 使用场景
+
+本方法适用于self-attention相关模型,尤其适用于长序列输入场景,开启长序列并行时该特性默认开启。
+
+## 使用方法
+
+设置`--use-flash-attn`即可调用该算法。
+
+## 使用效果
+
+在模型训练时间、模型质量等方面可以提升性能。
diff --git a/model/train/yoco_moe/docs/features/fused_ema_adamw_optimizer.md b/model/train/yoco_moe/docs/features/fused_ema_adamw_optimizer.md
new file mode 100644
index 0000000000000000000000000000000000000000..79c2279549d9b4a045754f715364a48cb4976316
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/fused_ema_adamw_optimizer.md
@@ -0,0 +1,26 @@
+# fused_ema_adamw 优化器
+## 问题分析
+多模态领域在模型训练过程中往往会额外生成ema模型用于后续任务,因此需要在训练过程中生成和保存ema模型数据,fused_ema_adamw优化器可以在模型训练过程中额外维护一份ema模型参数,在权重保存时ema模型将自动保存到模型权重文件中。
+
+## 解决思路
+在训练过程中,fused_ema_adamw优化器会为模型参数维护一份```ema_params```状态,并在每次优化器迭代过程中更新。ema_params状态更新公式如下:
+
+ ema_params = ema_decay * ema_params + (1 - ema_decay) * model_params
+
+```model_params```为模型参数,```ema_decay```为超参数。其中,```ema_decay```可在训练脚本中使用'--ema-decay 数值'来指定,若脚本中未指定,则默认ema_decay为0.9999。
+
+## 使用场景
+主要用于需要保存ema模型用于后续任务的多模态训练场景。
+
+## 使用方法
+1.在脚本中添加`--optimizer-selection fused_ema_adamw`,可开启fused_ema_adamw优化器,优化器```ema_params```状态保存功能与ema模型权重保存功能会一同开启。
+2.在脚本中添加`--ema-decay 数值`,可指定ema_decay,如未指定,则默认为0.9999。
+
+## 使用影响
+1.由于fused_ema_adamw优化器在训练时需要额外维护```ema_params```状态,内存开销会有所增加。
+2.权重保存时,优化器的```ema_params```优化器状态将会存储在distrib_optim.pt文件中。
+3.权重保存时,ema模型权重数据将会存储在model_optim_rng.pt文件中的```ema_model```字段中。
+
+## 注意事项
+1.fused_ema_adamw优化器不支持和参数副本复用特性同时开启,使用本优化器时请勿在训练脚本中添加`--reuse-fp32-param`参数。
+2.fused_ema_adamw优化器在训练时需要额外维护一份ema数据,内存开销会有所增加。不同的训练配置内存开销增加幅度不同,使用时请根据实际硬件内存变化情况,适当调整训练脚本或模型结构。
diff --git a/model/train/yoco_moe/docs/features/fusion-attn-v2.md b/model/train/yoco_moe/docs/features/fusion-attn-v2.md
new file mode 100644
index 0000000000000000000000000000000000000000..757c3f576dc633e2d08cfdf20dfe5845e9c6e175
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/fusion-attn-v2.md
@@ -0,0 +1,15 @@
+# fusion_attention_v2
+
+## 使用场景
+
+本方法为FlashAttention的V2版本,对V1版本进行了一定功能拓展,当前仅支持特定场景如[Alibi位置编码](./alibi.md),默认关闭。
+
+其它场景原则上与V1版本无差异,不建议主动开启。算子说明详情见[接口说明](../ops/fusion_attention.md)。
+
+## 使用方法
+
+设置`--use-fusion-attn-v2`即可调用该算法。
+
+## 使用效果
+
+基础效果等同于Flash attention,特定场景如[Alibi位置编码](./alibi.md)需手动开启。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/generate-mask.md b/model/train/yoco_moe/docs/features/generate-mask.md
new file mode 100644
index 0000000000000000000000000000000000000000..3e8f083445c666d675109a620e8948bb1d345793
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/generate-mask.md
@@ -0,0 +1,57 @@
+# MindSpeed Mask归一实现阐述
+
+## 问题分析
+### 1. Megatron源码阐述
+[1] 各device通过 `pretrain_gpt.py#L93`-`def get_batch` 去获取各项数据,包括AttnMask。
+
+[2] PP的首尾节点通过 `megatron/training/utils.py#L276`-`def get_batch_on_this_tp_rank` 去获取各项数据,包括AttnMask。其他节点直接返回None。
+
+[3] TP的首节点通过 `megatron/core/datasets/gpt_dataset.py#L675`-`def _get_ltor_masks_and_position_ids` 生成AttnMask。
+
+[4] TP其他节点,直接生成与首节点相同shape的empty矩阵,通过broadcast获取首节点生成的AttnMask。
+
+Tips: 以上操作默认开启,生成的AttnMask全部为下三角形状,可以通过 `--no-create-attention-mask-in-dataloader` 关闭。
+
+### 2. 问题发现
+[1] 昇腾的FA需要外部生成AttnMask,所以除了基础下三角模式,需要额外接口生成自定义AttnMask。
+
+[2] 非PP首尾节点的AttnMask为None,无法使能FA加速。
+
+[3] AttnMask生成、拷贝及广播操作,在每个micro_step都会跟随get_batch重复。
+
+[4] 长序列下,生成的AttnMask占用显存过大。
+
+
+## 解决方案
+### 解决思路
+[1] 提供统一AttnMask生成接口,同一进程复用一个全局变量AttnMask,避免重复生成和拷贝。
+
+[2] 适配AttnMask压缩模式,减少显存占用。
+
+
+## 使用场景
+目前支持FA和非FA的Mask生成,传入AttnMask可以为None和下三角模式。
+
+FA场景,当序列长度大于2048或使用ring_cp_algo时,默认走压缩模式。
+
+其他场景使用完整Mask。
+
+
+## 使用方法
+[1] 针对以上问题和思路,在MindSpeed中,直接默认使能AttnMask,不再使用原生mask生成方式。
+
+[2] 提供全局变量 `mindspeed/model/transformer.py`-`global _GLOBAL_ATTN_MASK`。
+
+[3] 提供 `--sparse-mode` 传参,配合FA多种模式调用。`--sparse-mode`的不同模式信息可以参考[torch_npu.npu_fusion_attention算子文档](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_001095.html)。
+
+[4] 提供 `mindspeed/model/transformer.py`-`def set_attention_mask`、`def get_attention_mask`、`def generate_attention_mask` 三个额外接口,实现正常流程外的设置、获取和生成功能。
+
+[5] 提供 `mindspeed/model/transformer.py`-`def parallel_transformer_forward_wrapper` 接口,避开get_batch的缺陷,在第一次正向的时候使能generate_attention_mask来生成mask。
+
+
+## 使用效果
+例如下三角模式,压缩模式下设sparse_mode=2,mask.shape固定为[2048,2048],将大幅提升性能并降低显存。
+
+
+## 注意事项
+当前FA场景仅支持下三角及Band模式,其他自定义AttnMask模式需要手动set_attention_mask,或修改get_attention_mask逻辑。
diff --git a/model/train/yoco_moe/docs/features/hccl-group-buffer-set.md b/model/train/yoco_moe/docs/features/hccl-group-buffer-set.md
new file mode 100644
index 0000000000000000000000000000000000000000..271e3e71ba3c170588601c9a27205ace6d2c82d1
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/hccl-group-buffer-set.md
@@ -0,0 +1,33 @@
+# Hccl Group Buffer Set
+
+## 问题背景
+当前 MindSpeed 的通信域 Buffer,只能通过环境变量 HCCL_BUFFSIZE 进行统一设置(默认为 200M ),但是往往不同的通信域所需的 Buffer 大小不能一概而论
+
+## 解决方案
+### 1.手动配置
+对外呈现开关,使得用户可以根据自己需求自己设置通信域缓冲区大小
+### 2.自动配置(推荐)
+使用自适应方案,MindSpeed 根据网络参数自适应通信域缓冲区大小
+
+## 使用方法
+### 1.手动配置
+打开--hccl-group-buffer,并指定所需要设定的组以及大小(例如:dp:200;tp:300;exp:400),单位是 M 。
+
+手动配置目前支持通信组:
+
+["dp", "dp_cp", "cp", "mp", "mp_exp", "tp", "pp", "embd", "tp_dp_cp", "tp_dp", "tp_cp", "tp_exp",
+ "exp", "dp_modulo_exp", "pp_new_stream", "cp2", "cp_ulysses", "cp_ring","cp_ring_intra", "cp_ring_intra_overlap",
+ "nd1_dim1", "ag_x_sd_rcv_overlap", "nd1_dim2", "ag_y_sd_rcv_overlap", "nd2_dim1", "nd2_dim2"]
+
+### 2.自动配置
+打开 --hccl-group-buffer-adaptive ,会自适应设置 tp、cp、pp 相关通信组大小; 需要注意的是,对于 ep 相关的通信组(exp、tp_exp、tp),用户可自行根据当前模型 MOE 负载不均衡的程度指定系数--hccl-ep-group-buffer-adaptive-factor 从而得到合适的通信域 Buffer,该系数代表当前负载不均衡的程度( 例如,设置--hccl-ep-group-buffer-adaptive-factor 大小为 1, 代表的是负载均衡情况下需要开启的buffer大小;设置为n,代表当前缓冲区大小是负载均衡情况下的 n 倍,n 配置过大有可能会导致OOM )
+
+自动配置目前支持通信组:
+
+[ "cp", "mp", "mp-exp", "tp", "pp", "tp_cp", "tp_exp", "exp", "pp_new_stream", "cp2", "cp_ulysses", "cp_ring", "cp_ring_intra","cp_ring_intra_overlap"]
+
+## 使用效果
+llama 系列模型,开启自适应方案,性能不下降的同时节约显存;MOE 相关模型,开启自适应方案并设置合适的负载不均衡系数,性能不下降的同时节约显存。
+
+## 使用限制
+本特性依赖 PTA:FrameworkPTAdapter 7.0.RC1.B020 (包含该版本)之后的版本
diff --git a/model/train/yoco_moe/docs/features/hccl-replace-gloo.md b/model/train/yoco_moe/docs/features/hccl-replace-gloo.md
new file mode 100644
index 0000000000000000000000000000000000000000..10836b439137b3111394f66fad34a334fd6b9a76
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/hccl-replace-gloo.md
@@ -0,0 +1,39 @@
+# Gloo 存档落盘优化
+
+## 问题分析
+在大规模集群下,Gloo 通信存在规模限制和稳定性问题。一方面,容易出现 Gloo 通信组创建失败的情况;另一方面,与 Hccl 通信相比,Gloo 通信较慢。
+
+对于Gloo通信组创建失败报错`Gloo connectFullMesh failed with ...`的问题,本质上是由于N张卡链接到主Master来完成建链,集群规模大时,Master处理能力不足,可能造成建链失败。可通过调整和网络建链相关参数进行规避(云上8k卡场景验证有效):
+
+```
+net.ipv4.tcp_max_syn_backlog = 65536
+net.core.netdev_max_backlog = 65536
+```
+
+此外,MindSpeed设计了 Gloo 通信优化方案使用Hccl通信替代Gloo。
+
+## 解决方案
+
+### 解决思路
+[1] 采用 Hccl 通信组替换 Gloo 通信组,实现在原有功能基础上的替代。
+
+[2] 采用切片方式减少单次通信的数据量,避免通信量过大导致的显存消耗。
+
+## 使用场景
+当 Gloo 通信频繁出现建链失败时,模型启动效率较低,此时需要通过替换 Gloo 通信组来提升效率。
+
+## 使用方法
+[1] 在训练脚本中加入 `--disable-gloo-group`,以启用该特性。
+
+[2] 在脚本中定义 `--hccl-slice-size N`(可选),设置 DP 组保存和加载分布式优化器状态时的通信量大小。该参数的有效区间为 (0, bucket_size/dp],其中 bucket_size 为分布式优化器中每个桶的大小。建议在显存允许的情况下,尽量增大该参数,以提高通信效率。
+
+## 使用效果
+### 通信效率分析
+理论上,分布式优化器状态保存和加载的通信效率在一定范围内随着 `hccl-slice-size` 增加而提升。
+
+
+### 显存增量分析
+开启该特性后,显存的增加量为 `hccl-slice-size * (dp + 1) * 4B`。
diff --git a/model/train/yoco_moe/docs/features/hierarchical-alltoallv.md b/model/train/yoco_moe/docs/features/hierarchical-alltoallv.md
new file mode 100644
index 0000000000000000000000000000000000000000..ec853b8986ac1cd25a06a04eb6eead6f96a6e0b4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/hierarchical-alltoallv.md
@@ -0,0 +1,37 @@
+# MoE Hierarchical Alltoallv分层通信
+
+## 问题分析
+
+大模型训练过程中,在MoE Layer中前向和反向均有两次alltoallv进行数据传输。在原本的实现中,多卡多机组成了一个规模较大的通信组,机间机内同时进行alltoallv通信,造成机间的冗余通信较多。
+
+如下图所示,假设有两个节点,其中topk=3,节点1内的某个token选择了节点2上不同卡的3个专家,那么这个token通过alltoallv传输了3次,也就是topk次。
+
+
+
+## 解决方案
+
+将规模较大的通信组分成两个相互正交的inner group和outer group。在outer group内,也就是ep组内,使用allgather收集token,再在inner group内,也就是tp组内,使用alltoallv传输token,提供分层通信的功能,节省topk倍的冗余通信,提升性能。
+
+
+
+## 使用场景
+
+在多机情况下,deepseekv2类moe模型,开启tp_extend_ep特性,且需要提升性能的场景下。
+
+## 使用方法
+
+在启动脚本中添加参数--moe-hierarchical-alltoallv,即可开启分层通信开关。
+
+## 使用效果
+
+开启分层通信moe_hierarchical_alltoallv前后,5000step精度对齐,性能收益4.28%
+
+在内存优化特性memory level0开启的情况下,开启分层通信moe_hierarchical_alltoallv前后对比,5000step精度对齐,性能收益3.02%
+
+在内存优化特性memory level1开启的情况下,开启分层通信moe_hierarchical_alltoallv前后对比,5000step精度对齐,性能收益4.34%
+
+## 注意事项:
+
+1.仅支持在多机情况下,moe_tp_extend_ep和moe_alltoall_overlap_comm特性开启的情况下
+
+2.Megatron和MindSpeed版本均为使用core_r0.8.0分支。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/hybrid-context-parallel.md b/model/train/yoco_moe/docs/features/hybrid-context-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..57762f73b206fe30667dbb184498c503c79904ca
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/hybrid-context-parallel.md
@@ -0,0 +1,45 @@
+# 混合长序列并行
+
+## 问题分析
+
+从生成性AI到科研模型,长序列训练正在变得非常重要。 在生成性AI领域,会话式AI、长文档摘要和视频生成等任务都需要在空间和时间层面对长上下文进行推理。 同样,章节和书籍级别的摘要(数万甚至数十万字)在会话式AI和摘要任务中也受到重视。现有的数据、张量和流水线等并行方法无法在序列维度进行切分。当序列维度(S)增长时,训练内存开销会以 $O$($S^2$) 的速度增长。因此需要针对长序列场景进行特定的优化解决长训练场景的训练需求。
+
+目前流行的序列并行方案,Ulysses和Ring Attention存在各自的局限性。
+
+Ulysses需要确保attention head数可以被序列并行维度整除,在GQA、MQA场景下序列并行的大小有限制,导致序列长度的扩展有限。
+
+Ring Attention的并行维度不受attention head数限制,因此理论上序列长度可以无限拓展。但相比于Ulysses,Ring Attention不能充分利用通信和计算带宽,在序列块大小较低时性能劣于Ulysses。
+
+## 解决方案
+对Ulysses和Ring Attention做融合,实现混合序列并行,以此解决两个方案各自缺陷。
+
+## 使用场景
+
+可兼容FlashAttention,目前已默认开启FlashAttention。
+
+序列并行维度被分为Ulysses维度和ring attention维度,Ulysses维度和ring attention维度乘积即为序列并行维度。
+
+## 使用方法
+
+设置`--context-parallel-size`,默认为1,根据用户需求配置。
+
+设置`--context-parallel-algo hybrid_cp_algo`,以使能混合序列并行。
+
+设置`--ulysses-degree-in-cp`,需要确保`--context-parallel-size`可以被该参数整除且大于1。例如当设置`--context-parallel-size=8`时,可以设置`--ulysses-degree-in-cp=2`或`--ulysses-degree-in-cp=4`。
+
+同时需要确保`--ulysses-degree-in-cp`可以被attention head数整除。
+
+混合长序列并行支持Ring Attention长序列并行相关特性,包括send receive overlap功能、Mask计算类型配置。
+
+## 使用效果
+
+利用多个计算设备对输入序列进行并行切分,降低单设备的内存消耗,相比不开启序列并行单步耗时增加,相比重计算计算效率提升。
+
+## 鸣谢
+
+1. GitHub项目地址:
+https://github.com/feifeibear/long-context-attention
+
+2. 论文预印本地址:
+USP: A Unified Sequence Parallelism Approach for Long Context Generative AI
+https://arxiv.org/abs/2405.07719
diff --git a/model/train/yoco_moe/docs/features/mc2.md b/model/train/yoco_moe/docs/features/mc2.md
new file mode 100644
index 0000000000000000000000000000000000000000..204116f1a9e660d2ed62a1536813b3e8589b7ed9
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/mc2.md
@@ -0,0 +1,63 @@
+# Ascend MC2
+
+## 使用前提
+仅限于版本标识为`CANN 8.0.RC2`和`Ascend HDK 24.1.RC2`及其后续所有迭代版本的系统环境。
+
+若在非指定版本中尝试采用该配置,可能触发系统级的异常行为,包括但不限于运行时错误。
+
+## 问题分析
+
+在开启了TP和SP的大模型训练场景下,存在matmul计算和all-reduce操作的强依赖关系(不开启SP),或存在matmul计算和all_gather/reduce_scatter操作的强依赖关系(开启SP)。当模型参数量较大时,此处通信量和计算量都较大,在串行执行时,会引入较长的等待闲置时间。
+
+## 解决方案
+
+针对存在强依赖关系的matmul计算和通信操作,昇腾自研MC2解决方案。
+
+MC2通过融合算子的方式将matmul计算和集合通信操作进行融合,将较大的计算和通信任务切分成了较小的计算子任务和通信子任务,并通过流水的方式使得通信子任务和计算子任务可以互相掩盖,从而减少等待和闲置时间,提高利用率。
+
+## 解决思路:
+
+MindSpeed在python脚本侧将原本串行的matmul和all_gather/reduce_scatter操作通过MC2融合算子接口进行了融合。
+
+具体实现参见[link](../../mindspeed/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py)。
+
+MC2算子接口参见[link](https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/apiref/apilist/ptaoplist_000449.html)。
+
+## 使用场景
+
+当开启了TP和SP时,建议开启MC2进一步优化。模型权重冻结和模型权重不冻结两个场景均支持。
+
+### 说明
+可以通过设置`requires_grad`属性为`False`来实现权重冻结。
+```python
+# 举例1:冻结所有参数
+for param in model.parameters():
+ param.requires_grad = False
+```
+
+```python
+# 举例2:除了output_layer,冻结所有ColumnParallelLinear和RowParallelLinear
+from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
+for name, module in model.named_modules():
+ if ('output_layer' not in name
+ and (isinstance(module, ColumnParallelLinear) or isinstance(module, RowParallelLinear))):
+ for param in module.parameters():
+ param.requires_grad = False
+```
+
+## 使用方法
+
+设置--use-ascend-mc2即可使能MC2算子。
+
+
+**同时需要确保开启**`--sequence-parallel`。
+
+## 使用效果
+
+在开启TP和SP的训练场景下,使用MC2可以减少内存开销并提高计算效率。
+
+## 注意事项
+
+1. MoE模型暂不支持开启MC2。
+2. 暂不兼容计算通信并行 CoC 特性 --use-ascend-coc 。
+3. 该特性不支持在 Atlas 900 A3 硬件上使用。
diff --git a/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-adaptive-recompute-activation.md b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-adaptive-recompute-activation.md
new file mode 100644
index 0000000000000000000000000000000000000000..9f2aac7712e9973ec6004b04826acbc7e2241ce3
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-adaptive-recompute-activation.md
@@ -0,0 +1,22 @@
+# MoE 负载感知内存均衡算法
+
+## 问题分析
+
+MoE在预训练前期负载均衡 aux_loss 未起作用时,token 在专家层的分配不均会导致全局内存负载不均衡问题,并引入大量碎片内存,导致显存波动巨大,并呈现逐步递增的趋势,大集群训练下更容易出现OOM。
+
+## 优化方案
+
+根据模型设定参数(DP/TP/SeqLength/NumExpert等),设定token分布不均的判定阈值,当超过该阈值后执行重计算,避免产生内存不均衡的激活值。
+
+## 使用限制
+
+1. 使用时**建议**开启`--moe-router-load-balancing-type aux_loss`,这样会使得训练过程中token分布快速趋向于平衡。
+2. 开启全局重计算后该功能无效。
+3. 仅支持`--moe-token-dispatcher-type`是all-gather时可用。
+4. 不兼容--recompute-in-bubble特性。
+
+## 使用方法
+
+- 必选参数:脚本中加入`--moe-adaptive-recompute-activation`即可。
+
+- 可选参数:如果希望调节判定阈值,则修改`--moe-adaptive-recompute-activation-scale`即可,默认值为2.0,表示阈值为平均分配在每个专家上的序列的两倍长度。
diff --git a/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-allgather-dispatcher.md b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-allgather-dispatcher.md
new file mode 100644
index 0000000000000000000000000000000000000000..46570ac226df147b910e6a7d013866459be01c74
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-allgather-dispatcher.md
@@ -0,0 +1,32 @@
+# Allgather Dispatcher 分支优化
+
+## 问题分析
+### 1. gather & scatter 算子替换
+
+在 Megatron MoE 中的 Allgather 分支,存在使用 gather/scatter 操作。gather/scatter 功能为沿 dim 轴根据索引逐元素进行取值/赋值操作,此操作会有大量的随机地址,对性能造成巨大影响。
+
+在 Megatron MoE 中对 gather/scatter 的调用主要是以下调用方式,通过对 index 做 expand 操作对维度进行扩展,再通过扩展后 index 对 hidden_states 进行逐元素取值/赋值。
+
+```python
+self.global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
+local_hidden_states = torch.gather(global_hidden_states, 0, self.global_local_map)
+```
+
+### 2. 异步通信
+在 Allgather dispatcher 分支中,会 permutate 函数开头分别对 hidden_states、max_ind、max_prob 三个数据做 allgather 通信,这些操作为串行操作,但各计算任务之间并非串行依赖关系。
+
+
+## 解决方案
+### 1. gather & scatter 算子替换
+由于 index 是通过 expand 进行扩展的,因此它的每一行中的内容都是一致,而我们没有必要使用 gather/scatter 进行逐元素的操作,可通过 index 算子以及 indexput 算子进行逐行操作,对 gather/scatter 进行等价替换。
+
+
+### 2. 异步通信
+通过对通信任务进行重新排序,并使用 async=True 参数进行异步下发,达到计算和通信并行的目的。
+
+## 使用场景
+
+在使用 mcore MoE 的场景下,开启了 `--moe-token-dispatcher-type allgather`。
+
+## 使用方法
+开启参数 `--moe-permutation-async-comm`。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-alltoall-dispatcher.md b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-alltoall-dispatcher.md
new file mode 100644
index 0000000000000000000000000000000000000000..6b30ad7e99704a687e480bca464ea38e702a7b34
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-alltoall-dispatcher.md
@@ -0,0 +1,25 @@
+# Alltoall Dispatcher 分支优化
+
+## 问题分析
+### 1. repeat_interleave 并行
+在 Alltoall dispatcher 分支中,调用了 repeat_interleave 算子,此算子只使用了单个 block dim 在单个下发流上进行串行计算,且耗时较长,算子的输出也是在 alltoall、permute、alltoallv 之后才用到。
+
+### 2. 计算通信并行
+在 alltoall 分支中的 permutation 函数最后会进行 allgather 操作,对所有 tokens 被切分的 H 维进行补全,然后再对数据分块进行专家计算。此项操作为串行操作,但各专家间的 tokens 并没有存在依赖关系,可修改为并行操作。
+
+
+## 解决方案
+### 1. repeat_interleave 并行
+通过新建一条下发流,将 repeat_interleave 算子调用分到新的流上,在 block dim 资源充足的情况下,可进行两个算子的并行计算,节省耗时。
+
+### 2. 计算通信并行
+可按照每个专家需要的 tokens 进行切分,然后逐个对 tokens 进行 allgather 通信 + 专家计算,由于第一个专家计算只依赖第一个通信,专家之间无依赖关系,因此在做第一个专家计算的时候可同步进行第二专家的通信,达到计算和通信并行。
+
+## 使用场景
+在使用 mcore MoE 的场景下,开启了 `--moe-token-dispatcher-type alltoall`。
+
+## 使用方法
+开启参数 `--moe-permutation-async-comm`。
+
+## 场景限制
+由于开启 `--moe-grouped-gemm` 后,专家计算被单一算子合并,因此计算通信并行优化会失效。
diff --git a/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-bmm-fused.md b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-bmm-fused.md
new file mode 100644
index 0000000000000000000000000000000000000000..35cc98bcf45418f8249d26eb6dee0872e834e2d8
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-bmm-fused.md
@@ -0,0 +1,26 @@
+# Megatron MoE BMM
+
+## 问题分析
+
+针对MoE的drop and pad场景,所有专家上tokens数量相同,使用bmm融合算子(融合前后的通信操作)替换gmm算子能达到更好的效果。
+
+## 解决方案
+
+通过调用bmm通算融合算子(alltoall_allgather_bmm和bmm_reducescatter_alltoall)替换gmm算子及前后的通信操作,达到加速效果。
+
+## 使用方法
+在drop and pad场景
+
+前置条件需要设置`--moe-grouped-gemm`
+
+设置`--moe-bmm-mc2`: 表示通过BMM的融合算子计算。
+
+## 使用效果
+在ep=8的场景下,开启融合算子替换,性能提升2%左右。
+
+## 使用限制
+1.仅支持megatron_moe的alltoall分支,且开启tp和ep。
+
+2.仅支持昇腾Atlas A3 AI处理器。
+
+3.不支持`--moe-tp-extend-ep`和`--moe-alltoall-overlap-comm`特性。
diff --git a/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-gmm.md b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-gmm.md
new file mode 100644
index 0000000000000000000000000000000000000000..9d58ac54553ab432620e1e3563ff346acd448561
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-gmm.md
@@ -0,0 +1,41 @@
+# Megatron MoE Grouped GeMM
+
+## 问题分析
+
+针对MoE单卡多专家计算,存在细碎的专家计算操作与通信,通过Grouped GeMM算子对多专家计算进行合并,提升MoE单卡多专家训练性能。
+
+## 解决方案
+
+通过调用 gmm 融合算子,对多个专家计算进行融合,达到加速效果。
+
+## 使用方法
+
+设置`--moe-grouped-gemm`: 表示开启Grouped GeMM计算。
+
+## 效果说明
+
+典型场景:
+
+- EP变小导致单卡专家数量增大 & DeepSeek MoE专家数量较多等场景。
+- DeepSeek MoE finegrained expert单个专家较小 & FFN规模不大 & TP变大导致单卡切分的计算变小。
+
+1. 随着FFN规模提升,计算不再细碎,单专家计算效率提升,Grouped GeMM 收益变小。
+
+表1:grok模型FFN大小和性能加速对比
+
+|ffn_hidden_size| 32768 | 16384| 8192| 4096|
+|--|--|--|--|--|
+|baseline|2280|1780|1537|1446|
+|GeMM|2416|1719|1448|1331|
+|性能提升|-5.30%|3.53%|6.12%|8.60%|
+
+
+2. TP越大,EP越小,收益更大。
+
+表2:Mixtral8*7B模型配置不同性能收益
+
+|配置| tp4 ep2 16expert | tp4 ep2 8expert | tp2 ep4 16expert| tp2 ep4 8expert|
+|--|--|--|--|--|
+|baseline|27969|20127|11976|13981|
+|GeMM|19415|17361|11049|14290|
+|性能提升|44.06%|17.93%|8.39%|-2.19%|
diff --git a/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-tp-extend-ep.md b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-tp-extend-ep.md
new file mode 100644
index 0000000000000000000000000000000000000000..f6b8e697201b3bb5ab05a08cda7b28ac432ba1d4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/megatron_moe/megatron-moe-tp-extend-ep.md
@@ -0,0 +1,28 @@
+# Megatron MoE TP拓展EP
+
+## 问题分析
+
+开启TP+EP后,专家层TP组切分专家参数,MoE细粒度小专家场景TP切分后GMM算子效率下降严重。
+
+## 解决方案
+
+针对小专家场景TP切分后GMM算子效率下降问题,专家层TP组不切分专家参数,切分专家数量。
+
+## 使用方法
+
+打开`--moe-tp-extend-ep`启用该特性。
+
+同时需要开启:
+- `--moe-permutation-async-comm`
+- `--moe-grouped-gemm`,目前仅支持Grouped MLP。
+
+同时需要确保`--num-experts`能被`tp * ep`整除。
+
+当前该特性不支持Moe Token drop and pad模式,即`--moe-expert-capacity-factor`需要为None。
+
+## 适用场景
+
+细粒度小专家,类DeepSeek-V2模型,每个专家的参数量较小。
+
+
+
diff --git a/model/train/yoco_moe/docs/features/moe-experts-pipeline-degree.md b/model/train/yoco_moe/docs/features/moe-experts-pipeline-degree.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b7bb81afada7577429ec4facf07a12a9a4a9fa4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/moe-experts-pipeline-degree.md
@@ -0,0 +1,31 @@
+# MoE Experts Pipeline Degree大专家流水
+
+## 问题分析
+
+该功能在面向megatron+mindspeed框架MoE类模型大专家的情况下,专家的计算时间和通信时间在每层神经网络中占比大,严重影响性能和内存。
+
+## 解决方案
+
+本方案中,将专家分组做流水,使专家计算内部的通信和计算相互掩盖,只有头尾未掩盖通信开销。在负载基本均衡的情况下,专家的allgather和reducescatter的未掩盖通信变成了1/moe_experts_pipeline_degree。(moe_experts_pipeline_degree表示流水次数)
+
+
+
+## 使用场景
+
+基于gpt-moe大专家模型,专家的计算时间和通信时间在每层神经网络中占比大的情况。
+
+## 使用方法
+
+在启动脚本中添加并合理配置 --moe-experts-pipeline-degree [int]
+
+其中,[int]是大专家的流水粒度,是大于1小于专家数num_experts,并且可以被专家数num_experts整除的整数。
+
+## 使用效果
+
+配置大专家流水moe_experts_pipeline_degree前后, 5000step精度对齐。性能收益2.11%,内存收益4.38%
+
+## 注意事项:
+
+1.仅支持tp>1, sp和moe_alltoall_overlap_comm特性开启的情况下
+
+2.Megatron和MindSpeed版本均为使用core_r0.8.0分支。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/moe-token-permute-and-unpermute.md b/model/train/yoco_moe/docs/features/moe-token-permute-and-unpermute.md
new file mode 100644
index 0000000000000000000000000000000000000000..e8f87fb722c642f35057e9cb66435fd876b53687
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/moe-token-permute-and-unpermute.md
@@ -0,0 +1,23 @@
+# Moe Token Permute and Unpermute 融合优化
+
+## 问题分析
+在MoE架构中,MoEAlltoAllTokenDispatcher调度器负责将token令牌分配给各个专家进行处理,并将处理后的结果重新组合回原始的token顺序。这个过程通常涉及到以下步骤:
+Token路由:确定每个token应该由哪个专家处理。这可以通过专家门控机制(gating mechanism)来完成,门控机制为每个token选择最适合的专家。
+数据重排(Permute):将token按选择的专家进行分组,以便每个专家可以并行处理属于它的token。这通常涉及到对token的重排操作。
+专家处理:每个专家并行处理属于它的token。
+结果重组(Unpermute):处理完成后,需要将来自不同专家的结果重组回原始的token顺序。
+在上述流程中,数据重排和结果重组步骤是性能瓶颈之一。这是因为这两个步骤涉及到大量的数据移动,特别是在使用分布式训练时。
+## 解决方法
+为了优化这一过程,可以考虑将数据重排和结果重组步骤合并成一个操作。MindSpeed对将Moe Token Permute和Unpermute操作分别融合成一个算子,提升模型训练性能。算子接口分别见[link](../ops/npu_fused_moe_token_permute.md),[link](../ops/npu_fused_moe_token_unpermute.md)。
+
+## 使用方法
+### 前提条件
+开启专家并行,并且`moe-token-dispatcher-type`设置为`alltoall`
+`--expert-model-parallel-size M \
+--num-experts N \
+--moe-token-dispatcher-type alltoall`
+
+设置`--use-fused-moe-token-permute-and-unpermute`即可调用Moe Token Permute和Unpermute融合算子。
+
+## 使用效果
+启用融合算子后,不仅能够有效节省内存资源,还能提升模型训练性能。
diff --git a/model/train/yoco_moe/docs/features/multi_parameter_pipeline.md b/model/train/yoco_moe/docs/features/multi_parameter_pipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..268284d395a8cce2a5f8edadbde244686cc99abb
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/multi_parameter_pipeline.md
@@ -0,0 +1,47 @@
+# PP支持多参数传递
+
+## 背景与挑战
+
+在深度学习的大模型分布式训练中,流水线并行(Pipeline Parallelism, PP)是一种将模型分割为多个阶段并在不同设备上并发执行以提高效率的技术。然而,在多模态场景下引入多参数传递支持时,PP面临特定的挑战:
+
+- **通信部分的设计**:传统PP通常只涉及单一张量的传输,但在多参数传递的情况下,需要处理多个变量的传递。这不仅增加了通信复杂度,还要求对每个变量的shape、dtype等属性进行精确管理,这些属性往往与整体模型架构紧密相关,具有高度定制性。
+- **前向传播的变量传递**:在前向计算过程中,不仅要根据定义的shape正确传递多个变量,还要确保每个阶段接收到的数据格式符合预期,这对数据流的设计提出了更高的要求。
+- **反向传播的运算扩展**:对于反向传播,除了对首个输出进行梯度计算外,还需对其他所有输出进行相应的运算,确保整个训练过程的完整性和准确性。
+
+## 解决方案
+
+针对上述挑战,我们设计了以下解决方案,旨在使PP能够有效支持多参数传递:
+
+- **优化的通信机制**:开发了一套高效的通信机制,支持多种类型和格式的数据传输。针对每个阶段的具体需求定制化配置传输参数。
+- **增强的梯度计算逻辑**:改进了反向传播算法,使得系统可以自动识别并处理来自多个输出的梯度信息,保证每个输出都能参与到最终的权重更新中。
+
+## 使用场景
+
+本特性特别适用于以下场景:
+- 需要处理大量多模态数据(如文本、图像、音频)的大型神经网络训练任务,并且流水线并行各个阶段传递多参数。
+
+## 使用方法
+
+**注意事项**:
+- 用户需代码配置`args.pipeline_tensor_shapes`明确指定各阶段间传递的具体参数及其属性(如shape、dtype)。
+- args.pipeline_tensor_shapes配置参考`tests_extend/system_tests/multi_modal/multi_parameter_pipeline/pretrain_multi_parameter_pipeline_test.py`
+
+
+**设置训练脚本参数**
+- 支持PP场景
+```shell
+# PP >= 2
+--pipeline-model-parallel-size ${PP} \
+--use-multiparameter-pipeline-model-parallel \
+```
+- 支持VPP场景
+```shell
+# PP >= 2, num-layers-per-virtual-pipeline-stage不为None
+--pipeline-model-parallel-size ${PP} \
+--num-layers-per-virtual-pipeline-stage 1 \
+--use-multiparameter-pipeline-model-parallel \
+```
+
+## 使用效果
+
+采用PP支持多参数传递后,用户可以在保持高通信效率的同时,更灵活地处理复杂的多模态数据。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/multi_parameter_pipeline_and_variable_seq_lengths.md b/model/train/yoco_moe/docs/features/multi_parameter_pipeline_and_variable_seq_lengths.md
new file mode 100644
index 0000000000000000000000000000000000000000..9722eb479964f5ab2f6e95747bf2b94de16b9d82
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/multi_parameter_pipeline_and_variable_seq_lengths.md
@@ -0,0 +1,49 @@
+# PP支持多参数传递和动态形状
+
+## 背景与挑战
+
+在深度学习的大规模分布式训练中,流水线并行(Pipeline Parallelism, PP)通过将模型分割为多个阶段并在不同设备上并发执行来提高效率。然而,在处理复杂的多模态数据时,PP面临了新的挑战:
+
+- **对于多参数传递**:传统PP通常只涉及单一张量的传输,但在多参数传递的情况下,需要处理多个变量的传递,这不仅增加了通信复杂度,还要求对每个变量的shape、dtype等属性进行精确管理 。
+- **对于动态形状**:当输入数据的序列长度不固定时,传统的方法是将所有序列调整到统一长度,这导致了内存和计算资源的浪费 。
+
+## 解决方案
+
+为了应对这些挑战,开发了一系列优化措施:
+
+- **多参数传递**:开发了一套高效的通信机制,支持多种类型和格式的数据传输,并改进了反向传播算法,使得系统可以自动识别并处理来自多个输出的梯度信息 。
+- **动态形状**:引入对动态形状的支持,允许每个微批次中的序列保持其原始长度。这样可以通过在发送张量之前,提前通信张量的形状信息,在各个流水线阶段之间同步即将接收的数据形状,确保内存分配和预处理的准确性 。
+
+## 使用场景
+
+- **多参数传递**:适用于需要处理大量多模态数据的任务,如文本、图像和音频等大型神经网络训练任务,其中流水线并行的各个阶段都需要传递多参数 。
+- **动态形状**:非常适合于处理文本长度差异很大的任务,比如文档分类和机器翻译,同时也增强了模型的泛化能力 。
+
+## 使用方法
+
+**注意事项**:
+- 用户需代码配置`args.pipeline_tensor_shapes`明确指定各阶段间传递的具体参数及其属性(如shape、dtype)。
+- args.pipeline_tensor_shapes配置参考`tests_extend/system_tests/multi_modal/multi_parameter_pipeline/pretrain_multi_parameter_pipeline_test.py`
+
+
+**设置训练脚本参数**
+- # 支持PP场景
+```shell
+# 配置举例
+# PP >= 2
+--pipeline-model-parallel-size ${PP} \
+--use-multiparameter-pipeline-model-parallel \
+--variable-seq-lengths \
+```
+- # 支持VPP场景
+```shell
+# 配置举例
+# PP >= 2, num-layers-per-virtual-pipeline-stage不为None
+--pipeline-model-parallel-size ${PP} \
+--num-layers-per-virtual-pipeline-stage 1 \
+--use-multiparameter-pipeline-model-parallel \
+--variable-seq-lengths \
+```
+
+## 使用效果
+同时支持在流水线并行中各阶段间传递多个参数和处理变长输入数据。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/nanopipe-pipeline-parallel.md b/model/train/yoco_moe/docs/features/nanopipe-pipeline-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..fcc4a09894d230a73aff2010d924e660bc3312d4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/nanopipe-pipeline-parallel.md
@@ -0,0 +1,74 @@
+# nanopipe流水线并行
+
+## 问题分析
+
+流水线并行是AI大模型大规模分布式训练的关键组成部分之一,但其效率收到流水线中bubble的影响,为了提高吞吐量,需要降低其bubble比例。
+
+## 解决方案
+
+在大模型流水线调度中,反向的input梯度和weight梯度通常是一起调度计算的,然而,通过分析他们计算的依赖关系,可以发现其实只有input梯度的计算存在相互层间的依赖关系。因此,通过独立调度反向的input梯度和weight梯度的计算,我们可以减少流水线调度的bubble。
+
+反向input梯度和weight梯度一起调度的Interleaved 1F1B如下图所示:
+
+
+
+独立调度input梯度和weight梯度的nano-pipe如下图所示:
+
+
+
+独立调度weight计算展示图如下图所示:
+
+
+
+### 解决思路:
+
+* 分离weight梯度计算流程,通过修改RowParallelLinear和ColumnParallelLinear的backward实现,将对weight的梯度计算进行剥离,先存储在调度器的dw计算队列中。
+* 在需要对dw计算时,从调度器的dw计算队列中pop出一个计算,然后计算对应的梯度。
+
+## 使用场景
+
+在训练模型时,降低bubble的比例,从而提升计算效率,达到更好的流水线并行。此特性暂只适配`--use-legacy-models`。
+
+## 使用方法
+
+nanopipe依赖于vpp,设置`--num-layers-per-virtual-pipeline-stage N`。要求`--pipeline-model-parallel-size` > 2
+设置`--use-nanopipe`,默认为False,根据用户需求配置。
+
+## 使用效果
+
+提升计算效率,减少bubble占比。如下表所示:
+
+| device | TP | SP | PP | SEQ | hidden-size | Nano vs vpp收益 |
+| :-----: | :----: | :----: | :-----:| :----: | :----: | :-----: |
+| 单机 | 1 | 关 | 4 | 4096 | 4096 | 3.24% |
+| 双机 | 4 | 开 | 4 | 8192 | 8192 | 1.02% |
+
+# nanoswap
+
+## 问题分析
+
+使用nano时grad从前向到反向需要持续存储在npu上,生命周期过长,多次累加会增大npu内存的峰值。
+
+## 解决方案
+
+将过多的张量做offload动作存储到cpu上,在内存峰值过后再将其张量reload回npu上。
+
+### 解决思路
+
+在前向时将上一轮过多的张量offload到cpu,再在连续的反向运算中途reload回npu上,通过swap流控制不会让reload和offload出现顺序错误。
+
+完整nanopipe-swap原理图如下图所示:
+
+
+
+## 使用方法
+
+基于nanopipe的基础上再开启`--use-nanopipe-swap`。
+
+## 使用效果
+
+优化设备内存峰值,如下表所示:
+
+| device | TP | SP | PP | SEQ | hidden-size | mc2 | Nano内存峰值 |swap内存峰值 | Nano vs swap内存峰值下降 |
+| :-----: | :----: | :----: | :-----:| :----: | :----: | :-----: | :-----: | :-----: | :-----: |
+| 单机 | 2 | 开 | 4 | 1024 | 4096 | 开 | 5520.62 | 5177.72 | 6.21% |
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/nd-matmul.md b/model/train/yoco_moe/docs/features/nd-matmul.md
new file mode 100644
index 0000000000000000000000000000000000000000..ca390c7fae097e6a302887d1f6217705edb353c9
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/nd-matmul.md
@@ -0,0 +1,27 @@
+## ND_MatMul
+
+## 问题分析
+
+传统的1d张量并行中,输入数据仅在张量并行组内简单复制,造成单卡静态内存较大;同时,attention和ffn的partial输出结果都需要做一次all_reduce,这一部分通信开销较大。
+
+
+
+## 解决方案
+
+针对attention和ffn中的矩阵乘,将矩阵乘的参数同时进行行和列切分,即mp=mp_row*mp_col,同时在一个张量并行组内将输入x列切mp份,每张卡只保留输入数据的1/mp,通过插入更小规模的all_gather和reduce_scatter通信算子保证计算的准确性。算法原理图如下:
+
+
+
+## 使用方法
+
+设置`--use-nd-matmul`,打开ND_MatMul特性的总开关。
+
+设置`--nd1-dim1-size`,默认为1,需要确保`--nd1-dim1-size`能够被`--tensor-model-parallel-size`整除。
+
+设置`--nd2-dim1-size`,默认为1,需要确保`--nd2-dim2-size`能够被`--tensor-model-parallel-size`整除。
+
+示例:`--tensor-model-parallel-size`为32,`--nd1-dim1-size`可以设置为2、4、8、16,`--nd2-dim1-size`可以设置为2、4、8、16,出于性能考虑(建议`--nd1-dim1-size`或者`--nd2-dim1-size`大于等于8),可配置`--nd1-dim1-size`为8、`--nd2-dim1-size`为4。
+
+## 使用效果
+
+降低单卡显存占用效果明显,在`--nd1-dim1-size`或者`--nd2-dim2-size`较大(>8)时,相比megatron TP性能提升。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/noop-layers.md b/model/train/yoco_moe/docs/features/noop-layers.md
new file mode 100644
index 0000000000000000000000000000000000000000..ca8af14b31b9f5838c9de6d2de99a91771065c96
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/noop-layers.md
@@ -0,0 +1,30 @@
+## Ascend 自定义空操作层
+## 问题分析
+
+在神经网络训练过程中,初始层的嵌入(Embedding)操作及终端层的对数几率(Logits)计算通常属于计算密集型任务,这可能对整个网络的运行效率产生显著影响。具体而言:
+
+- **嵌入层(Embedding Layer)**:在处理文本或分类数据时,嵌入层将高维稀疏特征转换为低维稠密向量表示。此过程涉及索引查找和潜在的大规模矩阵乘法,特别是在自然语言处理应用中,词汇表可能包含数十万乃至数百万词条。高维度的查找与转换操作会消耗大量计算资源。
+
+- **对数几率层(Logits Layer)**:位于网络末端的对数几率层通常是一个全连接层,其功能是将最后一层的隐藏状态映射到输出空间,为后续损失函数计算提供未归一化的预测值。如果分类任务具有大量类别,那么该层的权重矩阵将非常庞大,导致矩阵乘法运算成为性能瓶颈。
+
+上述操作的计算复杂度随着输入特征数量和类别数量的增加而上升,可能导致训练速度降低,并且在计算资源有限的环境中形成性能瓶颈。
+
+## 解决方案
+
+为应对上述挑战,我们引入了“自定义空操作层”功能,允许用户通过指定特定层为“空操作层”(No-Op Layers)来动态调整模型在训练流水线中的计算负载。此机制有助于在多个计算节点间更均匀地分配工作负载,从而优化整体计算资源的利用。
+
+## 使用场景
+
+当用户遇到由于计算资源分配不均导致的性能瓶颈时,此功能尤为适用。通过对计算密集型任务进行重新分配,可以有效减少流水线中的空闲时间(即空泡),从而提高系统的吞吐量和效率。
+
+## 使用方法
+
+要启用此功能,用户需通过命令行参数设置目标层为无操作层。例如,原模型共126层,该模型参数为 `--num-layers 126`,即执行实际计算的层有126层。若在该模型首和尾各自添加1层空层,则该模型参数应设置为 `--num-layers 128 --noop-layers 0,127` 表示总共128层,首尾层(即第0层和第127层,层数从0开始计数)为不执行实际计算的空操作层,中间的126层为执行实际计算的层。
+
+## 使用效果
+
+通过实施自定义增加无操作层的策略,预期能够显著减少流水线中的空泡现象,从而优化计算流程并提升系统性能。这不仅有助于加速模型训练过程,还能最大化硬件资源的利用率。
+
+## 注意事项
+
+使用“Ascend 自定义空操作层”特性增加空层后总层数发生变化,需要根据包含空层的总层数重新调整流水线(虚拟流水线)的配置。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/norm-recompute.md b/model/train/yoco_moe/docs/features/norm-recompute.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d80faebde08d9a91db49e104b144fb1af6ddfd1
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/norm-recompute.md
@@ -0,0 +1,33 @@
+# Norm重计算
+
+## 问题分析
+
+大模型训练过程中,往往会面临的显存不足的问题。
+
+## 解决方案
+
+类似于激活函数重计算,本特性支持了Norm层的重计算。
+
+## 解决思路
+
+运用激活函数重计算特性中的 `checkpoint` 机制,对norm层进行重计算处理,具体细节如下文所示:
+[原文链接](https://www.usenix.org/conference/atc24/presentation/yuan)
+
+## 使用场景
+
+主要用于训练场景,用户内存不足或要进一步节省内存时。
+
+## 使用方法
+
+脚本中添加:`--recompute-norm` 可开启Norm重计算。此特性仅支持mcore分支。
+
+添加:`--recompute-norm-num-layers ${num}` 可指定Norm重计算的层数。
+
+Norm重计算兼容激活函数重计算、全重计算同时开启:
+
+1.同时开启时,仅支持 `--recompute-method` 为 `block`。
+
+2.同时开启时,会按照指定的全重计算和Norm重计算的层数做各自类型的重计算,即不会有一层既做全重计算又做Norm重计算。
+
+(注意:执行优先级是先计算全重计算层,后Norm重计算层。)
+
diff --git a/model/train/yoco_moe/docs/features/npu_deterministic.md b/model/train/yoco_moe/docs/features/npu_deterministic.md
new file mode 100644
index 0000000000000000000000000000000000000000..740f95df62810caefbfe4225a17ffac70a959875
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/npu_deterministic.md
@@ -0,0 +1,21 @@
+## Ascend 确定性计算
+## 问题分析
+
+在训练过程中,各类随机因素会使得每次实验的训练过程并不完全一致,从而影响其LOSS曲线、性能曲线等无法完全重合。
+然而,在重复实验与对比实验中有时需要确定性的计算结果,保证结果的可复现性。
+
+## 解决方案
+
+为满足上述需求,引入了“确定性计算”功能,允许用户通过昇腾(Ascend)芯片确保多次训练结果的一致性,从而帮助性能调优、对照实验等工作。
+
+## 使用场景
+
+需要进行性能对比、特定场景复现时。
+
+## 使用方法
+
+要启用此功能,在脚本中加入`--npu-deterministic`即可。
+
+## 使用效果
+
+通过确定性计算功能,可保证同参数下多次实验具有相同的实验结果。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/npu_matmul_add.md b/model/train/yoco_moe/docs/features/npu_matmul_add.md
new file mode 100644
index 0000000000000000000000000000000000000000..ebd3a28ca6d94ae537c1c5618e783effbda0270e
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/npu_matmul_add.md
@@ -0,0 +1,29 @@
+# matmul_add融合优化
+
+## 问题分析
+模型训练中开启了梯度累加功能,但累加效率较慢,梯度累加中的 Add 算子占比较高。
+
+## 解决方法
+MindSpeed将matmul操作和add操作合并成一个融合算子。算子接口见[link](../ops/npu_matmul_add.md)。
+
+## 使用场景
+llama、gpt大模型均使用。
+
+## 使用方法
+先安装CANN-NNAL并初始化添加环境,例如:
+CANN-NNAL默认安装路径
+source /usr/local/Ascend/nnal/atb/set_env.sh
+
+去掉`--no-gradient-accumulation-fusion`即可调用npu_matmul_add_fp32融合算子。
+
+## 使用效果
+在显存未打满情况下,开启融合算子,llama2_70B_4k_tp2_pp2_vpp1_dp2性能可提升1.5%。
+
+## 使用限制
+1.npu_matmul_add_fp32暂不支持mfu统计
+
+2.融合算子与小算子之间存在精度差异,精度差异的原因是:
+小算子dtype变化过程:`bf16*bf16=fp32->bf16->fp32+fp32=fp32`
+融合算子dtype变化过程:`bf16*bf16=fp32+fp32=fp32`
+差异点在于融合算子做了升精度的操作,故导致精度与小算子存在差异
+
diff --git a/model/train/yoco_moe/docs/features/ops_flops_cal.md b/model/train/yoco_moe/docs/features/ops_flops_cal.md
new file mode 100644
index 0000000000000000000000000000000000000000..c06328b540985982b952244d20937ed287636221
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/ops_flops_cal.md
@@ -0,0 +1,31 @@
+# TFLOPS计算
+
+## 问题分析
+
+当前大模型在计算MFU时,依靠框架理论打印值TFLOPS/有效算力得到,但是理论值计算适用于一般模型,如果针对模型结构进行变动,将不再适用,同时HFU的计算目前需要手动计算。
+
+## 解决方案
+
+提供接口可以统计所有涉及MatMul计算的算子的浮点计算次数,同时能统计到模型正反向训练以及重计算的总浮点计算次数。
+
+### 解决思路:
+
+目前支持的涉及MatMul的算子有MatMul、BatchMatMul、FlashAttention、MC2相关融合算子、coc相关融合算子、GEMM相关融合算子、matmul_add_fp32融合算子。
+
+## 使用方法
+
+对于开启此功能,设置`--op-cal-tflops`即可调用。
+
+## 使用效果
+
+通过打印值`actual throughput per NPU (TFLOP/s/NPU)`和`actual throughput per NPU with recompute (TFLOP/s/NPU)`可以方便计算MFU
+和HFU。
+
+## 注意事项
+
+(1)由于此功能统计的是各卡的TFOPS信息,在CP/EP/PP场景下,各卡计算量不同,因此在最后需要汇总各卡的信息进行平均,额外增加一个all_reduce通信。
+
+(2)使用此功能由于会增加一个额外通信以及计算各算子的浮点计算次数,可能影响性能。
+
+(3)由于在Ring Attention长序列并行方案中,在causal场景下,由于算法优化缘故,会有部分计算减少,因此会导致理论值和实际统计值不符合的现象,理论计算上FA
+的计算减少值为`(CP-1)/2CP`。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/pipeline-experts.md b/model/train/yoco_moe/docs/features/pipeline-experts.md
new file mode 100644
index 0000000000000000000000000000000000000000..5d9883623aae0eaab2352aa13f043079bdc895c4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/pipeline-experts.md
@@ -0,0 +1,80 @@
+# MLP通信隐藏
+
+## 问题分析
+
+大模型训练过程中,通信和计算往往存在依赖关系,这样的串行执行顺序会造成计算和通信流存在一定程度的空闲等待时间,导致执行效率较低。
+
+## 解决方案
+
+对通信和计算算子做更为细粒度的切分,保证细粒度间的计算和通信任务不存在依赖关系,是创造可并行执行任务的前提。
+
+再对算子调度/执行顺序进行编排,实现计算和通信的并行执行,在计算过程能掩盖中间部分的通信过程。
+
+
+
+### a. MLP通信隐藏:`--use-pipe-experts`
+开启后,将对每个experts进行细粒度切分,对前向和反向的执行顺序进行编排,实现通信和计算之间的掩盖,提高效率。
+
+### b. 多流水线:`--pipe-experts-multi-stream`
+需要在打开`--use-pipe-experts`的基础上开启使用。开启后,能够保证ep的alltoall通信和tp的allgather/reduce-scatter之间串行执行,避免集合通信出现链路冲突。
+
+### c. 多副本:`--pipe-experts-multi-data N`
+需要在打开`--use-pipe-experts`的基础上开启使用,`N`表示使用N份副本。开启后,能将输入数据切分为多个副本,将不同副本间的计算和通信类比为多个experts的计算和通信。
+
+## 使用场景
+
+在 local_experts 大于等于 2 时,可以考虑使用专家间的计算通信流水实现通信隐藏的目的。
+
+在 local_experts 等于 1 时,即 ep = num_expert 时,可以考虑使用多副本间的计算通信流水实现通信隐藏的目的。
+
+可开启多流水线`--pipe-experts-multi-stream`规避集合通信上出现的链路冲突。
+
+## 使用方法
+
+需要在保证开启了`--moe-model-type deepspeed_moe`的前提下,开启`--use-pipe-experts`才会生效。
+进一步,可以在`--use-pipe-experts`的前提下,单独或同时设置`--pipe-experts-multi-stream`和`--pipe-experts-multi-data N`来叠加使用“多流水线”和“多副本”的特性。
+
+## 使用效果
+
+使用该特性可以提升性能。
+
+8机, world_size = 64, sequence_len = 128k, num_layers = 4, recompute_granularity = full, hidden_size = 12288, moe_router_topk = 2, ep = 4, tp = 8, dp = 2, cp = 4, pp = 1, sp = True
+
+场景1:num_experts = 4 (num_local_experts = 1)
+
+| pipe-experts | multi-stream | multi-data | 平均TFLOPs | 提升幅度 |
+|:------------:|:------------:|:---------------:|:--------:|:------:|
+| 关 | 关 | 关 = 1 (Default) | 104.88 | / |
+| 开 | 关 | 开 = 2 | 108.01 | 2.99% |
+| 开 | 关 | 开 = 4 | 110.96 | 5.80% |
+| 开 | 开 | 开 = 2 | 110.21 | 5.08% |
+| 开 | 开 | 开 = 4 | 111.43 | 6.25%★ |
+
+场景2:num_experts = 16 (num_local_experts = 4)
+
+| pipe-experts | multi-stream | multi-data | 平均TFLOPs | 提升幅度 |
+|:------------:|:------------:|:---------------:|:--------:|:------:|
+| 关 | 关 | 关 = 1 (Default) | 103.15 | / |
+| 开 | 关 | 关 = 1 (Default) | 109.27 | 5.93% |
+| 开 | 关 | 开 = 2 | 109.20 | 5.86% |
+| 开 | 开 | 关 = 1 (Default) | 109.49 | 6.14%★ |
+| 开 | 开 | 开 = 2 | 108.32 | 5.01% |
+
+场景3:num_experts = 8 (num_local_experts = 2)
+
+| pipe-experts | multi-stream | multi-data | 平均TFLOPs | 提升幅度 |
+|:------------:|:------------:|:---------------:|:--------:|:-------:|
+| 关 | 关 | 关 = 1 (Default) | 103.98 | / |
+| 开 | 开 | 关 = 1 (Default) | 109.32 | 5.13%★ |
+| 开 | 开 | 开 = 2 | 108.38 | 4.23% |
+
+## 注意事项
+1、在开启`--pipe-experts-multi-data N`时,若`N`过大,导致输入数据切分过细,会引入多余的 cast 和 add 算子,导致额外的开销,引起性能恶化。
+2、目前 8 机推荐在 num_local_experts = 1 时开启`--pipe-experts-multi-data 4`来获得最佳性能,在 num_local_experts > 1
+时,不推荐开启`--pipe-experts-multi-data N`。
+3、单机,当 num_local_experts 为 1 或 2 时,`N`推荐设置为 2,当 num_local_experts 为 4 及以上时,不推荐开启多副本。
+4、`--pipe-experts-multi-data N`特性主要被用来提供 num_local_experts 为 1 时无法进行 experts 间的细粒度切分的替代方案。
+5、虽然兼容 num_local_experts > 1 的场景,开启后可以进一步提高计算通信掩盖比例,但会新引入 cast 和 add
+算子操作,当掩盖的收益不足以抵消新引入算子的拖慢时,就会导致性能恶化。
+6、在未开启SP`--sequence-parallel`时,无法开启多流水线`--pipe-experts-multi-stream`。
+7、未适配MoE token dropless特性。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/pipeline-parallel.md b/model/train/yoco_moe/docs/features/pipeline-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..f46f7164d41018c612caaf76cd766a4118144742
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/pipeline-parallel.md
@@ -0,0 +1,30 @@
+# 流水线并行
+
+## 问题分析
+
+在大模型时代,单一设备无法存储整个模型。模型并行可以在训练过程中将模型加载到多个设备上。在朴素的模型并行中,设备需要等待前一阶段的计算结果,导致计算资源的严重利用率不足。同时,设备需要储存计算的中间结果,存储开销大。
+
+## 解决方案
+
+采用流水线的思想,减少不同机器之间等待的时间。同时尽可能地缩短前向计算与反向计算之间的距离,以减少内存消耗
+
+### 解决思路:
+
+* 将整个网络分阶段(stage),不同阶段在不同的设备上,前后阶段流水分批工作,通过一种“接力”的方式并行。
+* 开始训练时,会先进行预热。预热完成后,每进行一个前向运算,就安排一个后向运算。最后进行冷却,完成剩余阶段。如下图所示
+
+
+
+[原文链接](https://arxiv.org/pdf/1806.03377)
+## 使用场景
+
+在训练模型时,为了降低单个设备的存储开销,提升计算效率,将模型加载到多卡来进行流水线并行。
+
+## 使用方法
+
+设置`--pipeline_model_parallel_size`,默认为1,根据用户需求配置。
+
+## 使用效果
+
+提升计算效率,减少内存消耗
+
diff --git a/model/train/yoco_moe/docs/features/recomputation.md b/model/train/yoco_moe/docs/features/recomputation.md
new file mode 100644
index 0000000000000000000000000000000000000000..eb0b5c282e33b6aa6132f7b1c03a0cd589ffec3d
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/recomputation.md
@@ -0,0 +1,32 @@
+# Megatron 重计算
+## 问题分析
+
+大模型训练过程中,通常要求保留前向计算的激活值用于后续的反向梯度计算,并且需要保存结果的数量会随着模型层数的增加线性增加,大大增加芯片的内存压力。
+
+## 解决思路
+
+在前向过程和loss计算时直接删除激活值,反向梯度计算需要用时再重新计算一遍激活值,从而有效缩短激活值的生命周期,缓解内存压力。
+
+## 使用场景
+主要用于训练场景,重计算分为:选择性重计算和完全重计算。
+
+选择性重计算(推荐使用):只重计算transformer中的core_attention部分,将占用较少内存存储空间且重计算开销较高的激活保留在内存中,并将占用较多内存存储空间但重新计算开销相对较低的激活重新计算。
+
+完全重计算:对于内存非常有限场景,仅将输入保存,重新计算所有激活值。
+
+## 使用方法
+
+选择性重计算:脚本中添加`--recompute-activations`开启选择性重计算。
+
+完全重计算:脚本中配置`--recompute-granularity full`开启完全重计算,开启完全重计算时使用`--recompute-method uniform/block` 确认具体重计算方式。
+
+`--recompute-method uniform`:将Transformer层均匀划分组(每组大小`--recompute-num-layers`),按组存储输入和激活值。
+
+`--recompute-method block`:将前`--recompute-num-layers`个transformer层重计算,剩余层不进行重计算。
+
+同时配置`--recompute-activations` 、`--recompute-granularity full`生效选择性重计算。
+
+当脚本配置了`--recompute-method block`、`--recompute-granularity full`、`--num-layers-per-virtual-pipeline-stage N`参数时,用户可以通过`--recompute-num-layers N`参数来配置每个vpp stage做多少层重计算,参数`--enable-recompute-layers-per-pp-rank`可用于修改此情况下`--recompute-num-layers N`参数的语义,新的语义表示无视vpp,按每个pp stage来配置重计算层数。
+
+## 使用影响
+显存开销降低、性能降低。
diff --git a/model/train/yoco_moe/docs/features/recompute_independent_pipelining.md b/model/train/yoco_moe/docs/features/recompute_independent_pipelining.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f2511a6047da7b90a8deaae480c0a7eef095f39
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/recompute_independent_pipelining.md
@@ -0,0 +1,33 @@
+# 重计算流水线独立调度
+## 问题分析
+
+在目前的流水线调度中,重计算由反向计算触发,与反向计算绑定在一起调度,意味着重计算需要等待下一个stage返回梯度才可以开始计算。然而重计算并不需要用到反向计算的梯度,这导致bubble的增多和性能的下降。
+
+## 解决方案
+
+为了将重计算和反向计算独立调度,需要将重计算的调度修改为由调度器主动触发,并修改调度器,将重计算作为一个调度单元加入到调度器中,这使我们获得了自由地插入或去除部分重计算的能力,进而可以在内存和性能方面做出优化。
+
+### 解决思路
+通过torch的saved_tensors_hooks实现一种新的重计算方法,在反向计算前合适的时机主动触发或者直接去除部分重计算,从而实现对内存或性能的优化。
+
+## 使用场景
+
+在pipelining_with_interleaving调度中,若用户未开启重计算,则可以利用bubble主动插入重计算,以极小的性能代价换取内存峰值的降低,将需要保留激活值的前向计算块的个数减少到pp * vp。
+ 
+
+在pipelining_with_interleaving调度中,若用户已开启重计算,则可以通过解除重计算与后一个stage的反向计算间的依赖关系从而提前重计算,以及去除模型最后一层的重计算,实现计算性能的提升。
+ 
+
+## 使用方法
+
+脚本中添加: --recompute-in-bubble 可开启利用bubble进行重计算功能,实现内存节省。
+使用条件:必须开启虚拟流水并行特性,使用此功能前不能开启重计算,recompute_num_layers参数需为None或0。
+
+脚本中添加: --recompute-in-advance 可开启提前重计算以及去除不必要重计算功能,实现训练性能提升。
+使用条件:必须开启虚拟流水并行特性,使用此功能前需要开启重计算,且不支持recompute_method为uniform,recompute_num_layers不能为None或0。
+
+#### 注意:
+
+1. --recompute-in-bubble 特性暂不兼容完全重计算uniform、完全重计算block、选择重计算、自适应选择重计算特性,不兼容moe场景下的--moe-adaptive-recompute-activation、--moe-layer-recompute特性。
+2. --recompute-in-advance 特性暂不兼容完全重计算uniform、选择重计算、自适应选择重计算特性。
+3. --recompute-in-bubble和--recompute-in-advance两者不可同时开启。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/reuse-fp32-param.md b/model/train/yoco_moe/docs/features/reuse-fp32-param.md
new file mode 100644
index 0000000000000000000000000000000000000000..c5ae38fb989f68f48837e68dca2e8b773bc26851
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/reuse-fp32-param.md
@@ -0,0 +1,46 @@
+# 参数副本复用
+
+## 问题分析
+
+当前所有大模型训练客户场景都采用混精训练,导致在训练过程中持续保存了计算权重与状态权重。实际上,两者所使用的生命周期并没有重叠,故可通过数值变换的方式**等价**去除冗余。
+
+
+## 解决方案
+
+基于大模型混合精度训练中BF16计算参数(负责前后向计算)及FP32参数副本(负责参数更新)不需同时存在的特点,和两者在数值上的对应关系,设计内存共用算法。
+
+
+
+### 解决思路
+
+具体算法步骤如下:
+1. FP32 = BF16 + Residual;
+2. 前向计算开始前将FP32转换为BF16并保存Residual;
+3. 优化器更新前基于BF16和Residual恢复FP32参数并进行更新;
+4. 使用int32加减法来等价模拟原始逻辑中FP32<->BF16的相互转换(IEEE745向偶数舍入)。
+
+
+
+参数副本复用流程如下图所示:
+ 
+
+数值变化的详细逻辑如下图所示:
+ 
+
+## 使用场景
+
+1. 该特性主要用于使用BF16的训练场景。
+
+## 使用方法
+
+设置`--reuse-fp32-param`,即可调用该算法。
+
+## 使用效果
+
+1. 对于Float16OptimizerWithFloat16Params,整体能够节省`sizeof(bfloat16)*模型参数量`的静态内存,性能劣化在多个模型上测试小于1%。
+2. 对于开启分布式优化器的训练,整体能够节省`sizeof(bfloat16)*模型参数量 / DP`的静态内存,性能劣化在多个模型上测试小于1%。
+
+## 注意事项
+
+1. 使用legacy model训练时,`reuse_fp32_param`暂不支持和`--overlap-param-gather`一起使用。
+2. 使用fused_ema_adamw优化器时,不支持同时开启`reuse_fp32_param`。
diff --git a/model/train/yoco_moe/docs/features/ring-attention-context-parallel.md b/model/train/yoco_moe/docs/features/ring-attention-context-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..becbb1eb4e7e1505e6b00a55aa280c2937b2d9dc
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/ring-attention-context-parallel.md
@@ -0,0 +1,47 @@
+# Ring Attention长序列并行
+
+## 问题分析
+
+从生成性AI到科研模型,长序列训练正在变得非常重要。 在生成性AI领域,会话式AI、长文档摘要和视频生成等任务都需要在空间和时间层面对长上下文进行推理。 同样,章节和书籍级别的摘要(数万甚至数十万字)在会话式AI和摘要任务中也非常重要。现有的数据、张量和流水线等并行方法无法在序列维度进行切分。当序列维度(S)增长时,训练内存开销会以 $O$($S^2$) 的速度增长。因此需要针对长序列场景进行特定的优化解决长训练场景的训练需求。
+
+## 解决方案
+
+支持Ring Attention长序列并行方案,以此解决序列维度扩展问题。具体细节参见原文:
+> Ring Attention with Blockwise Transformers for Near-Infinite Context (https://arxiv.org/pdf/2310.01889)
+
+### 解决思路:
+
+Ring Attention借鉴了分块Softmax原理,在不需要获取整个序列的完整矩阵情况下进行分块attention计算。因此作者提出以分块方式执行自注意力和前馈网络计算,跨多个设备分布序列维度。具体地,该方法在进程之间构建注意力计算块的环状通信结构(Ring),每个进程具有一个切分后的本地QKV块。在计算完本地的attention后,通过向后发送和向前获取KV块,遍历进程设备环,以逐块的方式进行注意力和前馈网络计算。同时,本地的attention计算和KV块的通信理想情况下可以互相掩盖,从而消除了额外引入的通信开销。另外该方案在计算attention的过程中全程不需要数据拼接,支持的序列长度理论上可以无限拓展。
+
+## 使用场景
+
+当使用GPT类模型进行训练,同时数据进MoE层时实际序列长度8K以上。
+
+不同于Ulysses方案,该方案不需要确保head_size被cp_size整除。
+
+可兼容FlashAttention,目前已默认开启FlashAttention。
+
+如果想要使得计算和通信可以互相掩盖,理论上需要确保每个计算块分到的序列长度$c \geq F/B$。其中F是每个device的FLOPS,B是每个device间的带宽。具体推导过程参见原文。在实践中,需要确保每个计算块分到的序列长度足够大,才能较好掩盖。
+
+
+## 使用方法
+
+| 重要参数 | 参数说明 |
+|-------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| --context-parallel-size [int] | 开启CP对应的数量,默认为1,根据用户需求配置。 |
+| --seq-length [int] | 输入序列的长度。 |
+| --use-cp-send-recv-overlap | 建议开启,开启后支持send receive overlap功能。 |
+| --attention-mask-type [general/causal] | 可选,设置Mask计算类型,默认是causal(倒三角)Mask计算,设置general代表全量计算。 |
+| --context-parallel-algo megatron_cp_algo | 长序列并行算法选项,默认项为`ulysses_cp_algo`,当设置为`megatron_cp_algo`时开启Ring Attention。 |
+| --megatron-cp-in-bnsd | 开启后,FA使用BNSD计算。 |
+
+## 使用效果
+
+利用多个计算设备对输入序列进行并行切分,降低单设备的内存消耗,相比不开启序列并行单步耗时增加,相比重计算计算效率提升。
+
+
+## 注意事项:
+
+1. 开启Context Parallel时需要同时开启Flash Attention特性,否则特性不支持。
+2. 在使用GPT类模型进行训练的场景下,建议attention-mask-type设置为causal。
+3. 在8k的序列长度情况下,由于计算的时间缩短,cp功能分割之后的send receive的时间反而会长于计算时间,造成性能的下降,所以建议配置seq-length / context-parallel-size> 8k以获取最佳效果。具体公式参考:S/(Talpha) >= 1/(Wbeta),其中,S=seq-length / context-parallel-size, T表示芯片的理论算力,alpha表示计算效率,W表示理论通信带宽,beta表示带宽利用率。
diff --git a/model/train/yoco_moe/docs/features/rms_norm.md b/model/train/yoco_moe/docs/features/rms_norm.md
new file mode 100644
index 0000000000000000000000000000000000000000..9612bdc350df8f2252f9612fd1bfc4f465b72398
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/rms_norm.md
@@ -0,0 +1,15 @@
+# rms_norm融合优化
+## 问题分析
+rms_norm常见于LLaMA、LLaMA2、Baichuan等LLM模型中用于归一化,由于torch侧没有提供rms_norm算子的接口,因此在模型中通常是以自定义的形式出现,这种形式的执行效率相对较低。
+
+## 解决方法
+MindSpeed对将rms_norm操作合并成一个算子,减少数据传输和临时存储。算子接口见[link](../ops/rms_norm.md)。
+
+## 使用场景
+模型使用rms_norm作为归一化方式,脚本中设置了`--normalization RMSNorm`。
+
+## 使用方法
+设置`--use-fused-rmsnorm`即可调用rms_norm融合算子。mcore分支下仅支持使能该融合算子。
+
+## 使用效果
+开启融合算子可以节省内存,提升性能。
diff --git a/model/train/yoco_moe/docs/features/rotary-embedding.md b/model/train/yoco_moe/docs/features/rotary-embedding.md
new file mode 100644
index 0000000000000000000000000000000000000000..f3afd4d11c73ec9d532dca72813daea998f50c87
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/rotary-embedding.md
@@ -0,0 +1,25 @@
+# Rotary Postion Embedding 融合优化
+
+## 问题分析
+
+Rotary Position Embedding(RoPE)是一种大模型文本位置信息编码(Position Embedding)的解决方案。RoPE通过绝对位置编码的形式实现了相对位置信息的注入,融合了绝对和相对位置编码的优点,同时具备较好的长度外推性。目前RoPE方案已经被较多的大模型采用,例如LLaMA和GLM。
+
+然而,目前torch并没有针对RoPE做特定的实现和优化,在模型侧通常是通过自定义的方式实现,且Rotary Embedding的计算方式较为复杂,实现方式的计算和内存开销需要优化。
+
+## 解决方案
+`torch_npu`侧将Rotary Embedding操作合并成一个算子,减少数据传输和临时储存,优化模型训练性能。MindSpeed调用`torch_npu`侧接口实现算子融合。
+
+## 使用场景
+
+模型侧使用了Rotary Embedding作为Position Embedding解决方案。
+
+## 使用方法
+
+首先确保`--position-embedding-type`选项设置为`rope`。
+
+同时开启`--use-fused-rotary-pos-emb`选项,以启用融合算子。
+
+## 使用效果
+
+使用融合算子可以提升训练性能。
+
diff --git a/model/train/yoco_moe/docs/features/sequence-parallel.md b/model/train/yoco_moe/docs/features/sequence-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..84c7f294ddd02bff4973e2de37c21ed51ed834b4
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/sequence-parallel.md
@@ -0,0 +1,33 @@
+# 序列并行
+
+## 问题分析
+
+张量模型并行可以降低显存占用,加快训练速度,但是它要求将模型各层划分为独立的、可管理的块,所以不适用于 LayerNorm 和 Dropout 等操作。虽然 LayerNorm 和 Dropout 等操作的计算成本很低,但它们确实需要大量冗余内存来存储激活。为了分摊张量并行中无法切分的显存和计算,引入了序列并行的方法。
+
+## 解决方案
+
+在张量模型并行的基础上,进一步对 LayerNorm 和 Dropout 模块的序列维度进行切分。
+
+### 解决思路:
+
+将 LayerNorm 以及 Dropout 等操作的输入按序列维度进行了切分,使得各个设备上面只需要做一部分的 Dropout 和 LayerNorm 等操作即可。
+
+为了方便理解,以下图为例:假设输入$X$的大小为$ s \times b \times h $,按照序列维度切分$X=[X_1^s,X_2^s]$,经过LayerNorm操作后的结果为$Y=[Y_1^s,Y_2^s]$,随后进行张量模型并行。
+
+
+
+[原文链接](https://arxiv.org/pdf/2205.05198)
+
+## 使用场景
+
+使用训练模型时,将模型加载到多卡,使用张量模型并行后显存依旧占用过高或超出了处理器显存限制,或者训练时间过长,可以开启序列并行来降低显存占用,加快训练速度。
+
+## 使用方法
+
+首先确保训练参数中加入`--tensor-model-parallel-size N`,设置张量模型并行。
+
+同时添加`--sequence-parallel`,开启序列并行。
+
+## 使用效果
+
+利用多个设备,降低显存开销,加快训练速度。
diff --git a/model/train/yoco_moe/docs/features/shared-experts.md b/model/train/yoco_moe/docs/features/shared-experts.md
new file mode 100644
index 0000000000000000000000000000000000000000..91c510fabc1f5c38e7f4c942d365c8efe066e4cd
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/shared-experts.md
@@ -0,0 +1,28 @@
+# 共享专家特性
+
+## 方案介绍:
+
+随着混合专家模型MoE的演进,产生了路由专家和共享专家的概念。针对路由专家,输入数据会经过路由模块选择概率较高的专家进行计算;而对于共享专家,输入数据无需经过路由模块计算,所有数据都会经过共享专家计算。路由专家和共享专家的计算结果相加后作为MoE模块最终的计算结果。
+
+通过将共享专家和路由专家结合,MOE模型能够在不同的输入情况下既关注到输入数据的共性也能关注到输入数据的差异性,从而提高模型的泛化能力。
+
+共享专家如下图c所示(参考论文:https://arxiv.org/pdf/2401.06066 ):
+
+
+## 使用场景
+
+MoE场景下使用:`--moe-model-type megatron_moe`
+
+## 使用方法
+
+共享专家相关命令和参数说明:
+
+| 命令参数 | 参数说明 |
+|--------------------------|------------------------|
+| `--n-shared-experts [int]` | 共享专家数量 |
+
+## 注意事项
+
+1. 开启共享专家需要在mcore模式下,即没有设置`--use-legacy-models`
+
+2. 共享专家中间隐藏层大小的配置命令与路由专家相同:`--ffn-hidden-size [int]`
diff --git a/model/train/yoco_moe/docs/features/smart_swap.md b/model/train/yoco_moe/docs/features/smart_swap.md
new file mode 100644
index 0000000000000000000000000000000000000000..7142eb6f7f0cca94148752ea38eef6a1de8bdbed
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/smart_swap.md
@@ -0,0 +1,39 @@
+# SmartSwap
+
+## 问题分析
+
+在用户训练过程中产生的OOM问题,现有的内存方案主要为重计算和Swap两个方法。重计算会增加计算开销,而Swap路线需要用户自己编写和控制异步换入换出时机和内存管理,增加较多的使用成本。
+
+## 解决方案
+
+为了在最大限度地利用计算设备显存的同时,提高模型训练的性能,我们支持通过自适应迭代生成Swap策略,这一特性称为SmartSwap。
+
+此功能通过数据采样,策略生成,策略执行等流程的循环迭代,选择有限次数验证下的最优策略。
+在迭代中分为3个阶段。
+
+- WarmUp阶段,仅执行数据采样。采集Tensor生命周期信息供后续分析。此时OOM时会通过覆盖底层的内存异常,使得模型能够继续运行。
+- SearchPolicy阶段,执行数据采样和策略执行。 在策略生成中,包括候选内存过滤,内存策略生成,内存模拟排布等步骤。
+- Stable阶段,仅执行策略执行。在策略执行中,通过多流异步执行内存Swap,掩盖对计算流的耗时影响。
+
+
+
+## 使用场景
+
+1. OOM场景:当前训练配置下,出现OOM报错;可开启此功能,将OOM报错拦截,自动生成Swap策略,使训练在可用最大显存内运行。
+2. 非OOM场景:当前训练配置下,未出现OOM报错;可开启此功能,根据配置文件中的减少显存值,自动生成Swap策略,使训练在指定显存内运行。
+3. 重计算的替代场景:减少模型代码中的重计算生效范围,节省重计算过程。
+
+## 使用方法
+
+1. 在训练脚本中添加此功能的使能参数:`--smart-swap`。
+2. (可选)修改此功能的配置文件`mindspeed/core/memory/smart_swap/swap_policy_config.py`进行调试。
+
+## 使用效果
+
+1. 通过减少TP和PP数,获得性能收益;例如在llama2(8p,pp1,seqlen 8K,layer 32),将tp8改为tp1,性能收益25%;
+2. 通过关闭或者部分关闭全重计算,获得性能收益;例如在llama2(8p,pp1,seqlen 16K,layer 40),将全重计算关闭,性能收益28%;
+
+## 注意事项
+
+1. SmartSwap适配静态序列场景;暂未适配动态场景,例如MOE类场景。
+2. SmartSwap将占用Host内存,例如单机8卡,若每卡均换出`30 GB`到Host,则单机至少需要Host内存`8*30=240 GB`。
diff --git a/model/train/yoco_moe/docs/features/swap_attention.md b/model/train/yoco_moe/docs/features/swap_attention.md
new file mode 100644
index 0000000000000000000000000000000000000000..72003deb93595c601cc41e4493a5a601ff10698f
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/swap_attention.md
@@ -0,0 +1,51 @@
+# swap-attention
+
+## 问题分析
+
+大模型训练过程中,使用重计算功能可以大幅度减少内存,但会增加训练过程的计算时长,导致执行效率较低。
+
+## 解决方案
+
+新增swap-attention功能,利用设备内存和CPU内存来存放激活值,在梯度反传的同时从CPU内存预取激活值来减少重计算,充分利用H2D高带宽的优势以网补存、以网强算,提升MFU,加速大模型的训练。
+
+
+
+## 使用场景
+
+### a. 优化性能:
+
+在需要开启全重计算的场景下,可以通过开启`--swap-attention`和`--recompute-num-layers [int]`替换全重计算,以达到提升性能的目的。
+
+### b. 内存节省:
+
+对于不需要重计算的场景,只开启`--swap-attention`,可以在几乎不损耗性能的情况下,节省内存,以支持更大的模型的配置。
+
+
+## 使用方法
+
+需要添加参数`--swap-attention`。使用前提是开启flash attention融合算子。
+
+可选参数`--swap-modules`:参数类型为string,默认值为"input_norm,self_attention,post_attention_norm",可根据模型自行配置module,在mcore场景下默认仅预取self_attention module。
+
+### a. 仅开启预取功能:`--swap-attention`
+
+开启后,将对每一层的attention层的激活值进行预取,提高计算效率。
+
+
+
+### b. 开启预取功能并且指定重计算层数:`--swap-attention`和`--recompute-num-layers [int]`
+
+开启后,将对每一层的attention层的激活值进行预取,同时,对前[int]层的全连接层进行重计算。
+
+
+
+## 使用效果
+
+与完全重计算相比 ,有性能收益;
+与不重计算相比,有内存收益;
+
+## 注意事项:
+
+1. `--recompute-num-layers [int]`中的[int]层数指的是每一个pp stage的层数。[int]的取值应该小于等于num-layers/pipeline-model-parallel-size.
+2. 暂不兼容自适应选择重计算特性.
+3. 若出现性能严重劣化,可能是跨NUMA内存访问引起,可尝试通过进程绑核缓解,实现方法可参考[mstt绑核工具](https://gitee.com/ascend/mstt/tree/master/profiler/affinity_cpu_bind)
diff --git a/model/train/yoco_moe/docs/features/swiglu.md b/model/train/yoco_moe/docs/features/swiglu.md
new file mode 100644
index 0000000000000000000000000000000000000000..8653eb7d7536c41d01da837c507d6808640c303d
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/swiglu.md
@@ -0,0 +1,17 @@
+# swiglu融合优化
+
+## 问题分析
+swiglu常见于LLaMA、LLaMA2、Baichuan等大模型中的激活层,由于torch侧没有提供swiglu算子的接口,因此在模型中通常是以小算子的形式出现,这种形式的执行效率相对较低。
+
+## 解决方法
+MindSpeed将swiglu操作合并成一个融合算子,减少数据传输和临时存储。算子接口见[link](../ops/swiglu.md)。
+
+## 使用场景
+模型使用swiglu作为MLP层激活函数,脚本中设置了`--swiglu`。
+
+## 使用方法
+设置`--use-fused-swiglu`即可调用swiglu融合算子。mcore分支下仅支持使能该融合算子。
+
+## 使用效果
+开启融合算子可以节省内存,提升性能。
+
diff --git a/model/train/yoco_moe/docs/features/tensor-parallel-2d.md b/model/train/yoco_moe/docs/features/tensor-parallel-2d.md
new file mode 100644
index 0000000000000000000000000000000000000000..f88473c0152758b148d4e9cd407114f6d962f4c0
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/tensor-parallel-2d.md
@@ -0,0 +1,156 @@
+# 高维张量并行
+
+## 问题分析
+
+大模型训练时,张量并行(TP)将模型参数切分到多个设备上以减少其内存的占用,在训练过程中为了更新参数梯度信息等,需要引入allreduce通信。当集群规模较大时,如果设置TP域很大时,其通信开销会变得很大,使得训练效率降低。
+
+## 解决方案
+
+为了提高大规模TP域通信效率,采用高维张量并行,其将激活值和参数同时切分到多个计算设备上,相对1D-TP降低了通信域、减少通信次数,从而减少通信时间,提升模型训练的性能。
+
+### 解决思路
+
+#### 2D张量并行策略
+
+给定TP域大小,通过建立多通信域,在原Megatron(ColumnParallelLinear、RowParallelLinear)增加了一维的切分维度。将原tp通信域进行分解为两个子通信域tp_x和tp_y,需要满足`tp = tp_x * tp_y`。以MLP层为例,其实现过程如下:
+
+
+
+#### 分布式normalization
+
+在transformer网络中,normalization会将每一层神经元的输入都转成均值方差都一样的,加快其收敛。在MLP和attention层分别进行2D张量并行时,其输入和输出都分别在first-dim和last-dim做了tp_x和tp_y的切分,如果继续使用原LayerNorm或者RMSNorm需要先将input进行沿first-dim进行all-gather(x)和沿last-dim进行all-gather(y)操作,才能保证input数据的完整性。为了提升这部分的性能,采用了分布式normalization。其处理流程如下:
+
+##### **步骤1:计算输入的总和**
+
+首先,计算输入张量$\mathbf{x}$ 在最后一个维度上的总和:
+
+$$
+e_x = \sum_{i=1}^{H} x_i
+\
+$$
+
+##### **步骤2:分布式归约操作(All-Reduce)**
+
+将步骤1中的总和 $e_x$ 在所有tp_y通信域进程中进行归约(求和),确保每个进程都拥有其通信域全局总和:
+$$
+\
+e_x^{\text{global}} = \text{AllReduce}\left( e_x \right) = \sum_{p=1}^{P} \sum_{i=1}^{H} x_i^{(p)}
+\
+$$
+
+其中:
+- $P$ 是分布式进程的数量。
+- $x_i^{(p)}$ 表示第 $p$ 个进程中第 $i$ 个元素的值。
+
+##### **步骤3:计算输入元素的平方和**
+
+接下来,计算输入张量每个元素的平方和:
+
+$$
+s_x = \sum_{i=1}^{H} x_i^2
+$$
+
+##### **步骤4:分布式归约操作(All-Reduce)**
+
+将步骤3中的平方和 $s_x$ 在所有tp_y通信域进程中进行归约(求和),确保每个进程都拥有其通信域全局平方和:
+
+$$
+s_x^{\text{global}} = \text{AllReduce}\left( s_x \right) = \sum_{p=1}^{P} \sum_{i=1}^{H} \left( x_i^{(p)} \right)^2
+$$
+
+##### **步骤5:中心化输入数据**
+
+将输入数据 $\mathbf{x}$ 中心化,即减去平均值。平均值 $\mu$ 计算如下:
+
+$$
+\mu = \frac{e_x^{\text{global}}}{H}
+$$
+
+然后,中心化输入:
+
+$$
+x'_i = x_i - \mu \quad \forall i \in \{1, 2, \dots, H\}
+$$
+
+##### **步骤6:计算总和的平方**
+
+计算全局总和的平方:
+
+$$
+e_x'^2 = \left( e_x^{\text{global}} \right)^2
+$$
+
+##### **步骤7:计算归一化因子**
+
+计算归一化因子 $\gamma$,用于标准化输入数据。公式如下:
+
+$$
+\gamma = \frac{1}{\sqrt{ \left( \frac{s_x^{\text{global}}}{H} \right) - e_x'^2 + \epsilon }}
+$$
+
+这里:
+- $\frac{s_x^{\text{global}}}{H}$ 是全局平方和的平均值。
+- $e_x'^2$ 是全局总和的平方。
+- $\epsilon$ 是一个小常数,防止分母为零,增加数值稳定性。
+
+##### **步骤8:标准化输入数据**
+
+将中心化后的输入数据 $\mathbf{x}'$ 与归一化因子 $\gamma$ 相乘,得到标准化后的数据 $\mathbf{\hat{x}}$:
+
+$$
+\hat{x}_i = x'_i \cdot \gamma \quad \forall i \in \{1, 2, \dots, H\}
+$$
+
+##### **步骤9:应用权重和偏置**
+
+最后,将标准化后的数据与权重向量 $\mathbf{W}$ 相乘,并根据是否存在偏置向量 $\mathbf{b}$ 来决定最终输出。
+
+- **如果存在偏置**:
+
+$$
+\text{output}_i = b_i + W_i \cdot \hat{x}_i \quad \forall i \in \{1, 2, \dots, H\}
+$$
+
+- **如果不存在偏置**:
+
+$$
+\text{output}_i = W_i \cdot \hat{x}_i \quad \forall i \in \{1, 2, \dots, H\}
+$$
+
+
+## 使用场景
+
+当TP通信域需要设置较大时,通信效率较低,需要通过分解通信域来提升其通信效率。
+
+## 使用方法
+
+在训练脚本的参数列表中加入 `--tp-2d`,开启2D张量并行,`--tp-x N1`和`--tp-y N2`分别设置其x轴、y轴的切分大小,其中需满足`tp = N1 * N2`(N1 > 1, N2 > 1)。
+
+其他优化参数,用于辅助高维张量并行特性进行通信隐藏,需要开启tp-2d时生效:
+- `--enable-overlap-ag-with-matmul`: 在linear层forward计算时,开启all-gather通信和matmul进行隐藏,以便加速
+- `--enable-overlap-matmul-with-rs`: 在linear层forward计算时,开启matmul计算和reduce-scatter通信进行隐藏,以便加速
+- `--coc-fused-kernel`: 在linear层forward计算时,开启计算通信融合算子,将matmul计算与all-gather、reduce-scatter都进行算子级融合,实现进一步加速(该特性不与前两个特性兼容,依赖ATB加速库)
+- `--enable-backward-overlap-ag-with-matmul`: 在linear层backward计算梯度时,开启all-gather通信和matmul进行隐藏,以便加速(该特性依赖ATB加速库)
+
+上述3个forward计算优化参数`--enable-overlap-ag-with-matmul`、`--enable-overlap-matmul-with-rs`、`--coc-fused-kernel`只能同时开启1个。
+
+注意事项:
+
+Megatron-MOE支持情况如表1所示
+
+**表1** 高维张量并行支持Megatron-MOE
+
+| Dispather | MLP | 支持情况 | 性能优化 |
+|-----------|----------------|------|------|
+| AllGather | Sequential MLP | ✅ | ❌ |
+| AllGather | Grouped MLP | ✅ | ❌ |
+| All2All | Sequential MLP | ❌ | ❌ |
+| All2All | Grouped MLP | ❌ | ❌ |
+
+当前高维张量并行特性不与`--sequence-parallel`、`--use-fused-rmsnorm`等特性相兼容,请根据实际情况调整配置。
+
+## 使用效果
+
+在llama3-405B模型训练时,tp=16情况下,开启2D张量并行,tp_x=8,tp_y=2,相比原Megatron 1D张量并行性能提升5%+。
+开启coc-fused-kernel和enable-backward-overlap-ag-with-matmul通信计算融合优化后,进一步提升性能5%+。
+其他场景下,由于计算效率和通信组的划分差异,需根据tp_x和tp_y实际调优情况进行配置,部分配置不能保证效率提升。
diff --git a/model/train/yoco_moe/docs/features/tensor-parallel.md b/model/train/yoco_moe/docs/features/tensor-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..5632722d9f508078d8f96d7b4ef95161e092438a
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/tensor-parallel.md
@@ -0,0 +1,35 @@
+# 张量并行
+
+## 问题分析
+
+随着模型越来越庞大,其尺寸远远超出了处理器内存的限制,并且模型训练时间也变得很长。所以需要把一个模型切分,每个计算设备只负责一部分模型的存储与计算。
+
+## 解决方案
+
+张量并行将模型分成多份并存储在多个计算设备上,这样模型的权重和优化器状态可以分布在多个计算设备上,以此来克服单个计算设备显存无法容纳整个大型模型的问题。并且因为各计算设备只需要处理一部分的模型计算,训练速度也得到显著提高。这种分片策略叫做张量并行。
+
+### 解决思路
+
+#### 参数矩阵横切
+
+1.参数矩阵横切策略按照参数矩阵的行来切分模型,该切分策略需要将输入矩阵也进行按列切分。
+2.横切策略前向时,先切分输入矩阵,对应部分的输入矩阵进入对应部分的模型进行前向计算,之后用all-reduce操作来将各部分模型计算结果累加得到最终前向计算结果。
+3.横切策略反向时,可以计算得出最终输出的梯度和各部分模型的输出梯度相等,先将最终输出的梯度传到各部分模型的输出张量,再用all-gather操作将切分后的输入矩阵的梯度拼接得到最初输入矩阵的梯度。
+
+#### 参数矩阵纵切
+
+1.参数矩阵纵切策略按照参数矩阵的列来切分模型,该切分策略输入矩阵无需进行切分。
+2.纵切策略前向时,先将输入矩阵送入各部分模型,各部分模型分别进行前向计算得到输出结果,之后用all-gather操作来将各部分模型输出结果拼接得到最终前向计算结果。
+3.纵切策略反向时,先将最终输出的梯度进行切分并将对应的部分传到对应部分模型的输出张量,之后用all-reduce操作将各部分模型的输入矩阵的梯度累加得到最初输入矩阵的梯度。
+
+## 使用场景
+
+如果用户发现训练显存占用过高或超出了处理器显存限制,或者训练时间过长,可以开启张量并行来降低单设备显存占用,加快训练速度。
+
+## 使用方法
+
+在训练脚本的参数列表中加入 `--tensor-model-parallel-size N`,设置张量并行的size。
+
+## 使用效果
+
+利用多个设备,降低显存占用,加快训练速度。
diff --git a/model/train/yoco_moe/docs/features/ulysses-context-parallel.md b/model/train/yoco_moe/docs/features/ulysses-context-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..2e9b101efd8d835793a47949ecad4b4a71d92454
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/ulysses-context-parallel.md
@@ -0,0 +1,31 @@
+# Ulysses长序列并行
+
+## 问题分析
+
+从生成式AI到科研模型,长序列训练正在变得非常重要。 在生成式AI领域,会话式AI、长文档摘要和视频生成等任务都需要在空间和时间层面对长上下文进行推理。 同样,章节和书籍级别的摘要(数万甚至数十万字)在会话式AI和摘要任务中也非常重要。现有的数据、张量和流水线等并行方法无法解决序列维度的扩展问题。
+
+## 解决方案
+
+支持 Ulysses长序列并行方案,以此解决序列维度扩展问题。
+
+### 解决思路:
+
+Ulysses 将各个样本在序列维度上分割给参与的计算设备。然后,在 attention 计算之前,它对已分割的查询(Q)、键(K)和值(V)执行 all-to-all 通信操作,以使每个计算设备接收完整的序列,但仅用于注意力头的非重叠子集。这使得参与的计算设备可以并行计算不同的注意力头。最后,Ulysses 还使用另一个 all-to-all 来在注意力头上收集结果,同时重新在序列维度上进行分区。
+
+## 使用场景
+
+num_head 要能被 tp_size*cp_size 整除。
+
+## 使用方法
+
+设置`--context-parallel-size`,默认为1,根据用户需求配置。
+同时设置`--context-parallel-algo ulysses_cp_algo`。
+
+## 使用效果
+
+利用多个计算设备对输入序列进行并行切分,降低单设备的内存消耗,相比不开启序列并行单步耗时增加,相比重计算计算效率提升。
+
+## 鸣谢
+
+1.GitHub项目地址:
+https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/unaligned-ulysses-context-parallel.md b/model/train/yoco_moe/docs/features/unaligned-ulysses-context-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..9eb16a9e1ba07ef370464c208eab7cd32b0c3a79
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/unaligned-ulysses-context-parallel.md
@@ -0,0 +1,57 @@
+# 非对齐Ulysses长序列并行
+
+## 背景与挑战
+
+随着生成式AI和科研模型领域的发展,长序列训练变得越来越重要。然而,传统的Ulysses设计要求序列长度(sequence length)必须能够被长序列并行大小(Context Parallel size, CP size)整除。这在处理动态或不规则输入时带来了限制,特别是在多模态应用中,输入数据的序列长度可能无法预测且经常变化。因此,需要一种机制来支持这些非对齐情况下的操作,以适应更广泛的应用场景。
+
+
+## 解决方案
+
+为了解决传统Ulysses设计在处理非对齐序列长度时的局限性,“非对齐 Ulysses”机制通过引入一个抽象基类 `GatherSizeCalculator` 来提供计算 gather size 的接口。Gather size 通常指的是经过 (Ulysses 机制中的)all-to-all 通信后,输出张量在 `gather_idx` 维度上的大小。该基类定义了任何具体实现都必须提供的 `calculate()` 方法,用于返回整数形式的 gather size 或者 None。
+
+基于此接口,实现了两种具体的策略:`DefaultGatherSizeCalculator` 和 `DynamicGatherSizeCalculator`。前者默认返回 None,意味着使用对齐的Ulysses长序列并行;后者则根据当前批次的注意力掩码序列长度动态计算 gather size。这种设计使得系统能够灵活应对不同场景的需求,尤其是在多模态领域中处理 sequence length 不能被 CP size 整除的情况时尤为重要。
+
+此外,在 `UlyssesContextAttention` 类中,允许用户注入一个 `gather_size_calculator` 实例,使得系统能够灵活地选择不同的 gather size 计算方法,从而适应不同场景的需求。
+
+## 使用场景
+
+“非对齐 Ulysses”功能适用于以下几种典型场景:
+
+- **多模态学习**:当处理图像、视频、文本等多种类型的数据时,由于不同类型数据的序列长度差异较大,难以统一到固定的CP size。
+- **实时数据分析**:在处理流数据时,数据到达的时间不确定,导致每次处理的序列长度也可能不同。
+- **个性化推荐系统**:用户行为数据的序列长度通常各不相同,这种情况下也需要支持非对齐的操作。
+
+## 使用方法
+
+为了利用“非对齐 Ulysses”功能,用户可以根据业务需求传入基于 `GatherSizeCalculator` 基类的自定义 Calculator,或者直接使用预定义的 `DynamicGatherSizeCalculator`。以下是基本步骤:
+
+1. 启动脚本中配置长序列并行大小大于1`--context-parallel-size [int]`。 同时配置`--context-parallel-algo ulysses_cp_algo`。
+2. 创建一个继承自 `GatherSizeCalculator` 的自定义计算器类,并实现 `calculate()` 方法。在初始化 `UlyssesContextAttention` 对象时,通过构造函数参数传入自定义的 `gather_size_calculator` 实例。
+3. 如果不需要复杂的自定义逻辑,可以直接使用 `DynamicGatherSizeCalculator`,它会自动根据当前批次的注意力掩码序列长度计算 gather size。
+
+```python
+# 示例代码
+import megatron.core.parallel_state as ps
+from mindspeed.core.context_parallel.ulysses_context_parallel import UlyssesContextAttention, GatherSizeCalculator, DynamicGatherSizeCalculator
+from your_library import FlashSelfAttention
+
+# 自定义 GatherSizeCalculator
+class CustomGatherSizeCalculator(GatherSizeCalculator):
+ def calculate(self, *args, **kwargs):
+ # 示例逻辑
+ return kwargs.get("gather_size", None)
+
+
+core_attention = FlashSelfAttention()
+# 根据实际情况,使用预定义DynamicGatherSizeCalculator()或自定义CustomGatherSizeCalculator()
+calculator = DynamicGatherSizeCalculator()
+ulysses_attention = UlyssesContextAttention(core_attention, ps.get_context_parallel_group(),
+ gather_size_calculator=calculator)
+
+```
+*说明*:
+“非对齐 Ulysses”长序列并行暂不兼容Ulysses长序列并行KV缓存优化,即启动脚本设置了--context-parallel-kv-cache-policy为full或者half,系统将自动切换回使用对齐的Ulysses长序列并行机制。
+
+## 使用效果
+
+通过引入“非对齐 Ulysses”,系统提升了对不同输入长度的适应能力。这不仅解决了传统 Ulysses 在处理动态或不规则输入序列时遇到的问题,而且保持了良好的扩展能力。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/unaligned_linear.md b/model/train/yoco_moe/docs/features/unaligned_linear.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3df9e3d9ccb84ac0512aa2b667524d5658369bc
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/unaligned_linear.md
@@ -0,0 +1,35 @@
+# unaligned linear 非对齐线性层
+
+## 背景与挑战
+
+类Megatron-LM框架已成为大模型训练的主流方案之一,TP(张量并行 Tensor Parallism)是大模型训练的基本并行范式,该范式在部分场景仍存在不足,例如要求大模型的注意力头数、序列长度要能整除TP,不满足条件将在参数校验中抛出异常;本特性提供了一种注意力头数、序列长度不能整除TP的的解决方案;
+
+## 解决方案
+
+- **序列长度不能整除TP**:和pad方案(将序列长度pad到TP的整数倍)不同,该方案通过序列分配策略来解决,小于 **(seq_len%tp_size)** 的tp卡分配 **(seq_len//tp_size+1)** 序列长度,其他分配 **(seq_len//tp_size)** 序列长度,例如seq_len=1026,tp_size=4, tp0和tp1分配的序列长度为257,tp2和tp3分配的序列长度为256;
+- **注意力头数不能整除TP**:和上述方案类似,小于 **(num_attention_heads%tp_size)** 的tp卡分配 **(num_attention_heads//tp_size+1)** 个注意力头,其他卡分配 **(num_attention_heads//tp_size)** 注意力头,例如num_attention_heads=25,tp_size=4, tp0分配的注意力头为7个,tp1、tp2和tp3分配的注意力头均为6个;值得注意的是,模型的注意力相关权重TP切分和头数相关,假设hidden_size=3200, qkv_weight大小为(9600,3200)[MHA], dense_weight大小为(3200,3200), tp0的qkv权重大小为(2688,3200),dense权重大小为(3200,896), tp1、tp2和tp3的qkv权重大小为(2304, 3200),dense权重大小为(3200,768);GQA的权重切分方案按num_query_groups比例分配,注意;
+
+## 使用场景
+
+- 序列长度不能整除TP、注意力头数不能整除TP
+
+## 使用方法
+
+在模型参数中添加 --unaligned-linear 参数
+
+**注意事项:**
+- 非对齐的线性层会引起各TP的负载不均衡
+- 该特性不支持mc2、2d张量并行、cp特性(要求TP*CP能被注意力头数整除)等
+- 特殊的模型结构需要特殊适配该特性,当前已适配MHA、GQA结构,暂不支持MOE、MLA等结构
+
+**设置训练脚本参数**
+```shell
+# 开启非对齐线性层
+--unaligned-linear \
+```
+
+## 使用效果
+- **补充功能场景**:补充注意力头数、序列长度不能被TP整除的场景。
+- **潜在性能影响**:各TP处理的注意力头数、序列长度不一致,负载不均衡,建议模型结构设计时考虑该情况。
+
+综上所述,该特性是为了完善TP(张量并行)场景下的限制约束,特性本身会带来负载不均衡的性能影响,所以在模型设计和超参优化时注意这一影响。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/variable_seq_lengths.md b/model/train/yoco_moe/docs/features/variable_seq_lengths.md
new file mode 100644
index 0000000000000000000000000000000000000000..408692620264d24d33bb373527e0cc4dc7e866e3
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/variable_seq_lengths.md
@@ -0,0 +1,40 @@
+# PP支持动态形状
+
+## 背景与挑战
+
+在深度学习模型训练中,尤其是涉及多模态任务时,输入数据的序列长度往往不是固定的。对于采用流水线并行(Pipeline Parallelism, PP)策略的模型,
+处理不同长度的序列通常需要将所有序列调整为统一长度,通过填充或截断来实现。这种做法虽然简化了数据处理和模型设计,但会导致计算资源和内存的浪费,特别是在处理较短序列时,因为需要大量的填充。
+**主要挑战:**
+- **内存效率低下**:可能存在大量填充导致内存利用率低。
+- **计算效率低下**:对填充部分进行不必要的计算。
+
+## 解决方案
+
+为了应对上述挑战,我们引入了对动态形状的支持,允许每个微批次中的序列保持其原始长度。此功能通过在发送张量之前,提前通信张量的形状信息,在各个流水线阶段之间同步即将接收的数据形状,确保内存分配和预处理的准确性。
+## 使用场景
+
+- **多变长度文本处理**:如文档分类、机器翻译等任务,其中文本长度差异很大。
+- **增强模型泛化能力**:让模型更好地适应各种长度的输入,从而提高其在实际应用中的表现。
+
+## 使用方法
+
+**注意事项:**
+- 当采用流水线并行策略且序列长度固定时,启用该特性将增加不必要的通信开销,因此不建议使用。
+- 密切监控训练过程中的内存消耗,避免因序列长度变动引起的溢出问题。
+
+**设置训练脚本参数**
+```shell
+# 开启流水线并行, PP >= 2
+--pipeline-model-parallel-size ${PP} \
+# 开启PP支持动态形状
+--variable-seq-lengths
+```
+
+## 使用效果
+
+- **优化资源利用**:与传统方法中所有序列需填充至统一长度相比,本方案通过减少不必要的填充操作,有效节省内存空间,降低计算负载,提高整体性能。
+- **提高灵活性**:该特性赋予模型更强的适应性,使其能够高效处理各种长度的输入数据,进而增强了模型的泛化能力。这对于需要处理变长输入的任务(如文本分类、机器翻译等)尤为重要。
+- **更真实的数据表示**:保留了原始文本的真实长度,有助于模型更准确地捕捉文本特征。
+- **潜在性能影响**:尽管有诸多优点,但在某些情况下(如开启流水线并行,并且原序列为等长或需被截断以保持一致长度时),启用该特性可能会增加复杂度并减慢训练速度。因此,在设计和部署时应综合考虑这些因素,确保系统整体性能最优化。
+
+综上所述,PP支持动态形状是针对特定应用场景的一种有效优化手段,它能够在保证模型性能的同时,显著改善资源利用率和数据处理的灵活性。用户应根据实际情况权衡利弊,决定是否启用这一特性。
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/features/virtual-pipeline-parallel.md b/model/train/yoco_moe/docs/features/virtual-pipeline-parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..5c02b425eb38bcce4d2e4bfa83c49f5935eb80f3
--- /dev/null
+++ b/model/train/yoco_moe/docs/features/virtual-pipeline-parallel.md
@@ -0,0 +1,43 @@
+# 虚拟流水线并行
+
+## 问题分析
+
+Pipedream流水线并行切分粒度过大,运行过程中仍然有许多空泡(bubble),计算资源利用率仍有提高空间。
+
+## 解决方案
+
+将计算进一步细分,减少空泡。
+
+### 解决思路:
+
+在设备数量不变的情况下,分出更多的流水线阶段,以更多的通信量,换取空泡比率降低。
+
+
+
+[原文链接](https://people.eecs.berkeley.edu/~matei/papers/2021/sc_megatron_lm.pdf)
+
+为了方便理解,举一个例子:假设模型层数为16,张量并行大小为1,流水线并行大小为4,虚拟流水线并行大小为2。模型会被分为 4 * 2 = 8 个阶段,每个阶段 16 / 8 = 2 个层。
+
+ Device 0: [1, 2] [9, 10]
+ Device 1: [3, 4] [11, 12]
+ Device 2: [5, 6] [13, 14]
+ Device 3: [7, 8] [15, 16]
+
+前向的顺序为 device 0 -> device 1 -> device 2 -> device 3 -> device 0 -> device 1 -> device 2 -> device 3
+
+## 使用场景
+
+想要进一步减小空泡比率,提升性能
+
+## 使用方法
+
+设置`--num-layers-per-virtual-pipeline-stage N`。表示每个阶段的层数。要求模型的总层数 L % N == 0。要求 `--pipeline-model-parallel-size` > 2。
+
+## 使用效果
+
+空泡比率进一步减小
+
+## 注意事项
+
+Megatron虚拟流水并行vpp影响权重切分方式,保存、加载权重时需保证vpp配置一致,才能正常加载;
+
diff --git a/model/train/yoco_moe/docs/ops/README.md b/model/train/yoco_moe/docs/ops/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e17e98284c0c119c213ee634cad80e0a3b1e66da
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/README.md
@@ -0,0 +1,30 @@
+# How to run the ops?
+
+## previous installation
++ CANN
++ CANN-NNAL(Ascend-Transformer-Boost)
++ torch_npu
+
+## compile and install
+### 1. set the environment variables
+ ```shell
+# Default path, change it if needed.
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+ ```
+#### if use Ascend-Transformer-Boost
+ ```shell
+# Default path, change it if needed.
+source /usr/local/Ascend/nnal/atb/set_env.sh
+ ```
+
+### 2. include head files
+
++ newest torch_npu
++ newest cann
+
+### 3. install scripts
+```shell
+python3 setup.py build
+python3 setup.py bdist_wheel
+pip3 install dist/*.whl --force-reinstall
+```
diff --git a/model/train/yoco_moe/docs/ops/ffn.md b/model/train/yoco_moe/docs/ops/ffn.md
new file mode 100644
index 0000000000000000000000000000000000000000..261d765f01f4c4b0b9c59171c2758a9c92a097b3
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/ffn.md
@@ -0,0 +1,230 @@
+# ffn对外接口(只支持前向)
+
+npu_ffn(Tensor x, Tensor weight1, Tensor weight2, str activation, *, Tensor? expert_tokens=None,
+ Tensor? expert_tokens_index=None, Tensor? bias1=None, Tensor? bias2=None, Tensor? scale=None,
+ Tensor? offset=None, Tensor? deq_scale1=None, Tensor? deq_scale2=None, Tensor? antiquant_scale1=None,
+ Tensor? antiquant_scale2=None, Tensor? antiquant_offset1=None, Tensor? antiquant_offset2=None,
+ int? inner_precise=None, ScalarType? output_dtype=None) -> Tensor
+
+计算逻辑:
+ - **非量化场景:**
+
+ $$
+ y=activation(x * W1 + b1) * W2 + b2
+ $$
+
+ - **量化场景:**
+
+ $$
+ y=((activation((x * W1 + b1) * deqScale1) * scale + offset) * W2 + b2) * deqScale2
+ $$
+
+ - **伪量化场景:**
+
+ $$
+ y=activation(x * ((W1 + antiquantOffset1) * antiquantScale1) + b1) * ((W2 + antiquantOffset2) * antiquantScale2) + b2
+ $$
+
+**说明:**
+ 激活层为geglu/swiglu/reglu时,性能使能需要满足门槛要求,即整网中FFN结构所对应的小算子中vector耗时30us且占比10%以上的用例方可尝试FFN融合算子;或在不知道小算子性能的情况下,尝试使能FFN,若性能劣化则不使能FFN。
+
+## 非量化场景:
+输入:
+- x:必选输入,公式中的输入x,数据类型int8, float16, bfloat16,支持输入的维度最少是2维[M, K1],最多是8维
+- weight1: 必选输入,专家的权重数据,公式中的W1,数据类型int4, int8, float16, bfloat16,输入在有/无专家时分别为[E, K1, N1]/[K1, N1]
+- weight2: 必选输入,专家的权重数据,公式中的W2,数据类型int4, int8, float16, bfloat16,输入在有/无专家时分别为[E, K2, N2]/[K2, N2]
+ **说明:**
+ M表示token个数,对应transform中的BS(B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度);K1表示第一组matmul的输入通道数,对应transform中的H(Head-Size)表示隐藏层的大小);N1表示第一组matmul的输出通道数;K2表示第二组matmul的输入通道数;N2表示第二组matmul的输出通道数,对应transform中的H;E表示有专家场景的专家数。
+- activation: 必选输入,代表使用的激活函数,公式中的activation,当前支持fastgelu/gelu/relu/silu以及geglu/swiglu/reglu
+- expert_tokens: 可选输入,数据类型int64
+- expert_tokens_index:可选输入,数据类型int64
+ **说明:**
+ 不能同时输入expert_tokens和expert_tokens_index
+ expert_tokens,expert_tokens_index,若不为空时可支持的最大长度为256个
+- bias1: 可选输入,权重数据修正值,公式中的b1,数据类型int32, float16, float32,输入在有/无专家时分别为[E, N1]/[N1]
+- bias2: 可选输入,权重数据修正值,公式中的b2,数据类型int32, float16, float32,输入在有/无专家时分别为[E, N2]/[N2]
+- inner_precise:可选输入,表示高精度或者高性能选择,数据类型支持int64, 该参数仅对float16生效,bfloat16和int8不区分高精度和高性能。
+ - innerPrecise为0时,代表开启高精度模式,算子内部采用float32数据类型计算
+ - innerPrecise为1时,代表高性能模式
+
+输出:
+- y:必选输出,数据类型float16, bfloat16
+
+## 全量化场景:
+输入:
+- x:必选输入,公式中的输入x,数据类型int8, float16, bfloat16,支持输入的维度最少是2维[M, K1],最多是8维
+- weight1: 必选输入,专家的权重数据,公式中的W1,数据类型int4, int8, float16, bfloat16,输入在有/无专家时分别为[E, K1, N1]/[K1, N1]
+- weight2: 必选输入,专家的权重数据,公式中的W2,数据类型int4, int8, float16, bfloat16,输入在有/无专家时分别为[E, K2, N2]/[K2, N2]
+ **说明:**
+ M表示token个数,对应transform中的BS(B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度);K1表示第一组matmul的输入通道数,对应transform中的H(Head-Size)表示隐藏层的大小);N1表示第一组matmul的输出通道数;K2表示第二组matmul的输入通道数;N2表示第二组matmul的输出通道数,对应transform中的H;E表示有专家场景的专家数。
+- activation: 必选输入,代表使用的激活函数,公式中的activation,当前支持fastgelu/gelu/relu/silu以及geglu/swiglu/reglu
+- expert_tokens: 可选输入,数据类型int64
+- expert_tokens_index:可选输入,数据类型int64
+ **说明:**
+ 不能同时输入expert_tokens和expert_tokens_index
+ expert_tokens,expert_tokens_index,若不为空时可支持的最大长度为256个
+- bias1: 可选输入,权重数据修正值,公式中的b1,数据类型int32, float16, float32,输入在有/无专家时分别为[E, N1]/[N1]
+- bias2: 可选输入,权重数据修正值,公式中的b2,数据类型int32, float16, float32,输入在有/无专家时分别为[E, N2]/[N2]
+- scale: 可选输入,量化参数,量化缩放系数,数据类型float32,per-tensor下输入在有/无专家时均为一维向量,输入元素个数在有/无专家时分别为[E]/[1];per-channel下输入在有/无专家时为二维向量/一维向量,输入元素个数在有/无专家时分别为[E, N1]/[N1]
+- offset: 可选输入,量化参数,量化偏移量,数据类型float32,一维向量,输入元素个数在有/无专家时分别为[E]/[1]
+- deq_scale1:可选输入,量化参数,第一组matmul的反量化缩放系数,数据类型uint64, int64, float32, bfloat16,输入在有/无专家时分别为[E, N1]/[N1]
+- deq_scale2:可选输入,量化参数,第二组matmul的反量化缩放系数,数据类型uint64, int64, float32, bfloat16,输入在有/无专家时分别为[E, N2]/[N2]
+- inner_precise:可选输入,表示高精度或者高性能选择,数据类型支持int64, 该参数仅对float16生效,bfloat16和int8不区分高精度和高性能。
+ - innerPrecise为0时,代表开启高精度模式,算子内部采用float32数据类型计算
+ - innerPrecise为1时,代表高性能模式
+- output_dtype:可选输入,表示输出y的数据类型,为空时输出y的数据类型为float16,不为空时支持float16, bfloat16
+
+输出:
+- y:必选输出,数据类型float16, bfloat16
+
+## 伪量化场景:
+输入:
+- x:必选输入,公式中的输入x,数据类型int8, float16, bfloat16,支持输入的维度最少是2维[M, K1],最多是8维
+- weight1: 必选输入,专家的权重数据,公式中的W1,数据类型int4, int8, float16, bfloat16,输入在有/无专家时分别为[E, K1, N1]/[K1, N1]
+- weight2: 必选输入,专家的权重数据,公式中的W2,数据类型int4, int8, float16, bfloat16,输入在有/无专家时分别为[E, K2, N2]/[K2, N2]
+ **说明:**
+ M表示token个数,对应transform中的BS(B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度);K1表示第一组matmul的输入通道数,对应transform中的H(Head-Size)表示隐藏层的大小);N1表示第一组matmul的输出通道数;K2表示第二组matmul的输入通道数;N2表示第二组matmul的输出通道数,对应transform中的H;E表示有专家场景的专家数。
+- activation: 必选输入,代表使用的激活函数,公式中的activation,当前支持fastgelu/gelu/relu/silu以及geglu/swiglu/reglu
+- expert_tokens: 可选输入,代表各专家的token数,数据类型int64
+- expert_tokens_index:可选输入,代表各专家的token数,数据类型int64
+ **说明:**
+ 不能同时输入expert_tokens和expert_tokens_index
+ expert_tokens,expert_tokens_index,若不为空时可支持的最大长度为256个
+- bias1: 可选输入,权重数据修正值,公式中的b1,数据类型int32, float16, float32,输入在有/无专家时分别为[E, N1]/[N1]
+- bias2: 可选输入,权重数据修正值,公式中的b2,数据类型int32, float16, float32,输入在有/无专家时分别为[E, N2]/[N2]
+- antiquant_scale1: 可选输入,伪量化参数,第一组matmul的缩放系数,数据类型float16, bfloat16,per-channel下输入在有/无专家时分别为[E, N1]/[N1],per-in-group下输入在有/无专家时分别为[E, G, N1]/[G, N1]
+- antiquant_scale2: 可选输入,伪量化参数,第二组matmul的缩放系数,数据类型float16, bfloat16,per-channel下输入在有/无专家时分别为[E, N2]/[N2],per-in-group下输入在有/无专家时分别为[E, G, N2]/[G, N2]
+- antiquant_offset1: 可选输入,伪量化参数,第一组matmul的偏移量,数据类型float16, bfloat16,per-channel下输入在有/无专家时分别为[E, N1]/[N1],per-in-group下输入在有/无专家时分别为[E, G, N1]/[G, N1]
+- antiquant_offset2: 可选输入,伪量化参数,第二组matmul的偏移量,数据类型float16, bfloat16,per-channel下输入在有/无专家时分别为[E, N2]/[N2],per-in-group下输入在有/无专家时分别为[E, G, N2]/[G, N2]
+ **说明:**
+ G表示伪量化per-in-group场景下,antiquantOffsetOptional、antiquantScaleOptional的组数。
+- inner_precise:可选输入,表示高精度或者高性能选择,数据类型支持int64, 该参数仅对float16生效,bfloat16和int8不区分高精度和高性能。
+ - innerPrecise为0时,代表开启高精度模式,算子内部采用float32数据类型计算
+ - innerPrecise为1时,代表高性能模式
+
+输出:
+- y:必选输出,数据类型float16, bfloat16
+
+## 约束与限制
+
+- 有专家时,专家数据的总数需要与x的M保持一致。
+- 激活层为geglu/swiglu/reglu时,仅支持无专家分组时的float16高性能场景(float16场景指类型为aclTensor的必选参数数据类型都为float16的场景),且N1=2\*K2。
+- 激活层为gelu/fastgelu/relu/silu时,支持有专家或无专家分组的float16高精度及高性能场景,bfloat16场景,量化场景及伪量化场景,且N1=K2。
+- 非量化场景不能输入量化参数和伪量化参数,量化场景不能输入伪量化参数,伪量化场景不能输入量化参数。
+- 量化场景参数类型:x为int8、weight为int8、bias为int32、scale为float32、offset为float32,其余参数类型根据y不同分两种情况:
+ - y为float16,deqScale支持数据类型:uint64、int64、float32。
+ - y为bfloat16,deqScale支持数据类型:bfloat16。
+ - 要求deqScale1与deqScale2的数据类型保持一致。
+- 量化场景支持scale的per-channel模式参数类型:x为int8、weight为int8、bias为int32、scale为float32、offset为float32,其余参数类型根据y不同分两种情况:
+ - y为float16,deqScale支持数据类型:uint64、int64。
+ - y为bfloat16,deqScale支持数据类型:bfloat16。
+ - 要求deqScale1与deqScale2的数据类型保持一致。
+- 伪量化场景支持两种不同参数类型:
+ - y为float16、x为float16、bias为float16,antiquantScale为float16、antiquantOffset为float16,weight支持数据类型int8和int4。
+ - y为bfloat16、x为bfloat16、bias为float32,antiquantScale为bfloat16、antiquantOffset为bfloat16,weight支持数据类型int8和int4。
+- 当weight1/weight2的数据类型为int4时,其shape最后一维必须为偶数。
+- 伪量化场景,per-in-group下,antiquantScale1和antiquantOffset1中的K1需要能整除组数G,antiquantScale2和antiquantOffset2中的K2需要能整除组数G。
+- 伪量化场景,per-in-group下目前只支持weight是int4数据类型的场景。
+- innerPrecise参数在bfloat16非量化场景,只能配置为0;float16非量化场景,可以配置为0或者1;量化或者伪量化场景,0和1都可配置,但是配置后不生效。
+
+## npu_ffn 接口的调用方式
+```python
+import os
+import torch
+import torch_npu
+import torchair as tng
+from torchair.configs.compiler_config import CompilerConfig
+from mindspeed.ops import ffn
+
+os.environ["ENABLE_ACLNN"] = "true"
+config = CompilerConfig()
+npu_backend = tng.get_npu_backend(compiler_config=config)
+
+ACTIVATION_TYPE = 'relu'
+
+EXPERT_TOKENS_MODE_NONE = 0
+EXPERT_TOKENS_MODE_TOKENS = 1
+EXPERT_TOKENS_MODE_TOKENS_INDEX = 2
+
+
+class FFNModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, weight1, weight2, activation, expert_tokens=None, expert_tokens_index=None,
+ bias1=None, bias2=None, scale=None, offset=None, deq_scale1=None, deq_scale2=None,
+ antiquant_scale1=None, antiquant_scale2=None, antiquant_offset1=None, antiquant_offset2=None,
+ inner_precise=0):
+ return ffn.npu_ffn(x, weight1, weight2, activation,
+ expert_tokens=expert_tokens, expert_tokens_index=expert_tokens_index,
+ bias1=bias1, bias2=bias2, inner_precise=inner_precise)
+
+
+def test_ffn(tokens_mode, is_graph_mode=True):
+ M = 512
+ K1 = 256
+ N1 = 1024
+ K2 = N1
+ N2 = K1
+
+ dtype = torch.float16
+ bias_dtype = torch.float16 if dtype == torch.float16 else torch.float32
+
+ expert_tokens = None
+ expert_tokens_index = None
+
+ if tokens_mode == EXPERT_TOKENS_MODE_NONE:
+ x = torch.empty(M, K1, dtype=dtype).uniform_(-1.0, 1.0)
+ weight1 = torch.empty(K1, N1, dtype=dtype).uniform_(-0.1, 0.1)
+ weight2 = torch.empty(K2, N2, dtype=dtype).uniform_(-0.1, 0.1)
+ bias1 = torch.empty(N1, dtype=bias_dtype).uniform_(-0.1, 0.1)
+ bias2 = torch.empty(N2, dtype=bias_dtype).uniform_(-0.1, 0.1)
+ elif tokens_mode == EXPERT_TOKENS_MODE_TOKENS:
+ E = 8
+ x = torch.empty(M, K1, dtype=dtype).uniform_(-1.0, 1.0)
+ weight1 = torch.empty(E, K1, N1, dtype=dtype).uniform_(-0.1, 0.1)
+ weight2 = torch.empty(E, K2, N2, dtype=dtype).uniform_(-0.1, 0.1)
+ bias1 = torch.empty(E, N1, dtype=bias_dtype).uniform_(-0.1, 0.1)
+ bias2 = torch.empty(E, N2, dtype=bias_dtype).uniform_(-0.1, 0.1)
+ expert_tokens = [64, 64, 64, 64, 64, 64, 64, 64]
+ expert_tokens = torch.tensor(expert_tokens, dtype=torch.int64)
+ elif tokens_mode == EXPERT_TOKENS_MODE_TOKENS_INDEX:
+ E = 8
+ x = torch.empty(M, K1, dtype=dtype).uniform_(-1.0, 1.0)
+ weight1 = torch.empty(E, K1, N1, dtype=dtype).uniform_(-0.1, 0.1)
+ weight2 = torch.empty(E, K2, N2, dtype=dtype).uniform_(-0.1, 0.1)
+ bias1 = torch.empty(E, N1, dtype=bias_dtype).uniform_(-0.1, 0.1)
+ bias2 = torch.empty(E, N2, dtype=bias_dtype).uniform_(-0.1, 0.1)
+ expert_tokens_index = [64, 128, 192, 256, 320, 384, 448, 512]
+ expert_tokens_index = torch.tensor(expert_tokens_index, dtype=torch.int64)
+
+ x = x.npu()
+ weight1 = weight1.npu()
+ weight2 = weight2.npu()
+ bias1 = bias1.npu()
+ bias2 = bias2.npu()
+
+ if expert_tokens != None:
+ expert_tokens = expert_tokens.npu()
+ if expert_tokens_index != None:
+ expert_tokens_index = expert_tokens_index.npu()
+
+ if is_graph_mode:
+ model = FFNModel().npu()
+ model = torch.compile(model, backend=npu_backend, dynamic=True)
+ y = model(x, weight1, weight2, ACTIVATION_TYPE, expert_tokens=expert_tokens,
+ expert_tokens_index=expert_tokens_index, bias1=bias1, bias2=bias2)
+ else:
+ y = ffn.npu_ffn(x, weight1, weight2, ACTIVATION_TYPE, expert_tokens=expert_tokens,
+ expert_tokens_index=expert_tokens_index, bias1=bias1, bias2=bias2)
+ print('y.shape:', y.shape)
+
+
+if __name__ == '__main__':
+ test_ffn(EXPERT_TOKENS_MODE_NONE, True)
+ test_ffn(EXPERT_TOKENS_MODE_TOKENS, True)
+ test_ffn(EXPERT_TOKENS_MODE_TOKENS_INDEX, True)
+ test_ffn(EXPERT_TOKENS_MODE_NONE, False)
+ test_ffn(EXPERT_TOKENS_MODE_TOKENS, False)
+ test_ffn(EXPERT_TOKENS_MODE_TOKENS_INDEX, False)
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/fusion_attention.md b/model/train/yoco_moe/docs/ops/fusion_attention.md
new file mode 100644
index 0000000000000000000000000000000000000000..d12745fdf8ede4834ab1f5968732772ffbe4e6c3
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/fusion_attention.md
@@ -0,0 +1,146 @@
+# fusion attention 对外接口
+
+### 注意当前若要使用v2版本接口,需要开启`--use-fusion-attn-v2`特性
+
+npu_fusion_attention(
+ query, key, value, head_num,
+ input_layout, *, pse=None,
+ padding_mask=None, atten_mask=None,
+ scale=1., keep_prob=1., pre_tokens=2147483647,
+ next_tokens=2147483647, inner_precise=0, prefix=None,
+ actual_seq_qlen=None, actual_seq_kvlen=None,
+ sparse_mode=0, gen_mask_parallel=True,
+ sync=False, pse_type=1, q_start_idx=None,
+ kv_start_idx=None)
+
+- 计算公式:
+
+ 注意力的正向计算公式如下:
+
+ - pse_type=1时,公式如下:
+
+ $$
+ attention\\_out = Dropout(Softmax(Mask(scale*(pse+query*key^T), atten\\_mask)), keep\\_prob)*value
+ $$
+
+ - pse_type=其他取值时,公式如下:
+
+ $$
+ attention\\_out=Dropout(Softmax(Mask(scale*(query*key^T) + pse),atten\\_mask),keep\\_prob)*value
+ $$
+
+## 前向接口:
+输入:
+- query:必选输入,Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。
+- key:必选输入,Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。
+- value:必选输入,Device侧的Tensor,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。
+- atten_mask:可选输入,数据类型bool,缺省none。在softmax之前drop的mask。
+- pse:可选输入,Device侧的Tensor,可选参数,表示位置编码。数据类型支持FLOAT16、BFLOAT16,数据格式支持ND。非varlen场景支持四维输入,包含BNSS格式、BN1Skv格式、1NSS格式。如果非varlen场景Sq大于1024或varlen场景、每个batch的Sq与Skv等长且是sparse_mode为0、2、3的下三角掩码场景,可使能alibi位置编码压缩,此时只需要输入原始PSE最后1024行进行内存优化,即alibi_compress = ori_pse[:, :, -1024:, :],参数每个batch不相同时,输入BNHSkv(H=1024),每个batch相同时,输入1NHSkv(H=1024)。如果pse_type为2或3的话,需传入数据类型为float32的slope数据,slope数据支持BN或N两种shape。
+- padding_mask:可选输入,Device侧的Tensor,暂不支持该参数。
+- atten_mask:Device侧的Tensor,可选参数,取值为1代表该位不参与计算(不生效),为0代表该位参与计算,数据类型支持BOOL、UINT8,数据格式支持ND格式,输入shape类型支持BNSS格式、B1SS格式、11SS格式、SS格式。varlen场景只支持SS格式,SS分别是maxSq和maxSkv。
+- prefix:Host侧的int array,可选参数,代表prefix稀疏计算场景每个Batch的N值。数据类型支持INT64,数据格式支持ND。
+- actual_seq_qlen:Host侧的int array,可选参数,varlen场景时需要传入此参数。表示query每个S的累加和长度,数据类型支持INT64,数据格式支持ND。
+ 比如真正的S长度列表为:2 2 2 2 2 则actual_seq_qlen传:2 4 6 8 10。
+- actual_seq_kvlen:Host侧的int array,可选参数,varlen场景时需要传入此参数。表示key/value每个S的累加和长度。数据类型支持INT64,数据格式支持ND。
+ 比如真正的S长度列表为:2 2 2 2 2 则actual_seq_kvlen传:2 4 6 8 10。
+- sparse_mode:Host侧的int,表示sparse的模式,可选参数。数据类型支持:INT64,默认值为0,支持配置值为0、1、2、3、4、5、6、7、8。当整网的atten_mask都相同且shape小于2048*2048时,建议使用defaultMask模式,来减少内存使用,
+ 具体可参考昇腾社区说明https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/apiref/apilist/ptaoplist_000448.html。
+- q_start_idx:Host侧的int array,可选参数,长度为1的int类型数组。pse_type配置为2或3时,表示内部生成alibi编码在Sq方向偏移的格数,正数表示0对角线向上移动。缺省值为0,表示不进行偏移。
+- kv_start_idx:Host侧的int array,可选参数,长度为1的int类型数组。pse_type配置为2或3时,表示内部生成alibi编码在Skv方向偏移的格数,正数表示0对角线向左移动。缺省值为0,表示不进行偏移。
+
+输出:
+(Tensor, Tensor, Tensor, Tensor, int, int, int)
+
+- 第1个输出为Tensor,计算公式的最终输出y,数据类型支持:FLOAT16、BFLOAT16。
+- 第2个输出为Tensor,Softmax 计算的Max中间结果,用于反向计算,数据类型支持:FLOAT。
+- 第3个输出为Tensor,Softmax计算的Sum中间结果,用于反向计算,数据类型支持:FLOAT。
+- 第4个输出为Tensor,保留参数,暂未使用。
+- 第5个输出为int,DSA生成dropoutmask中,Philox算法的seed。
+- 第6个输出为int,DSA生成dropoutmask中,Philox算法的offset。
+- 第7个输出为int,DSA生成dropoutmask的长度。
+
+属性:
+- scale:可选属性,Host侧的double,可选参数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE,默认值为1。
+- pse_type:可选属性,Host侧的int,数据类型支持INT64,默认值为1。支持范围0-3。
+- pse_type配置为0的时候,pse由外部传入,计算流程是先mul scale再add pse。
+- pse_type配置为1的时候,pse由外部传入,计算流程是先add pse再mul scale。
+- pse_type配置为2的时候,pse由内部生成,生成标准alibi位置信息。内部生成的alibi矩阵0线与Q@K^T的左上角对齐。
+- pse_type配置为3的时候,pse由内部生成,生成的alibi位置信息为标准的基础上再做sqrt开平方。内部生成的alibi矩阵0线与Q@K^T的左上角对齐。
+- head_num:必选属性,Host侧的int,代表head个数,数据类型支持INT64。
+- input_layout:必选属性,Host侧的string,代表输入query、key、value的数据排布格式,支持BSH、SBH、BSND、BNSD、TND(actual_seq_qlen/actual_seq_kvlen需传值);后续章节如无特殊说明,S表示query或key、value的sequence length,Sq表示query的sequence length,Skv表示key、value的sequence length,SS表示Sq*Skv
+- keep_prob:可选属性,数据类型float,默认值为1.0。在 softmax 后的保留比例。
+- pre_tokens:可选属性,Host侧的int,用于稀疏计算的参数,可选参数,数据类型支持INT64,默认值为2147483647。
+- next_tokens:可选属性,Host侧的int,用于稀疏计算的参数,可选参数,数据类型支持INT64,默认值为2147483647。
+- inner_precise:可选属性,Host侧的int,用于提升精度,数据类型支持INT64,默认值为0。
+- gen_mask_parallel:debug参数,DSA生成dropout随机数向量mask的控制开关,默认值为True:同AICORE计算并行,False:同AICORE计算串行
+- sync:debug参数,DSA生成dropout随机数向量mask的控制开关,默认值为False:dropout mask异步生成,True:dropout mask同步生成
+
+## 反向接口
+输入:
+- grad:必选输入,数据类型float16, bfloat16,正向attention_out的梯度输入
+
+输出:
+- grad_query:必选输出,数据类型float16, bfloat16
+- grad_key:必选输出,数据类型float16, bfloat16
+- grad_value:必选输出,数据类型float16, bfloat16
+
+
+## 输入限制
+- 输入query、key、value的B:batchsize必须相等,取值范围1~2M。非varlen prefix场景B最大支持2K,varlen prefix场景B最大支持1K。
+- 输入query、key、value、pse的数据类型必须一致。pse_type=2或3的时候例外,此时pse需要传fp32的slope
+- 输入query、key、value的input_layout必须一致。
+- 输入query的N和key/value的N 必须成比例关系,即Nq/Nkv必须是非0整数,Nq取值范围1~256。当Nq/Nkv > 1时,即为GQA,当Nkv=1时,即为MQA。
+- 输入key/value的shape必须一致。
+- 输入query、key、value的S:sequence length,取值范围1~1M。
+- 输入query、key、value的D:head dim,取值范围1~512。
+- sparse_mode为1、2、3、4、5、6、7、8时,应传入对应正确的atten_mask,否则将导致计算结果错误。当atten_mask输入为None时,sparse_mode,pre_tokens,next_tokens参数不生效,固定为全计算。
+- sparse_mode配置为1、2、3、5、6时,用户配置的pre_tokens、next_tokens不会生效。
+- sparse_mode配置为0、4时,须保证atten_mask与pre_tokens、next_tokens的范围一致。
+- keep_prob的取值范围为(0, 1]。
+- varlen场景不支持非压缩prefix,即不支持sparse_mode=5;当Sq>Skv时,prefix的N值取值范围[0, Skv],当Sq<=Skv时,prefix的N值取值范围[Skv-Sq, Skv]。
+- sparse_mode=7或者8时,不支持可选输入pse。
+- varlen场景:
+ atten_mask输入不支持补pad,即atten_mask中不能存在某一行全1的场景。
+ pse仅支持alibi场景,即shape为BNHSkv(H=1024)、1NHSkv(H=1024)两种场景。
+- 支持actual_seq_qlen中某个Batch上的S长度为0;如果存在S为0的情况,不支持pse输入, 假设真实的S长度为[2,2,0,2,2],则传入的actual_seq_qlen为[2,4,4,6,8];不支持某个batch中Sq不为0,但是Skv为0的场景。
+
+## flash_attention_v2 类的调用方式
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops.fusion_attention_v2 import npu_fusion_attention
+
+batch_size, head_num, seq_size, head_dim = 1, 32, 2048, 128
+input_layout = "BNSD"
+keep_prob = 1.0
+pre_tokens = 65536
+next_tokens = 1
+scale_value = 0.088
+q_scale = 1.0
+
+input_shape = (batch_size, head_num, seq_size, head_dim)
+softmax_shape = (batch_size, head_num, seq_size)
+mask_shape = (1, 1, seq_size, seq_size)
+dtype = torch.float16
+dx = (torch.rand(input_shape).to(dtype) - 0.5)
+query = (torch.rand(input_shape).to(dtype) - 0.5) * 5
+key = (torch.rand(input_shape).to(dtype) - 0.5) * 5
+value = (torch.rand(input_shape).to(dtype) - 0.5) * 5
+atten_mask = np.tri(seq_size, k=-next_tokens) + np.tri(seq_size, k=-pre_tokens).transpose()
+alibi_mask = torch.zeros((batch_size, head_num, seq_size, seq_size)).to(dtype)
+atten_mask = torch.from_numpy(atten_mask).to(torch.bool)
+
+query.requires_grad = True
+key.requires_grad = True
+value.requires_grad = True
+
+#正向接口案例
+result = npu_fusion_attention(query.npu(), key.npu(), value.npu(), head_num, input_layout, atten_mask=atten_mask.npu(), scale=scale_value, keep_prob=keep_prob, pre_tokens=pre_tokens, next_tokens=next_tokens, pse_type=1)
+
+#反向接口案例
+result[0].backward(dx.npu())
+```
diff --git a/model/train/yoco_moe/docs/ops/gmm.md b/model/train/yoco_moe/docs/ops/gmm.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a9597e9e63d71a0743d329f6360d95a39de8121
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/gmm.md
@@ -0,0 +1,112 @@
+# gmm对外接口
+
+npu_gmm(x, weight, *, bias=None, group_list=None, group_type=0, gemm_fusion=False, original_weight=None)
+
+npu_gmm_v2(x, weight, *, bias=None, group_list=None, group_type=0, gemm_fusion=False, original_weight=None)
+
+[npu_gmm_v2]相较于[npu_gmm]接口, group_list的含义不同, 在npu_gmm接口中group_list中数值为分组轴大小的cumsum结果(累积和),npu_gmm_v2接口中group_list中数值为分组轴上每组大小。两个接口的算子性能无差异,使用时可以根据整网中group_list的情况决定,如果前序算子输出的group_list为各group的大小,建议使用npu_gmm_v2接口,因为此时使用npu_gmm接口需要先调用torch.cumsum将group_list转为累积和的形式,带来额外开销。
+
+## 前向接口:
+输入:
+- x:必选输入,为tensor,数据类型float16, bfloat16, float32
+- weight:必选输入,为tensor,数据类型float16, bfloat16, float32
+- bias:可选输入,为tensor,数据类型float16, float32, 默认值为none。训练场景下,仅支持bias为none
+- group_list:可选输入,数据类型list[int64], tensor,默认值为none。不同接口中的数值定义不同,具体如上。
+- group_type:可选输入,数据类型int64,代表需要分组的轴,如矩阵乘为C[m,n]=A[m,k]xB[k,n],则groupType取值-1:不分组,0:m轴分组,1:n轴分组,2:k轴分组,默认值为0。
+- gemm_fusion:可选输入,为bool,数据类型True,False,用于反向累加梯度的时候使能GMM+ADD融合算子,默认值为False。
+- original_weight:可选输入,为tensor,数据类型float16, bfloat16, float32,用于获取view之前的weight的main_grad用于GMM+ADD中梯度累加功能,默认值为None。
+
+输出:
+- y:必选输出,数据类型float16, bfloat16, float32
+
+约束与限制:
+- npu_gmm接口中,group_list必须为非负单调非递减数列,且长度不能为1
+- npu_gmm_v2接口中,group_list必须为非负数列,长度不能为1,且数据类型仅支持tensor
+- 不同group_type支持场景:
+ | group_type | 场景限制 |
+ | :---: | :---: |
+ | 0 | 1. weight中tensor需为3维,x,y中tensor需为2维
2. 必须传group_list,如果调用npu_gmm接口,则最后一个值与x中tensor的第一维相等,如果调用npu_gmm_v2接口,则数值的总和与x中tensor的第一维相等 |
+ | 2 | 1. x,weight中tensor需为2维,y中tensor需为2维
2. 必须传group_list,如果调用npu_gmm接口,则最后一个值与x中tensor的第一维相等,如果调用npu_gmm_v2接口,则数值的总和与x中tensor的第一维相等 |
+- group_type不支持group_type=1的场景,其中昇腾310系列处理器支持转置的场景为group_type为0,x为单tensor,weight为单tensor,y为单tensor。
+- x和weight中每一组tensor的最后一维大小都应小于65536.$x_i$的最后一维指当属性transpose_x为false时$x_i$的K轴或当transpose_x为true时$x_i$的M轴。$weight_i$的最后一维指当属性transpose_weight为false时$weight_i$的N轴或当transpose_weight为true时$weight_i$的K轴。
+- x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647。
+
+## 反向接口
+输入:
+- grad:必选输入,为tensor,数据类型float16, bfloat16, float32
+- x:必选输入,为tensor,数据类型float16, bfloat16, float32
+- weight:必选输入,为tensor,数据类型float16, bfloat16, float32
+- group_list:可选输入,数据类型list[int64]、tensor,默认值为none。数据来自正向输入
+
+输出:
+- grad_x:必选输出,数据类型float16, bfloat16, float32
+- grad_weight:必选输出,数据类型float16, bfloat16, float32
+- grad_bias:当前不支持,默认为none
+
+## gmm 类的调用方式
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops import gmm
+
+num_expert, seq_len, hidden_dim = 8, 32, 256
+group_list = [1, 3, 6, 10, 15, 21, 28, 32]
+group_type = 0
+
+x_shape = (seq_len, hidden_dim)
+weight_shape = (num_expert, hidden_dim, seq_len)
+dtype = torch.float16
+x = (torch.rand(x_shape).to(dtype) - 0.5)
+weight = (torch.rand(weight_shape).to(dtype) - 0.5)
+
+# 正向接口案例
+x.requires_grad = True
+weight.requires_grad = True
+result = gmm.npu_gmm(x.npu(), weight.npu(), bias=None, group_list=group_list, group_type=group_type)
+
+# 反向接口案例
+result.backward(torch.ones(result.shape).npu())
+
+# weight转置案例
+weight_shape_trans = (num_expert, seq_len, hidden_dim)
+weight_trans = (torch.rand(weight_shape_trans).to(dtype) - 0.5)
+weight_trans.requires_grad = True
+result = gmm.npu_gmm(x.npu(), weight_trans.transpose(-1,-2).npu(), bias=None, group_list=group_list, group_type=group_type)
+```
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops import gmm
+
+num_expert, seq_len, hidden_dim = 8, 32, 256
+group_list = torch.tensor([1, 3, 3, 4, 5, 6, 7, 4])
+group_type = 0
+
+x_shape = (seq_len, hidden_dim)
+weight_shape = (num_expert, hidden_dim, seq_len)
+dtype = torch.float16
+x = (torch.rand(x_shape).to(dtype) - 0.5)
+weight = (torch.rand(weight_shape).to(dtype) - 0.5)
+
+# 正向接口案例
+x.requires_grad = True
+weight.requires_grad = True
+result = gmm.npu_gmm_v2(x.npu(), weight.npu(), bias=None, group_list=group_list.npu(), group_type=group_type)
+
+# 反向接口案例
+result.backward(torch.ones(result.shape).npu())
+
+# weight转置案例
+weight_shape_trans = (num_expert, seq_len, hidden_dim)
+weight_trans = (torch.rand(weight_shape_trans).to(dtype) - 0.5)
+weight_trans.requires_grad = True
+result = gmm.npu_gmm_v2(x.npu(), weight_trans.transpose(-1,-2).npu(), bias=None, group_list=group_list.npu(), group_type=group_type)
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/lcal_coc.md b/model/train/yoco_moe/docs/ops/lcal_coc.md
new file mode 100644
index 0000000000000000000000000000000000000000..eee78e7d332c7ca0f9f380fb8a57156aad11c78e
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/lcal_coc.md
@@ -0,0 +1,395 @@
+# LCAL_COC对外接口
+
+## MATMUL_ALL_REDUCE接口
+
+```python
+from mindspeed.ops.lcal_functional import coc_ops
+
+coc_ops.matmul_all_reduce(input1, input2, output, bias)
+```
+
+### 接口功能
+
+该接口对输入的左右矩阵进行Matmul操作,并对其结果进行All-Reduce通信,最后加上bias(如果bias不为None)。将最终结果赋值到output内存区域中。
+
+### 接口输入输出
+
+假设Matmul操作对应的shape为[m, k]和[k, n]:
+
+接口输入:
+- input1:左矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,不支持转置,\[m,k\]);
+- input2:右矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,支持转置,\[k,n\]/\[n,k\]);
+- output:输出矩阵,需要提前申请内存作为接口的输入(必选输入,数据类型float16/bfloat16,shape只支持二维,\[m,n\]);
+- bias:偏置向量(可选输入,数据类型float16/bfloat16,shape支持\[1, n\]);
+
+接口输出:
+- 无
+
+### 使用案例
+
+```python
+import torch
+import torch_npu
+import torch.multiprocessing as mp
+import torch.distributed as dist
+from torch_npu.contrib import transfer_to_npu
+import megatron.core.parallel_state as ps
+
+
+def initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ virtual_pipeline_model_parallel_size=None,
+ pipeline_model_parallel_split_rank=None,
+ context_parallel_size=1,
+):
+ ps.destroy_model_parallel()
+ ps.initialize_model_parallel(
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
+ context_parallel_size=context_parallel_size,
+ )
+
+
+def test_coc_matmul_all_reduce(rank, world_size, master_ip, master_port):
+ torch_npu.npu.set_device(rank)
+ init_method = 'tcp://' + master_ip + ':' + master_port
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
+ initialize_model_parallel(world_size)
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ m, k, n = 2048, 4096, 8192
+ dtype = torch.float16
+ input1 = torch.rand(m, k, dtype=dtype, device=torch.npu.current_device())
+ input2 = torch.rand(k, n, dtype=dtype, device=torch.npu.current_device())
+ bias = torch.rand(1, n, dtype=dtype, device=torch.npu.current_device())
+ output = torch.zeros(m, n, dtype=dtype, device=torch.npu.current_device())
+ coc_ops.matmul_all_reduce(input1, input2, output, bias)
+ torch.npu.synchronize()
+ print(output)
+
+if __name__ == "__main__":
+ world_size = 8
+ master_ip = "127.0.0.1"
+ master_port = "50001"
+ mp.spawn(test_coc_matmul_all_reduce, args=(world_size, master_ip, master_port), nprocs=world_size)
+```
+
+
+## ALL_GATHER_MATMUL接口
+
+```python
+from mindspeed.ops.lcal_functional import coc_ops
+
+coc_ops.all_gather_matmul(input1, input2, output, bias)
+````
+
+### 接口功能
+
+该接口对输入的左矩阵进行All-Gather操作,然后将其与右矩阵做Matmul操作,最后加上bias(如果bias不为None)。将最终结果赋值到output内存区域中。
+
+### 接口输入输出
+
+假设Matmul操作对应的shape为[m, k]和[k, n](m必须为world_size的倍数):
+
+接口输入:
+- input1:左矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,不支持转置,\[m // world_size,k\]);
+- input2:右矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,支持转置,\[k,n\]/\[n,k\]);
+- output:输出矩阵,需要提前申请内存作为接口的输入(必选输入,数据类型float16/bfloat16,shape只支持二维,\[m,n\]);
+- bias:偏置向量(可选输入,数据类型float16/bfloat16,shape支持\[1, n\]);
+
+接口输出:
+- 无
+
+### 使用案例
+
+```python
+import torch
+import torch_npu
+import torch.multiprocessing as mp
+import torch.distributed as dist
+from torch_npu.contrib import transfer_to_npu
+import megatron.core.parallel_state as ps
+
+
+def initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ virtual_pipeline_model_parallel_size=None,
+ pipeline_model_parallel_split_rank=None,
+ context_parallel_size=1,
+):
+ ps.destroy_model_parallel()
+ ps.initialize_model_parallel(
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
+ context_parallel_size=context_parallel_size,
+ )
+
+
+def test_coc_all_gather_matmul(rank, world_size, master_ip, master_port):
+ torch_npu.npu.set_device(rank)
+ init_method = 'tcp://' + master_ip + ':' + master_port
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
+ initialize_model_parallel(world_size)
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ m, k, n = 2048, 4096, 8192
+ dtype = torch.float16
+ input1 = torch.rand(m // world_size, k, dtype=dtype, device=torch.npu.current_device())
+ input2 = torch.rand(k, n, dtype=dtype, device=torch.npu.current_device())
+ bias = torch.rand(1, n, dtype=dtype, device=torch.npu.current_device())
+ output = torch.zeros(m, n, dtype=dtype, device=torch.npu.current_device())
+ coc_ops.all_gather_matmul(input1, input2, output, bias)
+ torch.npu.synchronize()
+ print(output)
+
+
+if __name__ == "__main__":
+ world_size = 8
+ master_ip = "127.0.0.1"
+ master_port = "50001"
+ mp.spawn(test_coc_all_gather_matmul, args=(world_size, master_ip, master_port), nprocs=world_size)
+```
+
+
+## ALL_GATHER_MATMUL_V2接口
+
+```python
+from mindspeed.ops.lcal_functional import coc_ops
+
+coc_ops.all_gather_matmul_v2(input1, input2, output, comm_output, bias)
+```
+
+### 接口功能
+
+该接口对输入的左矩阵进行All-Gather操作,然后将其与右矩阵做Matmul操作,最后加上bias(如果bias不为None)。将最终结果赋值到output内存区域中,并将左矩阵进行All-Gather操作后得到的结果赋值到comm_output内存区域中。
+
+### 接口输入输出
+
+假设Matmul操作对应的shape为[m, k]和[k, n](m必须为world_size的倍数):
+
+接口输入:
+- input1:左矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,不支持转置,\[m // world_size,k\]);
+- input2:右矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,支持转置,\[k,n\]/\[n,k\]);
+- output:输出矩阵,需要提前申请内存作为接口的输入(必选输入,数据类型float16/bfloat16,shape只支持二维,\[m,n\]);
+- comm_output:输出矩阵,需要提前申请内存作为接口的输入(必选输入,数据类型float16/bfloat16,shape只支持二维,\[m,k\]);
+- bias:偏置向量(可选输入,数据类型float16/bfloat16,shape支持\[1, n\]);
+
+接口输出:
+- 无
+
+### 使用案例
+
+```python
+import torch
+import torch_npu
+import torch.multiprocessing as mp
+import torch.distributed as dist
+from torch_npu.contrib import transfer_to_npu
+import megatron.core.parallel_state as ps
+
+
+def initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ virtual_pipeline_model_parallel_size=None,
+ pipeline_model_parallel_split_rank=None,
+ context_parallel_size=1,
+):
+ ps.destroy_model_parallel()
+ ps.initialize_model_parallel(
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
+ context_parallel_size=context_parallel_size,
+ )
+
+
+def test_coc_all_gather_matmul_v2(rank, world_size, master_ip, master_port):
+ torch_npu.npu.set_device(rank)
+ init_method = 'tcp://' + master_ip + ':' + master_port
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
+ initialize_model_parallel(world_size)
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ m, k, n = 2048, 4096, 8192
+ dtype = torch.float16
+ input1 = torch.rand(m // world_size, k, dtype=dtype, device=torch.npu.current_device())
+ input2 = torch.rand(k, n, dtype=dtype, device=torch.npu.current_device())
+ bias = torch.rand(1, n, dtype=dtype, device=torch.npu.current_device())
+ output = torch.zeros(m, n, dtype=dtype, device=torch.npu.current_device())
+ comm_output = torch.zeros(m, k, dtype=dtype, device=torch.npu.current_device())
+ coc_ops.all_gather_matmul_v2(input1, input2, output, comm_output, bias)
+ torch.npu.synchronize()
+ print(output)
+
+
+if __name__ == "__main__":
+ world_size = 8
+ master_ip = "127.0.0.1"
+ master_port = "50001"
+ mp.spawn(test_coc_all_gather_matmul_v2, args=(world_size, master_ip, master_port), nprocs=world_size)
+```
+
+## MATMUL_REDUCE_SCATTER接口
+
+```python
+from mindspeed.ops.lcal_functional import coc_ops
+
+coc_ops.matmul_reduce_scatter(input1, input2, output, bias)
+````
+
+### 接口功能
+
+该接口对输入的左右矩阵进行Matmul操作,并对其结果进行Reduce-Scatter通信,最后加上bias(如果bias不为None)。将最终结果赋值到output内存区域中。
+
+### 接口输入输出
+
+假设Matmul操作对应的shape为[m, k]和[k, n](m必须为world_size的倍数):
+
+接口输入:
+- input1:左矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,不支持转置,\[m,k\]);
+- input2:右矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,支持转置,\[k,n\]/\[n,k\]);
+- output:输出矩阵,需要提前申请内存作为接口的输入(必选输入,数据类型float16/bfloat16,shape只支持二维,\[m // world_size,n\]);
+- bias:偏置向量(可选输入,数据类型float16/bfloat16,shape支持\[1, n\]);
+
+接口输出:
+- 无
+
+### 使用方法
+
+```python
+import torch
+import torch_npu
+import torch.multiprocessing as mp
+import torch.distributed as dist
+from torch_npu.contrib import transfer_to_npu
+import megatron.core.parallel_state as ps
+
+
+def initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ virtual_pipeline_model_parallel_size=None,
+ pipeline_model_parallel_split_rank=None,
+ context_parallel_size=1,
+):
+ ps.destroy_model_parallel()
+ ps.initialize_model_parallel(
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
+ context_parallel_size=context_parallel_size,
+ )
+
+
+def test_coc_matmul_reduce_scatter(rank, world_size, master_ip, master_port):
+ torch_npu.npu.set_device(rank)
+ init_method = 'tcp://' + master_ip + ':' + master_port
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
+ initialize_model_parallel(world_size)
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ m, k, n = 2048, 4096, 8192
+ dtype = torch.float16
+ input1 = torch.rand(m, k, dtype=dtype, device=torch.npu.current_device())
+ input2 = torch.rand(k, n, dtype=dtype, device=torch.npu.current_device())
+ bias = torch.rand(1, n, dtype=dtype, device=torch.npu.current_device())
+ output = torch.zeros(m // world_size, n, dtype=dtype, device=torch.npu.current_device())
+ coc_ops.matmul_reduce_scatter(input1, input2, output, bias)
+ torch.npu.synchronize()
+ print(output)
+
+
+if __name__ == "__main__":
+ world_size = 8
+ master_ip = "127.0.0.1"
+ master_port = "50001"
+ mp.spawn(test_coc_matmul_reduce_scatter, args=(world_size, master_ip, master_port), nprocs=world_size)
+```
+
+
+## PURE_MATMUL接口
+
+```python
+from mindspeed.ops.lcal_functional import coc_ops
+
+coc_ops.pure_matmul(input1, input2, output, bias)
+````
+
+### 接口功能
+
+该接口对输入的左右矩阵进行Lcal Matmul操作,最后加上bias(如果bias不为None)。将最终结果赋值到output内存区域中。
+
+### 接口输入输出
+
+假设Matmul操作对应的shape为[m, k]和[k, n]:
+
+接口输入:
+- input1:左矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,不支持转置,\[m,k\]);
+- input2:右矩阵(必选输入,数据类型float16/bfloat16,shape只支持二维,支持转置,\[k,n\]/\[n,k\]);
+- output:输出矩阵,需要提前申请内存作为接口的输入(必选输入,数据类型float16/bfloat16,shape只支持二维,\[m,n\]);
+- bias:偏置向量(可选输入,数据类型float16/bfloat16,shape支持\[1, n\]);
+
+接口输出:
+- 无
+
+### 使用方法
+
+```python
+import torch
+import torch_npu
+import torch.multiprocessing as mp
+import torch.distributed as dist
+from torch_npu.contrib import transfer_to_npu
+import megatron.core.parallel_state as ps
+
+
+def initialize_model_parallel(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ virtual_pipeline_model_parallel_size=None,
+ pipeline_model_parallel_split_rank=None,
+ context_parallel_size=1,
+):
+ ps.destroy_model_parallel()
+ ps.initialize_model_parallel(
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
+ virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
+ context_parallel_size=context_parallel_size,
+ )
+
+
+def test_coc_pure_matmul(rank, world_size, master_ip, master_port):
+ torch_npu.npu.set_device(rank)
+ init_method = 'tcp://' + master_ip + ':' + master_port
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
+ initialize_model_parallel(world_size)
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ m, k, n = 2048, 4096, 8192
+ dtype = torch.float16
+ input1 = torch.rand(m, k, dtype=dtype, device=torch.npu.current_device())
+ input2 = torch.rand(k, n, dtype=dtype, device=torch.npu.current_device())
+ bias = torch.rand(1, n, dtype=dtype, device=torch.npu.current_device())
+ output = torch.zeros(m, n, dtype=dtype, device=torch.npu.current_device())
+ coc_ops.pure_matmul(input1, input2, output, bias)
+ torch.npu.synchronize()
+ print(output)
+
+
+if __name__ == "__main__":
+ world_size = 8
+ master_ip = "127.0.0.1"
+ master_port = "50001"
+ mp.spawn(test_coc_pure_matmul, args=(world_size, master_ip, master_port), nprocs=world_size)
+```
diff --git a/model/train/yoco_moe/docs/ops/npu_all_to_all_all_gather_bmm.md b/model/train/yoco_moe/docs/ops/npu_all_to_all_all_gather_bmm.md
new file mode 100644
index 0000000000000000000000000000000000000000..da9a330e4f089a4993a131c4835584f4abc03743
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_all_to_all_all_gather_bmm.md
@@ -0,0 +1,190 @@
+# npu_alltoall_allgather_bmm对外接口
+```
+def npu_alltoall_allgather_bmm(
+ x: Tensor,
+ weight: Tensor,
+ group_ep: str,
+ group_ep_worldsize: int,
+ group_tp: str,
+ group_tp_worldsize: int,
+ *,
+ bias: Optional[Tensor] = None,
+ shard_type: Optional[int] = 0,
+ act_type: Optional[str] = "None",
+ need_allgather_out: Optional[bool] = False,
+ need_activation_feature: Optional[bool] = False
+) -> (Tensor, Tensor, Tensor):
+
+```
+
+计算逻辑:
+bmm指BatchMatMul,AllToAllAllGahterBatchMatMul算子是实现AllToAll、AllGather集合通信与BatchMatMul计算并行的算子。
+大体计算流程为:AllToAll集合通信-->AllGather集合通信-->BatchMatMul-->激活(可选,可以没有)
+
+计算逻辑如下,其中y1Out y2OutOptional y3OutOptional为输出,x weight bias为输入,activating为激活函数(由act_type决定,当act_type为None时,表示不调用激活函数)
+$$
+ alltoallOut = AllToAll(x)
+$$
+$$
+ y2OutOptional = AllGather(alltoallOut)
+$$
+$$
+ y3OutOptional = BatchMatMul(y2OutOptional, weight, bias)
+$$
+$$
+ y1Out = activating(y3OutOptional)
+$$
+
+## 输入输出及属性说明:
+输入:
+- x:必选输入,Tensor,数据类型支持float16,bfloat16。该输入进行AllToAll、AllGather集合通信,必须为3维,数据格式支持ND,通信后结果作为BatchMatMul计算的左矩阵。
+- weight:必选输入,Tensor,数据类型支持float16, bfloat16,类型需与x保持一致,必须为3维,数据格式支持ND, BatchMatMul计算的右矩阵。
+- bias:可选输入,Tensor,数据类型支持float16, float32。x为float16时,bias需为float16;x为bfloat16时,bias需为float32,必须为两维或三维,数据格式支持ND。BatchMatMul计算的bias。
+
+输出:
+- y1Out:Tensor,数据类型支持float16, bfloat16,仅支持3维。最终计算结果,如果有激活函数则为激活函数的输出,否则为BatchMatMul的输出。数据类型与输入x保持一致。
+- y2OutOptional:Tensor,可选输出,数据类型支持float16, bfloat16,仅支持3维。AllGather的输出,数据类型与输入x保持一致。反向可能需要。
+- y3OutOptional:Tensor,可选输出,数据类型支持float16, bfloat16,仅支持3维。有激活函数时,BatchMatMul的输出,类型与输入x保持一致。
+
+属性:
+- group_ep:必选属性,str。ep通信域名称,专家并行的通信域。
+- group_ep_worldsize:必选属性,int。ep通信域size,支持2/4/8/16/32。
+- group_tp:必选属性,str。tp通信域名称,Tensor并行的通信域。
+- group_tp_worldsize:必选属性,int。tp通信域size,支持2/4/8/16/32。
+- shard_type:可选属性,int,默认值为0,0表示在H维度按tp域进行allgather,1表示在C维度上按tp域进行allgather。
+- act_type:可选属性,str,激活函数类型,默认值为None,表示无激活函数。支持GELU/Silu/FastGELU/Relu/None等。
+- need_allgather_out:是否需要输出allgather后的结果,默认False,表示不需要输出。
+- need_activation_feature:是否需要输出执行激活函数前的结果(BatchMatMul后),默认False,表示不需要输出。仅在act_type不为None的时候有意义。
+
+
+## 输入shape限制
+因为集合通信及BatchMatMul计算所需,输入输出shape需满足以下数学关系:(其中ep=group_ep_worldsize,tp=group_tp_worldsize)
+按H轴进行AllGather场景,shard_type为0时:
+- x: (E, C, H/tp)
+- weight:(E/ep, H, M/tp)
+- bias:支持两维或三维,三维时shape为:(E/ep, 1, M/tp),两维时shape为:(E/ep, M/tp)
+- y1Out:(E/ep, ep\*C, M/tp)
+- y2OutOptional:(E/ep, ep\*C, H)
+- y3OutOptional:(E/ep, ep\*C, M/tp)
+按C轴进行AllGather场景,shard_type为1时:
+- x: (E, C/tp, H);
+- weight:(E/ep, H, M/tp);
+- bias:支持两维或三维,三维时shape为:(E/ep, 1, M/tp),两维时shape为:(E/ep, M/tp)
+- y1Out:(E/ep, ep\*tp\*C/tp, M/tp);
+- y2OutOptional:(E/ep, ep\*tp\*C/tp, H);
+- y3OutOptional:(E/ep, ep\*tp\*C/tp, M/tp)
+
+数据关系说明:
+- 比如x.size(0)等于E,weight.size(0)等于E/ep,则表示,x.size(0) = ep\*weight.size(0),x.size(0)是ep的整数倍;其他关系类似
+- E的取值范围为[2, 512],且E是ep的整数倍;
+- H的取值范围为:[1, 65535],当shard_type为0时,H需为tp的整数倍;
+- M/tp的取值为:[1, 65535];
+- E/ep的取值范围为:[1, 32];
+- ep、tp均仅支持2、4、8、16、32;
+- group_ep和group_tp名称不能相同;
+- C大于0,上限为算子device内存上限,当shard_type为1时,C需为tp的整数倍;
+- 不支持跨超节点,只支持超节点内。
+
+## npu_alltoall_allgather_bmm 的调用示例
+在终端调用命令如下:
+```
+python3 -m torch.distributed.launch --nproc_per_node 8 --master_addr 127.0.0.1 --master_port 29500 demo_test.py
+```
+注:master_addr和master_port参数需用户根据实际情况设置
+
+demo_test.py的示例代码如下:
+```python
+import os
+import pytest
+import torch
+import torch.distributed as dist
+from torch.distributed.distributed_c10d import _get_default_group, ReduceOp
+import torch_npu
+from mindspeed.ops.npu_all_to_all_all_gather_bmm import npu_alltoall_allgather_bmm
+
+world_size = 8
+ep_size = 4
+tp_size = 2
+def setup_ep_tp(rank, tp_size, ep_size, backend_type):
+ # 初始化EP域
+ print("device %d initialize ep group" % rank, flush=True)
+ for i in range(tp_size):
+ ep_ranks = [x + ep_size * i for x in range(ep_size)]
+ ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
+ if rank in ep_ranks:
+ ep_group_tmp = ep_group
+ print("device %d initialize tp group" % rank, flush=True)
+ for i in range(ep_size):
+ tp_ranks = [x * ep_size + i for x in range(tp_size)]
+ tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
+ if rank in tp_ranks:
+ tp_group_tmp = tp_group
+ return ep_group_tmp, tp_group_tmp
+
+def get_ep_tp_hcomm_info(rank, ep_size, tp_size):
+ ep_group, tp_group = setup_ep_tp(rank, tp_size, ep_size, "hccl")
+ if torch.__version__ > '2.0.1':
+ ep_hcomm_info = ep_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
+ tp_hcomm_info = tp_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
+ else:
+ ep_hcomm_info = ep_group.get_hccl_comm_name(rank)
+ tp_hcomm_info = tp_group.get_hccl_comm_name(rank)
+ return ep_hcomm_info, tp_hcomm_info
+
+if __name__ == '__main__':
+ dtype = torch.float16
+ x_shard_type = 1
+ out_y2_flag = True
+ out_y3_flag = False
+ act_type = "None"
+ transpose_weight = False
+ rank = int(os.environ["LOCAL_RANK"])
+ torch_npu.npu.set_device(rank)
+ dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)
+ ep_group, tp_group = get_ep_tp_hcomm_info(rank, ep_size, tp_size)
+ print(f'current device: {torch_npu.npu.current_device()}, local rank = {rank}, hcomm_info = {ep_group}, {tp_group}')
+ E, C, H, M = 4, 1024, 1024, 8192
+ if x_shard_type == 0:
+ x_shape = (E, C, H / tp_size)
+ elif x_shard_type == 1:
+ x_shape = (E, C / tp_size, H)
+ else:
+ x_shape = (E / ep_size, tp_size * ep_size * C, M / tp_size)
+ weight_shape = (E / ep_size, H, M / tp_size)
+ if transpose_weight == True:
+ weight_shape = (E / ep_size, M / tp_size, H)
+ bias_shape = (E / ep_size, 1, M / tp_size)
+
+ x_shape = tuple(int(item) for item in x_shape)
+ weight_shape = tuple(int(item) for item in weight_shape)
+ bias_shape = tuple(int(item) for item in bias_shape)
+ x = torch.rand(x_shape)
+ weight = torch.rand(weight_shape)
+ bias = torch.rand(bias_shape)
+ x_npu = x.npu().to(dtype)
+ weight_npu = weight.npu().to(dtype)
+ if transpose_weight == True:
+ print(f'!!!!before transpose, weight_npu.size()={weight_npu.size()}')
+ weight_npu = weight_npu.transpose(1, 2)
+ print(f'!!!!after transpose, weight_npu.size()={weight_npu.size()}')
+ print(f'!!!!after transpose, weight_npu.is_contiguous()={weight_npu.is_contiguous()}')
+ bias_npu = bias.npu().to(dtype)
+ # 赋值None可以验证bias为空的场景
+ bias_npu = None
+
+ y_npu = npu_alltoall_allgather_bmm(x_npu,
+ weight_npu,
+ ep_group,
+ ep_size,
+ tp_group,
+ tp_size,
+ bias=bias_npu,
+ shard_type=x_shard_type,
+ act_type=act_type,
+ need_allgather_out=out_y2_flag,
+ need_activation_feature=out_y3_flag)
+ if rank == 0:
+ for i, y in enumerate(y_npu[0]):
+ y.cpu().numpy().tofile(f"./y_{i}.bin")
+
+```
diff --git a/model/train/yoco_moe/docs/ops/npu_apply_fused_ema_adamw.md b/model/train/yoco_moe/docs/ops/npu_apply_fused_ema_adamw.md
new file mode 100644
index 0000000000000000000000000000000000000000..9a4a329f66d070cac06af105d0ffeef508b7d9ea
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_apply_fused_ema_adamw.md
@@ -0,0 +1,91 @@
+# npu_apply_fused_ema_adamw 对外接口
+
+## 接口原型
+```
+npu_apply_fused_ema_adamw(grad, var, m, v, s, step, lr, ema_decay, beta1, beta2, eps, mode, bias_correction, weight_decay)-> var, m, v, s
+```
+npu_apply_fused_ema_adamw接口用于更新fused_ema_adamw优化器中的var(模型参数), m(一阶矩动量), v(二阶矩动量), s(ema模型参数)这四个参数。
+
+```python
+# 接口内部计算逻辑示例如下
+def npu_apply_fused_ema_adamw(grad, var, m, v, s, step, lr, ema_decay,
+ beta1, beta2, eps, mode, bias_correction,
+ weight_decay):
+ beta1_correction = 1 - torch.pow(beta1, step) * bias_correction
+ beta2_correction = 1 - torch.pow(beta2, step) * bias_correction
+ grad_ = grad + weight_decay * var * (1 - mode)
+ m_ = beta1 * m + (1 - beta1) * grad_
+ v_ = beta2 * v + (1 - beta2) * grad_ * grad_
+ next_m = m_ / beta1_correction
+ next_v = v_ / beta2_correction
+ demon = torch.pow(next_v, 0.5) + eps
+ update = next_m / demon + weight_decay * var * mode
+ var_ = var - lr * update
+ s_ = ema_decay * s + (1 - ema_decay) * var_
+ return var_, m_, v_, s_
+```
+
+## 输入:
+- `grad`:必选输入,数据类型为tensor(float32),表示模型参数的梯度。接受任意shape但需保持接口调用时`grad, var, m, v, s`五个入参shape一致。
+- `var`:必选输入,数据类型为tensor(float32),表示模型参数。接受任意shape但需保持接口调用时`grad, var, m, v, s`五个入参shape一致。
+- `m`:必选输入,数据类型为tensor(float32),表示一阶矩动量。接受任意shape但需保持接口调用时`grad, var, m, v, s`五个入参shape一致。
+- `v`:必选输入,数据类型为tensor(float32),表示二阶矩动量。接受任意shape但需保持接口调用时`grad, var, m, v, s`五个入参shape一致。
+- `s`:必选输入,数据类型为tensor(float32),表示ema模型参数。接受任意shape但需保持接口调用时`grad, var, m, v, s`五个入参shape一致。
+- `step`:必选输入,数据类型为tensor(int64),shape:(1,),表示当前为第几步。
+- `lr`:可选属性,数据类型为float32,默认值:1e-3。表示学习率。
+- `ema_decay`:可选属性,数据类型为float32,默认值:0.9999。表示ema衰减超参数。
+- `beta1`:可选属性,数据类型为float32,默认值:0.9。表示一阶矩动量的衰减率。
+- `beta2`:可选属性,数据类型为float32,默认值:0.999。表示二阶矩动量的衰减率。
+- `eps`:可选属性,数据类型为float32,默认值:1e-8。表示一个极小的数。
+- `mode`:可选属性,数据类型为int,默认值:1。取1表示以adamw模式计算,取0表示以adam模式计算。
+- `bias_correction`:可选属性,数据类型为bool,默认值:True。表示是否开启偏置修正。
+- `weight_decay`:可选属性,数据类型为float32,默认值:0.0。表示模型参数的衰减率。
+
+支持的输入数据类型组合:
+
+| 参数名称 | 数据类型 |
+|-----------------|-----------------|
+| grad | tensor(float32) |
+| var | tensor(float32) |
+| m | tensor(float32) |
+| v | tensor(float32) |
+| s | tensor(float32) |
+| step | tensor(int64) |
+| lr | float32 |
+| ema_decay | float32 |
+| beta1 | float32 |
+| beta2 | float32 |
+| eps | float32 |
+| mode | int |
+| bias_correction | bool |
+| weight_decay | float32 |
+
+## 输出:
+- var:必选输出,数据类型为tensor(float32),shape和入参var一致,表示更新后的模型参数。
+- m:必选输出,数据类型为tensor(float32),shape和入参m一致,表示更新后的一阶矩动量。
+- v:必选输出,数据类型为tensor(float32),shape和入参v一致,表示更新后的二阶矩动量。
+- s:必选输出,数据类型为tensor(float32),shape和入参s一致,表示更新后的ema模型参数。
+
+***
+## 调用示例:
+
+- 输入 grad, var, m, v, s, step, lr, ema_decay, beta1, beta2, eps, mode, bias_correction, weight_decay
+- 输出 var, m, v, s
+
+```python
+import torch
+import torch_npu
+from mindspeed.ops.npu_apply_fused_ema_adamw import npu_apply_fused_ema_adamw
+
+
+grad = torch.full((10, 10), 0.5).to(torch.float32).npu()
+var = torch.full((10, 10), 0.5).to(torch.float32).npu()
+m = torch.full((10, 10), 0.9).to(torch.float32).npu()
+v = torch.full((10, 10), 0.9).to(torch.float32).npu()
+s = torch.full((10, 10), 0.5).to(torch.float32).npu()
+step = torch.full((1, ), 1).to(torch.int64).npu()
+lr, ema_decay, beta1, beta2, eps, mode, bias_correction, weight_decay= 1e-8, 0.9999, 0.9999, 0.9999, 1e-8, 1, True, 0.001
+var, m, v, s = npu_apply_fused_ema_adamw(grad, var, m, v, s, step, lr, ema_decay, beta1, beta2, eps, mode,
+ bias_correction, weight_decay)
+
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/npu_bmm_reduce_scatter_all_to_all.md b/model/train/yoco_moe/docs/ops/npu_bmm_reduce_scatter_all_to_all.md
new file mode 100644
index 0000000000000000000000000000000000000000..8c4e3b6d6a358ce06bed8a448188366b46f2b440
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_bmm_reduce_scatter_all_to_all.md
@@ -0,0 +1,187 @@
+# npu_bmm_reducescatter_alltoall对外接口
+```
+def npu_bmm_reducescatter_alltoall(x: Tensor,
+ weight: Tensor,
+ group_ep: str,
+ group_ep_worldsize: int,
+ group_tp: str,
+ group_tp_worldsize: int,
+ *,
+ bias: Optional[Tensor] = None,
+ shard_type: Optional[int] = 0) -> Tensor:
+```
+
+计算逻辑:
+BatchMatMulReduceScatterAllToAll是实现BatchMatMul计算与ReduceScatter、AllToAll集合通信并行的算子。
+大体计算流程为:BatchMatMul计算-->转置(shard_type等于0时需要)-->ReduceScatter集合通信-->Add-->AllToAll集合通信
+
+计算逻辑如下,其中out为最终输出,x weight bias为输入
+$$
+ bmmOut = BatchMatMul(x,weight)
+$$
+$$
+ reduceScatterOut = ReduceScatter(bmmOut)
+$$
+$$
+ addOut = Add(reduceScatterOut, bias)
+$$
+$$
+ out = AllToAll(addOut)
+$$
+
+## 输入输出及属性说明:
+输入:
+- x:必选输入,Tensor,数据类型float16,bfloat16,必须为3维。BatchMatMul计算的左矩阵。
+- weight:必选输入,Tensor,数据类型float16, bfloat16,必须为3维,类型与x保持一致。BatchMatMul计算的右矩阵。
+- bias:可选输入,Tensor,数据类型float16, float32。x为float16时,bias需为float16;x为bfloat16时,bias需为float32。支持两维或三维。BatchMatMul计算的bias。(由于要进行ReduceScatter通信,因此需要在通信之后再Add)。
+
+输出:
+- out:Tensor,数据类型float16, bfloat16,必须为3维。最终计算结果,类型与输入x保持一致。
+
+属性:
+- group_ep:必选属性,str。ep通信域名称,专家并行的通信域。
+- group_ep_worldsize:必选属性,int。ep通信域size,支持2/4/8/16/32。
+- group_tp:必选属性,str。tp通信域名称,Tensor并行的通信域。
+- group_tp_worldsize:必选属性,int。tp通信域size,支持2/4/8/16/32。
+- shard_type:可选属性,int,默认值为0。0表示输出在H维度按tp分片,1表示输出在C维度按tp分片。
+
+
+## 输入限制
+因为集合通信及BatchMatMul计算所需,输入输出shape需满足以下数学关系:(其中ep=group_ep_worldsize,tp=group_tp_worldsize)
+
+按H轴进行ReduceScatter场景,即shard_type为0场景:
+- x: (E/ep, ep\*C, M/tp)
+- weight:(E/ep, M/tp, H)
+- bias:(E/ep, 1, H/tp) 两维时为(E/ep, H/tp)
+- out:(E, C, H/tp)
+
+按C轴进行ReduceScatter场景,即shard_type为1场景:
+- x: (E/ep, ep\*tp\*C/tp, M/tp)
+- weight:(E/ep, M/tp, H)
+- bias:(E/ep, 1, H) 两维时为(E/ep, H)
+- out:(E, C/tp, H)
+
+数据关系说明:
+- 比如x.size(0)等于E/tp,out.size(0)等于E,则表示,out.size(0) = ep\*x.size(0),out.size(0)是ep的整数倍;其他关系类似
+- E的取值范围为[2, 512],且E是ep的整数倍;
+- H的取值范围为:[1, 65535],当shard_type为0时,H需为tp的整数倍;
+- M/tp的取值范围为:[1, 65535];
+- E/ep的取值范围为:[1, 32];
+- ep、tp均仅支持2、4、8、16、32;
+- group_ep和group_tp名称不能相同;
+- C大于0,上限为算子device内存上限,当shard_type为1时,C需为tp的整数倍;
+- 不支持跨超节点,只支持超节点内。
+
+## npu_bmm_reducescatter_alltoall 类的调用示例(待验证)
+在终端调用命令如下:
+```
+python3 -m torch.distributed.launch --nproc_per_node 8 --master_addr 127.0.0.1 --master_port 29500 demo_test.py
+```
+注:master_addr和master_port参数需用户根据实际情况设置,8表示ep_size*tp_size,按实际修改
+
+demo_test.py的示例代码如下:
+```python
+import os
+import pytest
+import torch
+import torch.distributed as dist
+from torch.distributed.distributed_c10d import _get_default_group, ReduceOp
+import torch_npu
+from mindspeed.ops.npu_bmm_reduce_scatter_all_to_all import npu_bmm_reducescatter_alltoall
+
+world_size = 8
+ep_size = 4
+tp_size = 2
+def get_hcomm_info(n, i):
+ default_pg = _get_default_group()
+ if torch.__version__ > '2.0.1':
+ hcomm_info = default_pg._get_backend(torch.device('npu')).get_hccl_comm_name(i)
+ else:
+ hcomm_info = default_pg.get_hccl_comm_name(i)
+ return hcomm_info
+
+def setup_ep_tp(rank, tp_size, ep_size, backend_type):
+ # 初始化EP域
+ print("device %d initialize ep group" % rank, flush=True)
+ for i in range(tp_size):
+ ep_ranks = [x + ep_size * i for x in range(ep_size)]
+ ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
+ if rank in ep_ranks:
+ ep_group_tmp = ep_group
+ print("device %d initialize tp group" % rank, flush=True)
+ for i in range(ep_size):
+ tp_ranks = [x * ep_size + i for x in range(tp_size)]
+ tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
+ if rank in tp_ranks:
+ tp_group_tmp = tp_group
+ return ep_group_tmp, tp_group_tmp
+
+def get_ep_tp_hcomm_info(rank, ep_size, tp_size):
+ ep_group, tp_group = setup_ep_tp(rank, tp_size, ep_size, "hccl")
+ if torch.__version__ > '2.0.1':
+ ep_hcomm_info = ep_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
+ tp_hcomm_info = tp_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
+ else:
+ ep_hcomm_info = ep_group.get_hccl_comm_name(rank)
+ tp_hcomm_info = tp_group.get_hccl_comm_name(rank)
+ return ep_hcomm_info, tp_hcomm_info
+
+def test_npu_bmm_reducescatter_alltoall(dtype, y_shard_type, transpose_weight):
+ rank = int(os.environ["LOCAL_RANK"])
+ torch_npu.npu.set_device(rank)
+ dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)
+ ep_group, tp_group = get_ep_tp_hcomm_info(rank, ep_size, tp_size)
+ hcomm_info = get_hcomm_info(world_size, rank)
+ print(f'current device: {torch_npu.npu.current_device()}, local rank = {rank}, hcomm_info = {ep_group}, {tp_group}')
+ E, C, H, M = 4, 1024, 1024, 8192
+ if y_shard_type == 0:
+ x_shape = (E / ep_size, ep_size * C, M / tp_size)
+ bias_shape = (E / ep_size, 1, H / tp_size)
+ else:
+ x_shape = (E / ep_size, tp_size * ep_size * C, M / tp_size)
+ bias_shape = (E / ep_size, 1, H)
+ weight_shape = (E / ep_size, M / tp_size, H)
+ if transpose_weight == True:
+ weight_shape = (E / ep_size, H, M / tp_size)
+
+ x_shape = tuple(int(item) for item in x_shape)
+ weight_shape = tuple(int(item) for item in weight_shape)
+ bias_shape = tuple(int(item) for item in bias_shape)
+ x = torch.rand(x_shape)
+ weight = torch.rand(weight_shape)
+ bias = torch.rand(bias_shape)
+ x_npu = x.npu().to(dtype)
+ weight_npu = weight.npu().to(dtype)
+ if transpose_weight == True:
+ print(f'!!!!before transpose, weight_npu.size()={weight_npu.size()}')
+ weight_npu = weight_npu.transpose(1, 2)
+ print(f'!!!!after transpose, weight_npu.size()={weight_npu.size()}')
+ print(f'!!!!after transpose, weight_npu.is_contiguous()={weight_npu.is_contiguous()}')
+ bias_npu = bias.npu().to(dtype)
+
+ y = npu_bmm_reducescatter_alltoall(x_npu,
+ weight_npu,
+ ep_group,
+ ep_size,
+ tp_group,
+ tp_size,
+ bias=bias_npu,
+ shard_type=y_shard_type)
+ print(f'y_shape = {y.size()}')
+ if y_shard_type == 0:
+ assert y.size() == (E, C, int(H / tp_size))
+ else:
+ assert y.size() == (E, C, H)
+ return y
+
+if __name__ == '__main__':
+ dtype = torch.float16
+ shard_type = 1
+ transpose_weight = False
+ y_npu = test_npu_bmm_reducescatter_alltoall(dtype, shard_type, transpose_weight)
+ rank = int(os.environ["LOCAL_RANK"])
+ if rank == 0:
+ for i, y in enumerate(y_npu):
+ y.cpu().numpy().tofile(f"./y_{i}.bin")
+
+```
diff --git a/model/train/yoco_moe/docs/ops/npu_dropout_add_layer_norm.md b/model/train/yoco_moe/docs/ops/npu_dropout_add_layer_norm.md
new file mode 100644
index 0000000000000000000000000000000000000000..5685b546a9ee679d2ec4e4d2c51ee4969df21f23
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_dropout_add_layer_norm.md
@@ -0,0 +1,127 @@
+# npu_dropout_add_layer_norm 对外接口
+```
+# 计算逻辑
+# norm_result = LayerNorm(Dropout(x0 x rowscale x layerscale) + residual)
+def npu_dropout_add_layer_norm(x0,
+ residual,
+ weight,
+ bias,
+ dropout_p,
+ epsilon,
+ rowscale=None,
+ layerscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ return_dropout_mask=False)
+
+# 计算逻辑
+# norm_result = RmsNorm(Dropout(x0 x rowscale x layerscale) + residual)
+def npu_dropout_add_rms_norm(x0,
+ residual,
+ weight,
+ bias,
+ dropout_p,
+ epsilon,
+ rowscale=None,
+ layerscale=None,
+ prenorm=False,
+ residual_in_fp32=False,
+ return_dropout_mask=False)
+```
+
+输入:
+- x0:必选输入,shape:(B,S,H)。
+- residual:必选输入,shape:(B,S,H),可输入None。表示残差。
+- weight:必选输入,shape:(H,)。表示归一化处理时的权重参数。
+- bias:必选输入,shape:(H,),数据类型与输入weight一致,可输入None。表示归一化处理时的偏置参数。
+- dropout_p:必选属性,数据类型float。表示Dropout舍弃概率,eval模式下p=0。
+- epsilon:必选属性,数据类型float。归一化处理时,添加到分母中的值,以提高数值稳定性。
+- rowscale:可选输入,shape:(B,S),数据类型与输入x0一致,默认值:None。表示矩阵按行缩放比例。
+- layerscale:可选输入,shape:(H,),数据类型与输入x0一致,默认值:None。表示矩阵按列缩放比例。
+- prenorm:可选属性,数据类型为bool,默认值:False。表示是否返回输出pre_norm_result。
+- residual_in_fp32:可选属性,数据类型为bool,默认值:False。仅在输入residual不为None时有意义。
+- return_dropout_mask:可选属性,数据类型为bool,默认值:False。表示是否返回输出drop_mask。
+
+支持的输入数据类型组合:
+
+| x0 | residual | weight | rowscale | layerscale |
+| ----- | ----- | ----- | ----- | ----- |
+|fp32|fp32|fp32|fp32|fp32 |
+|fp16|fp16|fp16|fp16|fp16 |
+|bf16|bf16|bf16|bf16|bf16 |
+
+
+输出:
+- norm_result:必选输出,数据类型与输入x0一致。
+- pre_norm_result:可选输出,数据类型与输入residual一致。
+- mask_result:可选输出,数据类型为bool。
+
+***
+## 案例一:npu_dropout_add_layer_norm 接口的调用方式
+
+- 输入 x0 和 weight
+- 结果只返回 norm_result
+
+```python
+import torch
+import torch_npu
+
+from mindspeed.ops.dropout_add_layer_norm import npu_dropout_add_layer_norm
+
+
+batch, seq, hidden_size = 6, 60, 1024
+x0 = torch.randn((batch, seq, hidden_size), requires_grad=True).to(torch.float).npu()
+weight = torch.randn((hidden_size, ), requires_grad=True).to(torch.float).npu()
+residual, bias, rowscale, layerscale = None, None, None, None
+dropout_p = 0.0
+epsilon = 1e-5
+prenorm, residual_in_fp32, return_dropout_mask = False, True, False
+
+# 正向接口案例
+norm_result = npu_dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon,
+ rowscale, layerscale, prenorm, residual_in_fp32, return_dropout_mask)
+
+g = torch.randn_like(norm_result)
+norm_result.backward(g)
+x0_grad = x0.grad
+weight_grad = weight.grad
+
+```
+
+## 案例二:npu_dropout_add_layer_norm 接口的调用方式
+- 输入 x0, residual, weight, rowscale, layerscale
+- 结果返回 norm_result, pre_norm_result, mask_result
+
+```python
+import torch
+import torch_npu
+
+from mindspeed.ops.dropout_add_layer_norm import npu_dropout_add_layer_norm
+
+
+batch, seq, hidden_size = 6, 60, 1024
+x0 = torch.randn((batch, seq, hidden_size), requires_grad=True).to(torch.float).npu()
+residual = torch.randn((batch, seq, hidden_size), requires_grad=True).to(torch.float).npu()
+weight = torch.randn((hidden_size, ), requires_grad=True).to(torch.float).npu()
+bias = torch.randn((hidden_size, ), requires_grad=True).to(torch.float).npu()
+rowscale = torch.randn((batch, seq, ), requires_grad=True).to(torch.float).npu()
+layerscale = torch.randn((hidden_size, ), requires_grad=True).to(torch.float).npu()
+dropout_p = 0.0
+epsilon = 1e-5
+prenorm, residual_in_fp32, return_dropout_mask = True, True, True
+
+# 正向接口案例
+norm_result, pre_norm_result, mask_result = npu_dropout_add_layer_norm(x0, residual, weight,
+ bias, dropout_p, epsilon,
+ rowscale, layerscale, prenorm,
+ residual_in_fp32, return_dropout_mask)
+
+g = torch.randn_like(norm_result)
+norm_result.backward(g)
+x0_grad = x0.grad
+residual_grad = residual.grad
+weight_grad = weight.grad
+bias_grad = bias.grad
+rowscale_grad = rowscale.grad
+layerscale_grad = layerscale.grad
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/npu_fused_moe_token_permute.md b/model/train/yoco_moe/docs/ops/npu_fused_moe_token_permute.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ca4166a9db4617b5a734d1edad04578297ebf0e
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_fused_moe_token_permute.md
@@ -0,0 +1,121 @@
+# npu_moe_token_permute对外接口
+
+npu_moe_token_permute(
+ tokens: torch.Tensor,
+ indices: torch.Tensor,
+ num_out_tokens: int = None,
+ padded_mode: bool = False
+)
+
+小算子等价计算逻辑:
+```python
+import torch
+
+def permute_with_padded_tokens(tokens, indices):
+ """Permute the tokens based on the indices, only used in padding mode.
+ The input indices shape is [num_expert, capacity], it indicates which tokens were selected by each expert separately.
+ Args:
+ tokens (torch.Tensor): The input token tensor.
+ indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
+
+ Returns:
+ torch.Tensor: The permuted tensor.
+ torch.Tensor: The sorted_indices corresponding permuted tensor.
+ """
+ permuted_tokens = tokens.index_select(dim=0, index=indices.view(-1))
+
+ return permuted_tokens, indices
+
+
+def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False):
+ """Permute the tokens based on the indices. Token with the same index will be grouped together.
+ The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
+ Args:
+ tokens (torch.Tensor): The input token tensor.
+ indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
+ num_out_tokens (int, optional): The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped.
+ padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
+
+ Returns:
+ torch.Tensor: The permuted tensor.
+ torch.Tensor: The sorted_indices corresponding permuted tensor.
+ """
+ if padded_mode:
+ return permute_with_padded_tokens(tokens, indices)
+
+ if indices.dim() == 1:
+ topk = 1
+ else:
+ topk = indices.size(1)
+ flatten_indices = indices.view(-1)
+ sorted_indices = torch.argsort(flatten_indices, stable=True)
+ sorted_indices1 = torch.argsort(sorted_indices, stable=True)
+
+ if num_out_tokens is not None:
+ sorted_indices = sorted_indices[:num_out_tokens]
+ permuted_tokens = tokens.index_select(0, sorted_indices // topk)
+ return permuted_tokens, sorted_indices1
+```
+
+## 前向接口:
+
+输入:
+
+- tokens:必选输入,2维Tensor,数据类型bfloat16(当前版本tokens仅支持bfloat16)
+- indices: 必选输入,2维Tensor,数据类型int64
+
+输出:
+
+- permuted_tokens:必选输出,2维Tensor,数据类型bfloat16(当前版本permuted_tokens仅支持bfloat16)
+- sorted_indices:必选输出,1维Tensor,数据类型int32(当前版本sorted_indices仅支持int32)
+
+属性:
+
+- num_out_tokens:可选属性,数据类型int64_t,表示有效输出token数
+- padded_mode: 可选属性,数据类型int64_t,如果为 True,则表示索引被填充到 [num_expert,capacity] 以表示每个专家选择的token
+
+
+## 反向接口:
+
+输入:
+
+- grad_permuted_tokens:必选输入,2维Tensor,数据类型bfloat16(当前版本grad_permuted_tokens仅支持bfloat16)
+- sorted_indices:必选输入,2维Tensor,数据类型int32(当前版本sorted_indices1仅支持int32)
+
+输出:
+
+- grad_tokens:必选输出,2维Tensor,数据类型bfloat16(当前版本grad_tokens仅支持bfloat16)
+
+属性:
+
+- num_topK:必选属性,数据类型int64_t,表示每条token输出的专家个数
+- padded_mode:可选属性,数据类型int64_t,表示有效输出token数
+
+
+**备注**:
+1. 目前仅支持padded_mode为False
+2. 目前仅支持bfloat16
+
+
+
+## 案例
+
+```python
+import torch
+import torch_npu
+
+from mindspeed.ops.npu_moe_token_permute import npu_moe_token_permute
+
+dtype = torch.bfloat16
+tokens = torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [0, 0, 0]]).npu().to(dtype).requires_grad_(True)
+indices = torch.tensor([[0, 4], [4, 3], [4, 2], [1, 1]]).npu()
+num_out_tokens = indices.numel()
+probs = torch.ones_like(indices) / 2
+probs = probs.npu().to(dtype)
+# 正向接口案例
+permuted_tokens, sorted_indices = npu_moe_token_permute(tokens, indices, num_out_tokens)
+
+# 反向接口案例
+permuted_tokens.backward(torch.ones(permuted_tokens.shape).to(torch.bfloat16).npu())
+
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/npu_fused_moe_token_unpermute.md b/model/train/yoco_moe/docs/ops/npu_fused_moe_token_unpermute.md
new file mode 100644
index 0000000000000000000000000000000000000000..dfb1f7cb7aa496ec0a99af7fe59727c38173afa7
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_fused_moe_token_unpermute.md
@@ -0,0 +1,179 @@
+# npu_moe_token_unpermute对外接口
+
+npu_moe_token_unpermute(
+ permuted_tokens: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ probs: torch.Tensor = None,
+ padded_mode: bool = False,
+ restore_shape: torch.Size = None,
+)
+
+小算子等价计算逻辑:
+```python
+import torch
+
+def unpermute_with_padded_tokens(
+ permuted_tokens: torch.Tensor,
+ indices: torch.Tensor,
+ probs: torch.Tensor,
+ restore_shape: torch.Size,
+) -> torch.Tensor:
+ """
+ Unpermutes a padded permuted tokens based on sorted indices and merges the tokens with their corresponding probabilities.
+
+ This function takes a tensor of permuted tokens and reorders them according to the provided indices. It also combines the tokens with their associated probabilities.
+
+ Parameters:
+ permuted_tokens (torch.Tensor): A 2D tensor containing permuted tokens.
+ indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
+ probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities corresponding to each token.
+ restore_shape (torch.Size): The target shape for the unpermuted tokens tensor.
+
+ Returns:
+ torch.Tensor: A tensor of unpermuted tokens, merged with their probabilities.
+
+ """
+ # Ensure permuted_tokens is 2D
+ assert permuted_tokens.dim() == 2, f"Got {permuted_tokens.dim()}D."
+
+ # Reshape and expand probabilities and indices to match permuted_tokens
+ probs = probs.view(-1).unsqueeze(-1)
+ indices = indices.view(-1, 1).expand(-1, permuted_tokens.shape[1])
+ assert (
+ permuted_tokens.shape == indices.shape
+ ), "Shape mismatch between permuted_tokens and indices."
+
+ # Combine tokens with their probabilities
+ combined_output = probs * permuted_tokens
+
+ # Prepare a tensor of zeros with the desired output shape
+ empty_tokens = torch.zeros(
+ restore_shape,
+ dtype=combined_output.dtype,
+ device=combined_output.device,
+ requires_grad=True,
+ )
+
+ # Scatter the combined tokens back to their original positions
+ unpermuted_tokens = torch.scatter_add(empty_tokens, 0, indices, combined_output)
+
+ return unpermuted_tokens
+
+def unpermute(
+ permuted_tokens: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ probs: torch.Tensor = None,
+ padded_mode: bool = False,
+ restore_shape: torch.Size = None,
+):
+ """Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
+
+ Args:
+ permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
+ sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
+ probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
+ padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
+ restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None.
+
+ Returns:
+ torch.Tensor: The unpermuted tokens, optionally merged with probabilities.
+ """
+ if padded_mode:
+ return unpermute_with_padded_tokens(
+ permuted_tokens, sorted_indices, probs, restore_shape=restore_shape
+ )
+
+ assert sorted_indices.numel() == permuted_tokens.size(0)
+ if probs is not None:
+ # Unpermute and merge the tokens with their probabilities
+ num_unpermuted_tokens = probs.numel()
+ topk = probs.size(1)
+ else:
+ # Unpermute the tokens without merge
+ num_unpermuted_tokens = permuted_tokens.size(0)
+ topk = 1
+
+ unpermuted_tokens = torch.zeros(
+ [num_unpermuted_tokens, permuted_tokens.shape[-1]],
+ dtype=permuted_tokens.dtype,
+ device=permuted_tokens.device,
+ )
+ unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
+ unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
+ if probs is not None:
+ unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
+ unpermuted_tokens = unpermuted_tokens.sum(dim=1)
+
+ return unpermuted_tokens
+```
+
+## 前向接口:
+
+输入:
+
+- permuted_tokens:必选输入,2维Tensor,数据类型bfloat16(当前版本permuted_tokens仅支持bfloat16)
+- sorted_indices: 必选输入,1维Tensor,数据类型int32(当前版本sorted_indices仅支持int32)
+- probs:可选输入,2维Tensor,数据类型bfloat16(当前版本probs仅支持bfloat16)
+
+输出:
+
+- unpermuted_tokens:必选输出,2维Tensor,数据类型bfloat16(当前版本unpermuted_tokens仅支持bfloat16)
+
+属性:
+
+- padded_mode: 可选属性,数据类型int64_t,如果为 True,则表示索引被填充到 [num_expert,capacity] 以表示每个专家选择的token
+
+
+## 反向接口:
+
+输入:
+
+- permuted_tokens:必选输入,2维Tensor,数据类型bfloat16(当前版本permuted_tokens仅支持bfloat16)
+- grad_unpermuted_tokens:必选输入,2维Tensor,数据类型bfloat16(当前版本grad_permuted_tokens仅支持bfloat16)
+- sorted_indices: 必选输入,1维Tensor,数据类型int32(当前版本sorted_indices仅支持int32)
+- probs:可选输入,2维Tensor,数据类型bfloat16(当前版本probs仅支持bfloat16)
+
+输出:
+
+- grad_permuted_tokens:必选输出,2维Tensor,数据类型bfloat16(当前版本grad_permuted_tokens仅支持bfloat16)
+- grad_probs:必选输出,2维Tensor,数据类型bfloat16(当前版本grad_probs仅支持bfloat16)
+
+属性:
+
+- padded_mode:可选属性,数据类型int64_t,表示有效输出token数
+
+
+**备注**:
+1. 目前仅支持padded_mode为False
+2. 目前仅支持bfloat16
+
+
+## 案例
+
+```python
+import torch
+import torch_npu
+
+from mindspeed.ops.npu_moe_token_unpermute import npu_moe_token_unpermute
+
+dtype = torch.bfloat16
+permuted_tokens = torch.tensor([[1., 1., 1.],
+ [0., 0., 0.],
+ [0., 0., 0.],
+ [3., 3., 3.],
+ [2., 2., 2.],
+ [1., 1., 1.],
+ [2., 2., 2.],
+ [3., 3., 3.]]).npu().to(dtype).requires_grad_(True)
+sorted_indices = torch.tensor([0, 6, 7, 5, 3, 1, 2, 4], dtype=torch.int32).npu()
+indices = torch.tensor([[0, 4], [4, 3], [4, 2], [1, 1]]).npu()
+probs = torch.ones_like(indices) / 2
+probs = probs.npu().to(dtype).requires_grad_(True)
+
+# 正向接口案例
+unpermuted_tokens = npu_moe_token_unpermute(
+ permuted_tokens, sorted_indices, probs=probs)
+
+# 反向接口案例
+unpermuted_tokens.backward(torch.ones(unpermuted_tokens.shape).to(torch.bfloat16).npu())
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/npu_grouped_mat_mul_all_reduce.md b/model/train/yoco_moe/docs/ops/npu_grouped_mat_mul_all_reduce.md
new file mode 100644
index 0000000000000000000000000000000000000000..d4c522ab6df3e1e96bb713684b33e0463a3f24c0
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_grouped_mat_mul_all_reduce.md
@@ -0,0 +1,135 @@
+# npu_grouped_mat_mul_all_reduce对外接口
+```
+def npu_grouped_mat_mul_all_reduce(x: List[torch.Tensor],
+ weight: List[torch.Tensor],
+ hcomm: str,
+ *,
+ bias: Optional[List[torch.Tensor]] = None,
+ group_list: Optional[List[int]] = None,
+ split_item: Optional[int] = 0,
+ reduce_op: str = "sum",
+ comm_turn: int = 0) -> List[torch.Tensor]
+```
+
+计算逻辑:
+GroupedMatMulAllReduce算子是GroupedMatmul算子的多卡通信版本。它可以实现分组矩阵计算,每组矩阵乘的维度大小可以不同,是一种灵活的组合方式。输入数据x和输出数据y均支持切分或不切分模式,可以根据参数split_item来确定是否切分。当x需要切分时,使用group_list参数来描述x的m轴切分配置。本算子增加了AllReduce集合通信操作,可以把矩阵乘任务切分到多张卡上并行计算,然后通过AllReduce集合通信操作把所有卡的计算结果加和到一起,最终完成整个任务。根据输入x、weight和输出y的Tensor数量,本算子可以支持如下四种场景:
+- x、weight、y的tensor数量均等于组数group_num,即每组的数据对应的tensor是独立的。
+- x的tensor数量为1, weight和y的tensor数量等于组数,此时需要通过group_list来说明x在m轴方向上的分组情况。如group_list[0]=10说明x矩阵的前10行参与第一组矩阵乘计算。
+- x、weight的tensor数量均等于组数group_num, y的tensor数量为1,此时每组矩阵乘的结果放在同一个输出tensor中连续存放。
+- x、y的tensor数量均为1,weight的tensor数量等于组数,属于前两种情况的组合。
+
+计算公式为:
+对于每一组矩阵乘任务i: $$y_i = x_i * weight_i + bias_i$$
+切分到n张卡上后,计算形式可表示为:
+
+$$
+y_i = [x_{i1}, x_{i2}, ..., x_{in}] *
+\begin{bmatrix}
+{weight_{i1}} \\
+{weight_{i2}} \\
+{...} \\
+{weight_{in}}
+\end{bmatrix}+\sum^{n}{bias_i/n}
+$$
+
+## 前向接口:
+输入:
+- x:必选输入,List[Tensor],数据类型float16,bfloat16。支持的最大长度为64个。
+- weight:必选输入,List[Tensor],数据类型float16, bfloat16。支持的最大长度为64个。
+- bias:可选输入,List[Tensor],数据类型float16, float32。支持的最大长度为64个。对于实际无bias的场景,可以直接不指定bias参数或设置为None。
+- group_list:可选输入,Optional[List[int64]],缺省none。代表输入和输出M方向的matmul大小分布,支持的最大长度为64个。
+
+输出:
+- y:List[Tensor],数据类型float16, bfloat16。支持的最大长度为64个。
+
+属性:
+- split_item:可选属性,int64。代表输入和输出是否要做tensor切分,0代表输入和输出都不用切分;1代表输入需要切分,输出不需要切分;2代表输入不需要切分,输出需要切分;3代表输入和输出都需要切分。缺省0。
+- hcomm:必选属性,数据类型支持:string。表示通信域名称,Host侧标识列组的字符串。通过Hccl提供的接口获取。
+- reduce_op:可选属性,数据类型支持:string。reduce操作类型。**当前版本仅支持输入"sum"。**
+- comm_turn:可选属性,int64。Host侧的整型,通信数据切分数,即总数据量/单次通信量。**当前版本仅支持输入0。**
+
+
+## 反向接口
+无
+
+## 输入限制
+- List支持的最大长度为64;
+- 属性reduce_op仅支持输入"sum";
+- 属性comm_turn仅支持输入0;
+- 通信支持2、4、8卡。
+- 当split_item取0或2时,x的tensor数量与weight的相同;当split_item取1或3时,x的tensor数量为1。
+- 当split_item取0或2时,group_list为空;当split_item取1或3时,group_list的长度与weight的tensor个数相同。
+- 若bias不是空,其tensor数量要与weight的相同。
+- 满足矩阵乘的m/k/n维度关系。
+- 支持输入元素类型:
+ 1. x中元素类型为float16,weight中元素类型为float16,bias中元素类型为float16,输出y中元素类型为float16;
+
+ 2. x中元素类型为bfloat16,weight中元素类型为bfloat16,bias中元素类型为float32,输出y中元素类型为bfloat16;
+- 输入中须包含通信域hcomm字符串,它需要从torch.distributed包中的接口获取。
+- 当前只支持Pytorch2.1版本。
+
+## npu_grouped_mat_mul_all_reduce 类的调用方式
+
+```python
+import os
+import torch
+import torch_npu
+import torch.multiprocessing as mp
+import torch.distributed as dist
+import numpy as np
+import math
+from mindspeed.ops.npu_grouped_mat_mul_all_reduce import npu_grouped_mat_mul_all_reduce
+
+
+def get_hcomm_info(world_size, rank):
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method='tcp://127.0.0.1:5001')
+ print(f'device_{rank} init_process_group success.')
+ if dist.is_available():
+ from torch.distributed.distributed_c10d import _get_default_group
+ default_pg = _get_default_group()
+ if torch.__version__ > '2.0.1':
+ hcomm_info = default_pg._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
+ else:
+ hcomm_info = default_pg.get_hccl_comm_name(rank)
+ print(f'device_{rank} get_hccl_comm_name success.')
+ return dist, hcomm_info
+
+
+def run(rank, n_device):
+ torch_npu.npu.set_device(rank)
+ dist, hcomm_info = get_hcomm_info(n_device, rank)
+
+ x1 = torch.randn(32, 256, dtype=torch.float16)
+ x1_npu = x1.npu()
+ w1 = torch.randn(256, 32, dtype=torch.float16)
+ w1_npu = w1.npu()
+
+ x2 = torch.randn(64, 128, dtype=torch.float16)
+ x2_npu = x2.npu()
+ w2 = torch.randn(128, 64, dtype=torch.float16)
+ w2_npu = w2.npu()
+
+ group_list = None
+ split_item = 0
+ hccl_group = hcomm_info
+ reduce_op = "sum"
+ comm_turn = 0
+
+ # 正向接口案例
+ x1_npu.requires_grad = False
+ x2_npu.requires_grad = False
+ w1_npu.requires_grad = False
+ w2_npu.requires_grad = False
+ y_npu = npu_grouped_mat_mul_all_reduce([x1_npu, x2_npu], [w1_npu, w2_npu], hccl_group, bias=[],
+ group_list=group_list, split_item=split_item, reduce_op=reduce_op,
+ comm_turn=comm_turn)
+ if rank == 0:
+ for i, y in enumerate(y_npu):
+ y.cpu().numpy().tofile(f"./y_{i}.bin")
+
+
+if __name__ == '__main__':
+ n_device = 2
+ mp.spawn(run, args=(n_device,), nprocs=n_device)
+
+```
diff --git a/model/train/yoco_moe/docs/ops/npu_groupmatmul_add.md b/model/train/yoco_moe/docs/ops/npu_groupmatmul_add.md
new file mode 100644
index 0000000000000000000000000000000000000000..d45cabcfa984b37c4892d41fd6497481ad2f030a
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_groupmatmul_add.md
@@ -0,0 +1,35 @@
+# npu_groupmatmul_add_fp32对外接口(只支持前向)
+输入:
+- x:必选输入,维度[m,k], 数据类型float16, bf16
+- weight:必选输入,维度[k,n], 数据类型float16, bf16
+- groupNum:必选输入, 维度[NumGroup], 数据类型int64
+- output:必选输入,维度[w1,w2],数据类型float32
+
+输出:
+- output:必选输出,数据类型float32
+
+## 案例
+```python
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import pytest
+import torch
+import mindspeed.megatron_adaptor
+import torch_npu
+from mindspeed.ops.gmm import GMMFunction
+from mindspeed.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32
+
+x = torch.rand((1024, 4096), dtype=torch.float16).npu()
+grads_out = torch.rand((1024, 2048), dtype=torch.float16).npu()
+group_list = torch.tensor([500, 1024], dtype=torch.int64).npu()
+out = torch.rand((8192, 2048), dtype=torch.float32).npu()
+weight = torch.rand((2, 4096, 2048), dtype=torch.float16).npu()
+# 分开算子计算结果
+_, dy, _ = GMMFunction.builder.load().npu_gmm_backward([grads_out], [x], [weight], group_list, 0)
+out_single = out+dy[0].view(*out.shape)
+#融合算子计算结果
+x = x.clone().detach()
+grads_out = grads_out.clone().detach()
+group_list = group_list.clone().detach()
+out = out.clone().detach()
+npu_groupmatmul_add_fp32(x, grads_out, group_list, out)
+```
diff --git a/model/train/yoco_moe/docs/ops/npu_matmul_add.md b/model/train/yoco_moe/docs/ops/npu_matmul_add.md
new file mode 100644
index 0000000000000000000000000000000000000000..1824236513d1a3475552581842e6fe5c48fad97f
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_matmul_add.md
@@ -0,0 +1,24 @@
+# npu_matmul_add_fp32对外接口(只支持前向)
+输入:
+- x:必选输入,数据类型float16, bf16
+- weight:必选输入,数据类型float16, bf16
+- C:必选输入,数据类型float32
+
+输出:
+- output:必选输出,数据类型float32
+
+## 案例
+```python
+ import torch
+ import torch_npu
+ from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32
+
+ x = torch.rand((4096, 8192),dtype=torch.float16).npu()
+ weight = torch.rand((4096, 8192),dtype=torch.float16).npu()
+ C = torch.rand((8192, 8192),dtype=torch.float32).npu()
+ # 分开算子计算过程
+ product = torch.mm(x.T, weight)
+ result = product + C
+ # 融合算子计算过程
+ npu_matmul_add_fp32(weight, x, C)
+```
diff --git a/model/train/yoco_moe/docs/ops/npu_mm_all_reduce_add_rms_norm.md b/model/train/yoco_moe/docs/ops/npu_mm_all_reduce_add_rms_norm.md
new file mode 100644
index 0000000000000000000000000000000000000000..935ec2e8efb560c6226817d0b5588bb90eb774b0
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_mm_all_reduce_add_rms_norm.md
@@ -0,0 +1,151 @@
+# npu_mm_all_reduce_add_rms_norm对外接口
+
+CLASS MatmulAllReduceAddRmsNorm()
+
+计算逻辑:
+$$
+mmOut = allReduce(x1*x2 + bias)
+$$
+$$
+y = mmOut + residual
+$$
+$$
+normOut = \frac{y}{RMS(y)}*gamma, RMS(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} y_{i}^{2} + epsilon}
+$$
+
+## 非量化场景:
+输入:
+- x1:必选输入,数据类型float16, bfloat16
+- x2:必选输入,数据类型float16, bfloat16
+- residual:必选输入,数据类型float16, bfloat16
+- gamma:必选输入,数据类型float16, bfloat16
+- hcom:必选输入,数据类型string,
+- reduce_op:可选输入,数据类型为string,当前仅支持sum
+- epsilon:可选输入,数据类型为float,缺省情况下为1e-06
+- bias:可选输入,数据类型float16, bfloat16
+- antiquant_scale:可选输入,该场景默认为nullptr
+- antiquant_offset:可选输入,该场景默认为nullptr
+- dequant_scale:可选输入,该场景默认为nullptr
+- antiquant_group_size:可选输入,该场景默认为0
+- comm_turn:可选输入,数据类型为int,缺省情况下为0
+
+输出:
+- y:必选输出,数据类型float16, bfloat16
+- normOut:必选输出,数据类型float16, bfloat16
+
+## 全量化场景
+输入:
+- x1:必选输入,数据类型int8
+- x2:必选输入,数据类型int8
+- residual:必选输入,数据类型float16, bfloat16
+- gamma:必选输入,数据类型float16, bfloat16
+- hcom:必选输入,数据类型string,
+- reduce_op:可选输入,数据类型为string,当前仅支持sum
+- epsilon:可选输入,数据类型为float,缺省情况下为1e-06
+- bias:可选输入,数据类型int32
+- antiquant_scale:可选输入,该场景默认为nullptr
+- antiquant_offset:可选输入,该场景默认为nullptr
+- dequant_scale:可选输入,数据类型int64,uint64,bfloat16
+- antiquant_group_size:可选输入,该场景默认为0
+- comm_turn:可选输入,数据类型为int,缺省情况下为0
+
+输出:
+- y:必选输出,数据类型float16, bfloat16
+- normOut:必选输出,数据类型float16, bfloat16
+
+## 伪量化场景
+输入:
+- x1:必选输入,数据类型float16, bfloat16
+- x2:必选输入,数据类型int8
+- residual:必选输入,数据类型float16, bfloat16
+- gamma:必选输入,数据类型float16, bfloat16
+- hcom:必选输入,数据类型string,
+- reduce_op:可选输入,数据类型为string,当前仅支持sum
+- epsilon:可选输入,数据类型为float,缺省情况下为1e-06
+- bias:可选输入,数据类型float16, bfloat16
+- antiquant_scale:可选输入,数据类型float16, bfloat16
+- antiquant_offset:可选输入,数据类型float16, bfloat16
+- dequant_scale:可选输入,该场景默认为nullptr
+- antiquant_group_size:可选输入,数据类型为int,缺省情况下为0
+- comm_turn:可选输入,数据类型为int,缺省情况下为0
+
+输出:
+- y:必选输出,数据类型float16, bfloat16
+- normOut:必选输出,数据类型float16, bfloat16
+
+## 输入限制
+- ``x2`` 仅支持最后两轴转置情况下的非连续tensor传入,``x1``、``residual``、``gamma`` 等输入仅支持连续tensor
+- 仅支持ND数据格式
+- ``x1`` 支持两维或者三维,其维度为 ``(b, s, k)`` 或者 ``(s, k)``
+- ``x2`` 仅支持两维,其维度为 ``(k, n)``,``x1`` 和 ``x2`` 的轴满足matmul算子入参要求,k轴相等
+- ``bias`` 在非空情况下为1维,其维度为 ``(n)``
+- ``residual`` 仅支持三维,其维度为 ``(b, s, n)``,当 ``x1`` 为两维时,``residual`` 的 ``(b * s)`` 等于 ``x1`` 的 ``s``,当 ``x1`` 为三维时,``residual`` 的 ``(b * s)`` 等于 ``x1`` 的 ``(b * s)``;``residual`` 的最后一维与``x2`` 的最后一维相等
+- ``gamma`` 仅支持一维,其维度为 ``(n)``,``gamma`` 的最后一维与 ``residual`` 的最后一维相等
+- ``reduce_op`` 仅支持 ``sum``
+- 昇腾Atlas A2 AI处理器支持1、2、4、8卡,并且仅支持hccs链路all mesh组网
+- 昇腾Atlas A2 AI处理器支持``(b * s)``,``n``为0的空tensor,不支持``k``为0的空tensor
+- 非量化场景下,``x1``、``x2``、``bias``(若支持)、``residual``、``gamma`` 计算输入的数据类型要一致
+- 昇腾Atlas A2 AI处理器,在非量化场景下,``(b * s)``、``k``、``n``的范围为``[1, 2147483647]``
+- 全量化场景下,若输出 ``residual`` 类型为 ``FLOAT16``,``dequant_scale`` 的类型为 ``INT64``、``UINT64``(需通过 ``torch_npu.npu_trans_quant_param()`` 接口对 ``dequant_scale`` 进行处理);若输出 ``residual`` 类型为 ``BFLOAT16``,``dequant_scale`` 的类型为 ``BFLOAT16``。``dequant_scale`` 满足两种模式:
+ - ``per_tensor`` 模式:``(1,)``
+ - ``per_channel`` 模式:``(1, n)`` 或 ``(n,)``
+- 全量化场景下,``x1``、``x2`` 数据类型为 ``int8``,``bias``(若支持)数据类型为 ``int32``,``residual``、``gamma``计算输入的数据类型要一致。
+- 全量化场景下,``m``大小不超过2147483647,``x1``与``x2``的最后一维大小不超过65535,``x1``的最后一维指``k``,``x2``的最后一维指转置时的``k``或非转置时的``n``。
+- 伪量化场景下,``m``的范围为``[1, 2147483647]``,``k``、``n``的范围为``[1,65535]``
+- 伪量化场景下,``antiquant_scale`` 满足三种模式:
+ - ``per_tensor`` 模式:``(1,)``
+ - ``per_channel`` 模式:``(1, n)`` 或 ``(n,)``
+ - ``per_group`` 模式:``(ceil(k,antiquant_group_size),n)``
+- ``antiquantOffset`` 若非空,shape 与 ``antiquant_scale``一致。
+- 伪量化场景下,``x2`` 的数据类型需为 ``int8``,``x1``、``bias``(若支持)、``residual``、``gamma``、``antiquant_scale``、``antiquant_offset``计算输入的数据类型要一致。
+- 伪量化场景下,``antiquant_group_size`` 取值满足取值范围``[32, min(k-1, INT_MAX)]``且为32倍数。
+- 一个模型中的通算融合MC2算子,仅支持相同通信域。
+
+## npu_mm_all_redcue_add_rms_norm 接口的调用方式
+
+```python
+import torch
+import torch_npu
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from mindspeed.ops.npu_mm_all_reduce_add_rms_norm import npu_mm_all_reduce_add_rms_norm
+
+
+def run_mm_all_reduce_add_rms_norm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, residual_shape,
+ gamma_shape, dtype):
+ torch_npu.npu.set_device(rank)
+ init_method = 'tcp://' + master_ip + ':' + master_port
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
+ from torch.distributed.distributed_c10d import _get_default_group
+ default_pg = _get_default_group()
+ if torch.__version__ > '2.0.1':
+ hcom_info = default_pg._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
+ else:
+ hcom_info = default_pg.get_hccl_comm_name(rank)
+
+ x1 = torch.randn(x1_shape, dtype=dtype).npu()
+ x2 = torch.randn(x2_shape, dtype=dtype).npu()
+ residual = torch.randn(residual_shape, dtype=dtype).npu()
+ gamma = torch.randn(gamma_shape, dtype=dtype).npu()
+ epsilon = 0.000001
+ y, norm_out = npu_mm_all_reduce_add_rms_norm(x1=x1, x2=x2, residual=residual, gamma=gamma, hcom=hcom_info,
+ reduce_op='sum', epsilon=epsilon)
+ print("y:", y)
+ print("norm_out:", norm_out)
+
+
+if __name__ == "__main__":
+ worksize = 8
+ master_ip = "127.0.0.1"
+ master_port = '50001'
+ b, s, k, n = 4, 1024, 1024, 8192
+ x1_shape = (b, s, k)
+ x2_shape = (k, n)
+ residual_shape = (b, s, n)
+ gamma_shape = (n)
+ dtype = torch.float16
+
+ mp.spawn(run_mm_all_reduce_add_rms_norm,
+ args=(worksize, master_ip, master_port, x1_shape, x2_shape, residual_shape, gamma_shape, dtype),
+ nprocs=worksize)
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/npu_mm_all_reduce_add_rms_norm_.md b/model/train/yoco_moe/docs/ops/npu_mm_all_reduce_add_rms_norm_.md
new file mode 100644
index 0000000000000000000000000000000000000000..c6b2ea1da236a49348363a02c8f457006401f098
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_mm_all_reduce_add_rms_norm_.md
@@ -0,0 +1,151 @@
+# npu_mm_all_reduce_add_rms_norm_对外接口
+
+CLASS InplaceMatmulAllReduceAddRmsNorm()
+
+计算逻辑:
+$$
+mmOut = allReduce(x1*x2 + bias)
+$$
+$$
+y = mmOut + residual
+$$
+$$
+normOut = \frac{y}{RMS(y)}*gamma, RMS(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} y_{i}^{2} + epsilon}
+$$
+
+## 非量化场景:
+输入:
+- x1:必选输入,数据类型float16, bfloat16
+- x2:必选输入,数据类型float16, bfloat16
+- residual:必选输入,数据类型float16, bfloat16
+- gamma:必选输入,数据类型float16, bfloat16
+- hcom:必选输入,数据类型string,
+- reduce_op:可选输入,数据类型为string,当前仅支持sum
+- epsilon:可选输入,数据类型为float,缺省情况下为1e-06
+- bias:可选输入,数据类型float16, bfloat16
+- antiquant_scale:可选输入,该场景默认为nullptr
+- antiquant_offset:可选输入,该场景默认为nullptr
+- dequant_scale:可选输入,该场景默认为nullptr
+- antiquant_group_size:可选输入,该场景默认为0
+- comm_turn:可选输入,数据类型为int,缺省情况下为0
+
+输出:
+- residual:必选输出,复用residual,数据类型float16, bfloat16
+- normOut:必选输出,数据类型float16, bfloat16
+
+## 全量化场景
+输入:
+- x1:必选输入,数据类型int8
+- x2:必选输入,数据类型int8
+- residual:必选输入,数据类型float16, bfloat16
+- gamma:必选输入,数据类型float16, bfloat16
+- hcom:必选输入,数据类型string,
+- reduce_op:可选输入,数据类型为string,当前仅支持sum
+- epsilon:可选输入,数据类型为float,缺省情况下为1e-06
+- bias:可选输入,数据类型int32
+- antiquant_scale:可选输入,该场景默认为nullptr
+- antiquant_offset:可选输入,该场景默认为nullptr
+- dequant_scale:可选输入,数据类型int64,uint64,bfloat16
+- antiquant_group_size:可选输入,该场景默认为0
+- comm_turn:可选输入,数据类型为int,缺省情况下为0
+
+输出:
+- residual:必选输出,复用residual,数据类型float16, bfloat16
+- normOut:必选输出,数据类型float16, bfloat16
+
+## 伪量化场景
+输入:
+- x1:必选输入,数据类型float16, bfloat16
+- x2:必选输入,数据类型int8
+- residual:必选输入,数据类型float16, bfloat16
+- gamma:必选输入,数据类型float16, bfloat16
+- hcom:必选输入,数据类型string,
+- reduce_op:可选输入,数据类型为string,当前仅支持sum
+- epsilon:可选输入,数据类型为float,缺省情况下为1e-06
+- bias:可选输入,数据类型float16, bfloat16
+- antiquant_scale:可选输入,数据类型float16, bfloat16
+- antiquant_offset:可选输入,数据类型float16, bfloat16
+- dequant_scale:可选输入,该场景默认为nullptr
+- antiquant_group_size:可选输入,数据类型为int,缺省情况下为0
+- comm_turn:可选输入,数据类型为int,缺省情况下为0
+
+输出:
+- residual:必选输出,复用residual,数据类型float16, bfloat16
+- normOut:必选输出,数据类型float16, bfloat16
+
+## 输入限制
+- ``x2`` 仅支持最后两轴转置情况下的非连续tensor传入,``x1``、``residual``、``gamma`` 等输入仅支持连续tensor
+- 仅支持ND数据格式
+- ``x1`` 支持两维或者三维,其维度为 ``(b, s, k)`` 或者 ``(s, k)``
+- ``x2`` 仅支持两维,其维度为 ``(k, n)``,``x1`` 和 ``x2`` 的轴满足matmul算子入参要求,k轴相等
+- ``bias`` 在非空情况下为1维,其维度为 ``(n)``
+- ``residual`` 仅支持三维,其维度为 ``(b, s, n)``,当 ``x1`` 为两维时,``residual`` 的 ``(b * s)`` 等于 ``x1`` 的 ``s``,当 ``x1`` 为三维时,``residual`` 的 ``(b * s)`` 等于 ``x1`` 的 ``(b * s)``;``residual`` 的最后一维与``x2`` 的最后一维相等
+- ``gamma`` 仅支持一维,其维度为 ``(n)``,``gamma`` 的最后一维与 ``residual`` 的最后一维相等
+- ``reduce_op`` 仅支持 ``sum``
+- 昇腾Atlas A2 AI处理器支持1、2、4、8卡,并且仅支持hccs链路all mesh组网
+- 昇腾Atlas A2 AI处理器支持``(b * s)``,``n``为0的空tensor,不支持``k``为0的空tensor
+- 非量化场景下,``x1``、``x2``、``bias``(若支持)、``residual``、``gamma`` 计算输入的数据类型要一致
+- 昇腾Atlas A2 AI处理器,在非量化场景下,``(b * s)``、``k``、``n``的范围为``[1, 2147483647]``
+- 全量化场景下,若输出 ``residual`` 类型为 ``FLOAT16``,``dequant_scale`` 的类型为 ``INT64``、``UINT64``(需通过 ``torch_npu.npu_trans_quant_param()`` 接口对 ``dequant_scale`` 进行处理);若输出 ``residual`` 类型为 ``BFLOAT16``,``dequant_scale`` 的类型为 ``BFLOAT16``。``dequant_scale`` 满足两种模式:
+ - ``per_tensor`` 模式:``(1,)``
+ - ``per_channel`` 模式:``(1, n)`` 或 ``(n,)``
+- 全量化场景下,``x1``、``x2`` 数据类型为 ``int8``,``bias``(若支持)数据类型为 ``int32``,``residual``、``gamma``计算输入的数据类型要一致。
+- 全量化场景下,``m``大小不超过2147483647,``x1``与``x2``的最后一维大小不超过65535,``x1``的最后一维指``k``,``x2``的最后一维指转置时的``k``或非转置时的``n``。
+- 伪量化场景下,``m``的范围为``[1, 2147483647]``,``k``、``n``的范围为``[1,65535]``
+- 伪量化场景下,``antiquant_scale`` 满足三种模式:
+ - ``per_tensor`` 模式:``(1,)``
+ - ``per_channel`` 模式:``(1, n)`` 或 ``(n,)``
+ - ``per_group`` 模式:``(ceil(k,antiquant_group_size),n)``
+- ``antiquantOffset`` 若非空,shape 与 ``antiquant_scale``一致。
+- 伪量化场景下,``x2`` 的数据类型需为 ``int8``,``x1``、``bias``(若支持)、``residual``、``gamma``、``antiquant_scale``、``antiquant_offset``计算输入的数据类型要一致。
+- 伪量化场景下,``antiquant_group_size`` 取值满足取值范围``[32, min(k-1, INT_MAX)]``且为32倍数。
+- 一个模型中的通算融合MC2算子,仅支持相同通信域。
+
+## npu_mm_all_redcue_add_rms_norm 接口的调用方式
+
+```python
+import torch
+import torch_npu
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from mindspeed.ops.npu_mm_all_reduce_add_rms_norm_ import npu_mm_all_reduce_add_rms_norm_
+
+
+def run_mm_all_reduce_add_rms_norm(rank, world_size, master_ip, master_port, x1_shape, x2_shape, residual_shape,
+ gamma_shape, dtype):
+ torch_npu.npu.set_device(rank)
+ init_method = 'tcp://' + master_ip + ':' + master_port
+ dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=init_method)
+ from torch.distributed.distributed_c10d import _get_default_group
+ default_pg = _get_default_group()
+ if torch.__version__ > '2.0.1':
+ hcom_info = default_pg._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
+ else:
+ hcom_info = default_pg.get_hccl_comm_name(rank)
+
+ x1 = torch.randn(x1_shape, dtype=dtype).npu()
+ x2 = torch.randn(x2_shape, dtype=dtype).npu()
+ residual = torch.randn(residual_shape, dtype=dtype).npu()
+ gamma = torch.randn(gamma_shape, dtype=dtype).npu()
+ epsilon = 0.000001
+ y, norm_out = npu_mm_all_reduce_add_rms_norm_(x1=x1, x2=x2, residual=residual, gamma=gamma, hcom=hcom_info,
+ reduce_op='sum', epsilon=epsilon)
+ print("y:", y)
+ print("norm_out:", norm_out)
+
+
+if __name__ == "__main__":
+ worksize = 8
+ master_ip = "127.0.0.1"
+ master_port = '50001'
+ b, s, k, n = 4, 1024, 1024, 8192
+ x1_shape = (b, s, k)
+ x2_shape = (k, n)
+ residual_shape = (b, s, n)
+ gamma_shape = (n)
+ dtype = torch.float16
+
+ mp.spawn(run_mm_all_reduce_add_rms_norm,
+ args=(worksize, master_ip, master_port, x1_shape, x2_shape, residual_shape, gamma_shape, dtype),
+ nprocs=worksize)
+```
diff --git a/model/train/yoco_moe/docs/ops/npu_ring_attention_update.md b/model/train/yoco_moe/docs/ops/npu_ring_attention_update.md
new file mode 100644
index 0000000000000000000000000000000000000000..4eb9a364c1b03dfb4a1497c2f77a7a02d28289ae
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_ring_attention_update.md
@@ -0,0 +1,96 @@
+# npu_ring_attention_update对外接口
+
+npu_ring_attention_update(
+ prev_attn_out: torch.Tensor,
+ prev_softmax_max: torch.Tensor,
+ prev_softmax_sum: torch.Tensor,
+ cur_attn_out: torch.Tensor,
+ cur_softmax_max: torch.Tensor,
+ cur_softmax_sum: torch.Tensor,
+ actual_seq_qlen: torch.Tensor = None,
+ layout: str = "SBH",
+)
+
+小算子等价计算逻辑:
+```python
+import torch
+
+
+def forward_update(prev_attn_out, prev_softmax_max, prev_softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH'):
+ # update softmax_max
+ origin_dtype = prev_attn_out.dtype
+ softmax_max = torch.maximum(prev_softmax_max, cur_softmax_max)
+ prev_scale = torch.exp(prev_softmax_max - softmax_max)
+ cur_scale = torch.exp(cur_softmax_max - softmax_max)
+
+ # update softmax_sum
+ prev_softmax_sum_scaled = prev_softmax_sum * prev_scale
+ cur_softmax_sum_scaled = cur_softmax_sum * cur_scale
+ softmax_sum = prev_softmax_sum_scaled + cur_softmax_sum_scaled
+
+ # out updating scale
+ prev_out_scale = prev_softmax_sum_scaled / softmax_sum
+ cur_out_scale = cur_softmax_sum_scaled / softmax_sum
+
+ # [b, n, s, 8] -> [s, b, h]
+ # SBH layout
+ n = prev_out_scale.shape[1]
+ h = prev_attn_out.shape[-1]
+ d = h // n
+ prev_out_scale = prev_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
+ prev_out_scale = rearrange(prev_out_scale, 'b n s d -> s b (n d)').contiguous()
+ cur_out_scale = cur_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
+ cur_out_scale = rearrange(cur_out_scale, 'b n s d -> s b (n d)').contiguous()
+
+ # update output
+ attn_out = prev_attn_out * prev_out_scale + cur_attn_out * cur_out_scale
+ attn_out = attn_out.to(origin_dtype)
+ return attn_out, softmax_max, softmax_sum
+
+```
+
+## 前向接口:
+
+输入:
+
+- prev_attn_out:必选输入,数据类型torch.bfloat16, torch.float, torch.float16
+- prev_softmax_max: 必选输入,数据类型torch.float
+- prev_softmax_sum: 必选输入,数据类型torch.float
+- cur_attn_out: 必选输入,数据类型torch.bfloat16, torch.float, torch.float16
+- cur_softmax_max: 必选输入,数据类型torch.float
+- cur_softmax_sum: 必选输入,数据类型torch.float
+
+
+输出:
+
+- attn_out:必选输出,数据类型torch.bfloat16, torch.float, torch.float16
+- softmax_max:必选输出,数据类型torch.float
+- softmax_max:必选输出,数据类型torch.float
+
+属性:
+
+- actual_seq_qlen:可选属性,数据类型torch.int64, 数据单调递增,layout为TND的时候使用
+- layout:必选属性,数据类型str
+
+
+
+## 案例
+
+```python
+import torch
+import torch_npu
+from mindspeed.ops.npu_ring_attention_update import npu_ring_attention_update
+
+prev_attn_out = torch.randn(2048, 1, 12, dtype=torch.bfloat16).npu()
+prev_softmax_max = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()
+prev_softmax_sum = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()
+cur_attn_out = torch.randn(2048, 1, 12, dtype=torch.bfloat16).npu()
+cur_softmax_max = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()
+cur_softmax_sum = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()
+
+attn_out, softmax_max, softmax_sum = forward_update(prev_attn_out, prev_softmax_max, prev_softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum)
+
+
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/npu_rotary_position_embedding.md b/model/train/yoco_moe/docs/ops/npu_rotary_position_embedding.md
new file mode 100644
index 0000000000000000000000000000000000000000..e1a3d38c2b6ed8b97e33049fd2e936ca7459466c
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/npu_rotary_position_embedding.md
@@ -0,0 +1,107 @@
+# npu_rotary_position_embedding对外接口
+
+npu_rotary_position_embedding(x, cos, sin, mode=0)
+
+小算子等价计算逻辑:
+```python
+import torch
+from einops import rearrange
+
+# mode = 0
+def rotate_half(x):
+ x1, x2 = torch.chunk(x, 2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+# mode = 1
+def rotate_interleaved(x):
+ x1 = x[..., ::2]
+ x2 = x[..., 1::2]
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ...(d two)", two=2)
+
+def fused_rotary_position_embedding(x, cos, sin, interleaved=False):
+ if not interleaved:
+ return x * cos + rotate_half(x) * sin
+ else:
+ return x * cos + rotate_interleaved(x) * sin
+```
+
+## 前向接口:
+
+输入:
+
+- x:必选输入,4维Tensor,数据类型float16, bfloat16, float32
+- cos: 必选输入,4维Tensor,数据类型float16, bfloat16, float32
+- sin: 必选输入,4维Tensor,数据类型float16, bfloat16, float32
+
+输出:
+
+- y:必选输出,数据类型float16, bfloat16, float32
+
+属性:
+
+- mode:可选属性,数据类型int64_t,用于选择计算模式,0: rotate_half(GPT-NeoX style); 1: rotate_interleaved(GPT-J style)。缺省为0。
+
+
+## 反向接口:
+
+输入:
+
+- dy:必选输入,4维Tensor,数据类型float16, bfloat16, float32
+- cos:必选输入,4维Tensor,数据类型float16, bfloat16, float32
+- sin:必选输入,4维Tensor,数据类型float16, bfloat16, float32
+- x: 可选输入,4维Tensor,数据类型float16, bfloat16, float32
+
+输出:
+
+- dx:必选输出,4维Tensor,数据类型float16, bfloat16, float32
+- dcos:可选输出,4维Tensor,数据类型float16, bfloat16, float32
+- dsin:可选输出,4维Tensor,数据类型float16, bfloat16, float32
+
+属性:
+
+- mode:可选属性,数据类型int64_t,用于选择计算模式,0: rotate_half(GPT-NeoX style); 1: rotate_interleaved(GPT-J style)。缺省为0。
+
+## 输入约束:
+
+| 输入 | RotateHalf(mode: 0) | RotateInterleaved(mode: 1) |
+| :-: | :- | :- |
+| x | layout支持: BNSD、BSND、SBND;
D < 896,且为2的倍数;
B, N < 1000;
当需要计算cos/sin的反向梯度时,B*N <= 1024 | layout支持: BNSD、BSND、SBND;
B * N < 1000;
D < 896, 且D为2的倍数; |
+| cos | 数据范围:[-1, 1];
对应x layout的支持情况:
x为BNSD: 11SD、B1SD、BNSD;
x为BSND: 1S1D、BS1D、BSND;
x为SBND: S11D、SB1D、SBND. | 数据范围:[-1, 1];
对应x layout的支持情况:
x为BNSD: 11SD;
x为BSND: 1S1D;
x为SBND: S11D.|
+| sin | 同cos | 同cos |
+
+**备注**:
+1. 输入不支持为None;
+2. cos, sin如需计算反向梯度需要同时设置`requires_grad = True`, 若仅设置一个则均不计算;
+3. RotaryHalf(mode=0)模式下,当输入layout是BNSD,且D为非32Bytes对齐时,建议不使用该融合算子(模型启动脚本中不开启`--use-fused-rotary-pos-emb`选项),否则可能出现性能下降。
+
+
+## 案例
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding
+
+x = np.random.uniform(-2, 2, (4, 8192, 4, 128))
+cos = np.random.uniform(-1, 1, (1, 8192, 1, 128))
+sin = np.random.uniform(-1, 1, (1, 8192, 1, 128))
+
+x_npu = torch.from_numpy(x).float().npu()
+cos_npu = torch.from_numpy(cos).float().npu()
+sin_npu = torch.from_numpy(sin).float().npu()
+
+x_npu.requires_grad = True
+cos_npu.requires_grad = True
+sin_npu.requires_grad = True
+# 正向接口案例
+result = npu_rotary_position_embedding(x_npu, cos_npu, sin_npu, 0)
+
+# 反向接口案例
+result.backward(torch.ones_like(result).npu())
+x_npu.grad
+cos_npu.grad
+sin_npu.grad
+```
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/quant_gmm.md b/model/train/yoco_moe/docs/ops/quant_gmm.md
new file mode 100644
index 0000000000000000000000000000000000000000..4e545eb8aa28740099516500c8bee64e46b5a8f5
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/quant_gmm.md
@@ -0,0 +1,95 @@
+# quant_gmm对外接口
+
+npu_quant_gmm(x, weight, scale, *, offset=None, per_token_scale=None, bias=None, group_list=None,
+output_dtype=None, act_type=0)
+
+npu_quant_gmm_v2(x, weight, scale, *, offset=None, per_token_scale=None, bias=None, group_list=None,
+output_dtype=None, act_type=0)
+
+[npu_quant_gmm_v2]相较于[npu_quant_gmm]接口, group_list的含义不同, 在npu_quant_gmm接口中group_list中数值为分组轴大小的cumsum结果(累积和),npu_quant_gmm_v2接口中group_list中数值为分组轴上每组大小。两个接口的算子性能无差异,使用时可以根据整网中group_list的情况决定,如果前序算子输出的group_list为各group的大小,建议使用npu_quant_gmm_v2接口,因为此时使用npu_quant_gmm接口需要先调用torch.cumsum将group_list转为累积和的形式,带来额外开销。
+
+## 前向接口:
+
+输入:
+
+- x:必选输入,参数为tensor,数据类型int8;
+- weight:必选输入,参数为tensor,数据类型int8;
+- scale:必选输入,参数类型为tensor,数据类型int64,bfloat16,float32;
+- offset:保留参数,当前未使能;
+- per_token_scale:可选参数,参数类型为tensor,数据类型float32,默认值为none;
+- bias:可选输入,参数类型为tensor,数据类型int32, 默认值为none;
+- group_list:可选输入,参数类型为tensor,数据类型int64,默认值为none。不同接口中的数值定义不同,具体见上述接口说明中描述;
+- output_dtype:可选输入,参数类型为torch.dtype,可选值为:torch.int8,torch.bfloat16,torch.float16,用于指定输出数据类型,默认值为None,此时输出类型为torch.float16;
+- act_type:可选参数,参数类型为int,用于指定激活函数类型,默认值为0,支持的激活函数类型如下:
+ - 0:无激活函数;
+ - 1:relu;
+ - 2:gelu_tanh;
+ - 3:gelu_err_func(暂不支持);
+ - 4:fast_gelu;
+ - 5:silu。
+
+输出:
+
+- y:必选输出,数据类型int8, float16, bfloat16。
+
+约束与限制:
+
+- npu_quant_gmm接口中,group_list必须为非负单调非递减数列,且长度不能为1;
+- npu_quant_gmm_v2接口中,group_list必须为非负数列,长度不能为1,且数据类型仅支持tensor;
+- x和weight中每一组tensor的最后一维大小都应小于65536.$x_i$的最后一维指当属性transpose_x为false时$x_i$的K轴或当transpose_x为true时$x_i$的M轴。$weight_i$的最后一维指当属性transpose_weight为false时$weight_i$的N轴或当transpose_weight为true时$weight_i$的K轴;
+- x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647;
+- 当需要输出y数据类型为int8时,指定output_dtype为torch.int8,scale类型为int64,per_token_scale为空,此时只支持act_type=0,即无激活函数;该场景当前仅支持单算子模式,图模式不支持;
+- 当需要输出y数据类型为bfloat16时,output_dtype为torch.bfloat16,scale类型为bfloat16;
+- 当需要输出y数据类型为float16时,output_dtype为torch.float16或者默认参数None,scale类型为float32。
+
+## gmm 类的调用方式
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops import quant_gmm
+
+num_expert, seq_len, hidden_dim, out_channel = 8, 32, 256, 128
+group_list = torch.tensor([1, 3, 6, 10, 15, 21, 28, 32], dtype=torch.int64).npu()
+
+x = torch.randint(-128, 128, (seq_len, hidden_dim), dtype=torch.int8).npu()
+weight = torch.randint(-128, 128, (num_expert, hidden_dim, out_channel), dtype=torch.int8).npu()
+scale = torch.rand(num_expert, out_channel, dtype=torch.float32).npu()
+per_token_scale = torch.rand(seq_len, dtype=torch.float32).npu()
+
+result = quant_gmm.npu_quant_gmm(x, weight, scale, per_token_scale=per_token_scale,
+ bias=None, group_list=group_list, output_dtype=torch.float16)
+
+# weight转置案例
+weight_trans = torch.randint(-128, 128, (num_expert, out_channel, hidden_dim), dtype=torch.int8).npu()
+result = quant_gmm.npu_quant_gmm(x, weight_trans.transpose(-1,-2), scale, per_token_scale=per_token_scale,
+ bias=None, group_list=group_list, output_dtype=torch.float16)
+```
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops import quant_gmm
+
+num_expert, seq_len, hidden_dim, out_channel = 8, 32, 256, 128
+group_list = torch.tensor([1, 3, 3, 4, 5, 6, 7, 4], dtype=torch.int64).npu()
+
+x = torch.randint(-128, 128, (seq_len, hidden_dim), dtype=torch.int8).npu()
+weight = torch.randint(-128, 128, (num_expert, hidden_dim, out_channel), dtype=torch.int8).npu()
+scale = torch.rand(num_expert, out_channel, dtype=torch.float32).npu()
+per_token_scale = torch.rand(seq_len, dtype=torch.float32).npu()
+
+result = quant_gmm.npu_quant_gmm_v2(x, weight, scale, per_token_scale=per_token_scale,
+ bias=None, group_list=group_list, output_dtype=torch.float16)
+
+# weight转置案例
+weight_trans = torch.randint(-128, 128, (num_expert, out_channel, hidden_dim), dtype=torch.int8).npu()
+result = quant_gmm.npu_quant_gmm_v2(x, weight_trans.transpose(-1,-2), scale, per_token_scale=per_token_scale,
+ bias=None, group_list=group_list, output_dtype=torch.float16)
+```
diff --git a/model/train/yoco_moe/docs/ops/rms_norm.md b/model/train/yoco_moe/docs/ops/rms_norm.md
new file mode 100644
index 0000000000000000000000000000000000000000..15598e74a9f0664120b30784e180c85fc274e957
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/rms_norm.md
@@ -0,0 +1,43 @@
+# rms_norm对外接口
+## 前向接口:
+输入:
+- x:必选输入,数据类型float16, bfloat16, float32
+- gamma:必选输入,数据类型float16, bfloat16, float32
+
+输出:
+- rstd:中间变量,数据类型float32
+- out:必选输出,数据类型float16, bfloat16, float32
+
+属性:
+- epsilon:可选属性,数据类型float,缺省1e-6。
+
+## 反向接口:
+输入:
+- dy:必选输入,数据类型float16, bfloat16, float32
+- x:必选输入,数据类型float16, bfloat16, float32
+- rstd:必选输入,数据类型float32
+- gamma:必选输入,数据类型float16, bfloat16, float32
+
+输出:
+- dx:必选输出,数据类型float16, bfloat16, float32
+- dgammma:必选输出,数据类型float32
+
+## 案例
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.op_builder import RmsNormOpBuilder
+
+x = torch.rand([2, 4]).to(torch.float).npu()
+gamma = torch.rand([4]).to(torch.float).npu()
+# 正向接口案例
+mindspeed_ops = RmsNormOpBuilder().load()
+out = mindspeed_ops.rms_norm(x, gamma)
+# 反向接口案例
+out.backward(torch.ones(out.shape).npu())
+
+```
diff --git a/model/train/yoco_moe/docs/ops/swiglu.md b/model/train/yoco_moe/docs/ops/swiglu.md
new file mode 100644
index 0000000000000000000000000000000000000000..a2dac806144797b5a859779e2db8853fbb5a1e05
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/swiglu.md
@@ -0,0 +1,42 @@
+# swiglu对外接口
+## 前向接口:
+输入:
+- x:必选输入,数据类型float16, bfloat16, float32
+
+输出:
+- y:必选输出,数据类型float16, bfloat16, float32
+
+属性:
+- dim:可选属性,数据类型int32_t,缺省-1。
+
+## 反向接口:
+输入:
+- dy:必选输入,数据类型float16, bfloat16, float32
+- x:必选输入,数据类型float16, bfloat16, float32
+
+输出:
+- dx:必选输出,数据类型float16, bfloat16, float32
+
+属性:
+- dim:可选属性,数据类型int32_t,缺省-1。
+
+## 案例
+```python
+ import os
+ import torch
+ import torch_npu
+ import numpy as np
+ import math
+ from mindspeed.op_builder import SwigluOpBuilder
+
+ x = np.random.uniform(-2, 2, (8192,1,3904))
+ x = torch.from_numpy(x).float().npu()
+ y_grad = np.random.uniform(-2, 2, (8192,1,1952))
+ y_grad = torch.from_numpy(y_grad).float().npu()
+
+ x.requires_grad = True
+ # 正向接口案例
+ mindspeed_ops = SwigluOpBuilder().load()
+ result = mindspeed_ops.swiglu(x, dim=-1)
+ # 反向接口案例
+ result.backward(y_grad)
\ No newline at end of file
diff --git a/model/train/yoco_moe/docs/ops/weight_quant_gmm.md b/model/train/yoco_moe/docs/ops/weight_quant_gmm.md
new file mode 100644
index 0000000000000000000000000000000000000000..767cbba17f9fab451b9ec4424b8b61d206dd350a
--- /dev/null
+++ b/model/train/yoco_moe/docs/ops/weight_quant_gmm.md
@@ -0,0 +1,86 @@
+# weight_quant_gmm对外接口
+
+npu_weight_quant_gmm(x, weight, antiquant_scale, *, antiquant_offset=None, bias=None, group_list=None, act_type=0)
+
+npu_weight_quant_gmm_v2(x, weight, antiquant_scale, *, antiquant_offset=None, bias=None, group_list=None, act_type=0)
+
+[npu_weight_quant_gmm_v2]相较于[npu_weight_quant_gmm]接口,group_list的含义不同,在npu_weight_quant_gmm接口中group_list中数值为分组轴大小的cumsum结果(累积和),npu_weight_quant_gmm_v2接口中group_list中数值为分组轴上每组大小。两个接口的算子性能无差异,使用时可以根据整网中group_list的情况决定,如果前序算子输出的group_list为各group的大小,建议使用npu_weight_quant_gmm_v2接口,因为此时使用npu_weight_quant_gmm接口需要先调用torch.cumsum将group_list转为累积和的形式,带来额外开销。
+
+## 前向接口:
+
+输入:
+
+- x:必选输入,参数为tensor,数据类型float16,bfloat16;
+- weight:必选输入,参数为tensor,数据类型int8;
+- antiquant_scale:必选输入,参数类型为tensor,数据类型float16,bfloat16;
+- antiquant_offset:可选参数,参数类型为tensor,数据类型float16,bfloat16,默认值为none,当前不支持传none;
+- bias:可选输入,参数类型为tensor,数据类型float16,float32,默认值为none;
+- group_list:可选输入,参数类型为tensor,数据类型int64,默认值为none。不同接口中的数值定义不同,具体见上述接口说明中描述;
+- act_type:可选参数,参数类型为int,用于指定激活函数类型,默认值为0,表示无激活函数,当前只支持默认值0;
+
+输出:
+
+- y:必选输出,数据类型float16,bfloat16。
+
+约束与限制:
+
+- npu_weight_quant_gmm接口中,group_list必须为非负单调非递减数列,且长度不能为1;
+- npu_weight_quant_gmm_v2接口中,group_list必须为非负数列,长度不能为1,且数据类型仅支持tensor;
+- x和weight中每一组tensor的最后一维大小都应小于65536.$x_i$的最后一维指当属性transpose_x为false时$x_i$的K轴或当transpose_x为true时$x_i$的M轴。$weight_i$的最后一维指当属性transpose_weight为false时$weight_i$的N轴或当transpose_weight为true时$weight_i$的K轴;
+- x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647;
+- x,antiquant_scale,antiquant_offset,y的数据类型因保持一致
+- 当需要输出y数据类型为bfloat16时,bias类型为float32;
+- 当需要输出y数据类型为float16时,bias类型为float16。
+- 暂不支持计算flops。
+
+## gmm 类的调用方式
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops import weight_quant_gmm
+
+num_expert, seq_len, hidden_dim, out_channel = 8, 32, 256, 128
+group_list = torch.tensor([1, 3, 6, 10, 15, 21, 28, 32], dtype=torch.int64).npu()
+
+x = torch.rand(seq_len, hidden_dim, dtype=torch.float16).npu()
+weight = torch.randint(-128, 128, (num_expert, hidden_dim, out_channel), dtype=torch.int8).npu()
+antiquant_scale = torch.rand(num_expert, out_channel, dtype=torch.float16).npu()
+antiquant_offset = torch.rand(num_expert, out_channel, dtype=torch.float16).npu()
+
+result = weight_quant_gmm.npu_weight_quant_gmm(x, weight, antiquant_scale, antiquant_offset=antiquant_offset,
+ bias=None, group_list=group_list)
+
+# weight转置案例
+weight_trans = torch.randint(-128, 128, (num_expert, out_channel, hidden_dim), dtype=torch.int8).npu()
+result = weight_quant_gmm.npu_weight_quant_gmm(x, weight_trans.transpose(-1,-2), antiquant_scale,
+ antiquant_offset=antiquant_offset, bias=None, group_list=group_list)
+```
+
+```python
+import os
+import torch
+import torch_npu
+import numpy as np
+import math
+from mindspeed.ops import weight_quant_gmm
+
+num_expert, seq_len, hidden_dim, out_channel = 8, 32, 256, 128
+group_list = torch.tensor([1, 3, 3, 4, 5, 6, 7, 4], dtype=torch.int64).npu()
+
+x = torch.rand(seq_len, hidden_dim, dtype=torch.float16).npu()
+weight = torch.randint(-128, 128, (num_expert, hidden_dim, out_channel), dtype=torch.int8).npu()
+antiquant_scale = torch.rand(num_expert, out_channel, dtype=torch.float16).npu()
+antiquant_offset = torch.rand(num_expert, out_channel, dtype=torch.float16).npu()
+
+result = weight_quant_gmm.npu_weight_quant_gmm_v2(x, weight, antiquant_scale, antiquant_offset=antiquant_offset,
+ bias=None, group_list=group_list)
+
+# weight转置案例
+weight_trans = torch.randint(-128, 128, (num_expert, out_channel, hidden_dim), dtype=torch.int8).npu()
+result = weight_quant_gmm.npu_weight_quant_gmm_v2(x, weight_trans.transpose(-1,-2), antiquant_scale,
+ antiquant_offset=antiquant_offset, bias=None, group_list=group_list)
+```
diff --git a/model/train/yoco_moe/mindspeed/__init__.py b/model/train/yoco_moe/mindspeed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/arguments.py b/model/train/yoco_moe/mindspeed/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..576db4152fc7543599a90808aefbc00c00cd02e6
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/arguments.py
@@ -0,0 +1,1117 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from dataclasses import make_dataclass, field
+from functools import wraps
+import argparse
+import warnings
+import torch
+from mindspeed.features_manager import FEATURES_LIST
+
+
+def extra_args_provider_decorator(extra_args_provider):
+ @wraps(extra_args_provider)
+ def wrapper(parser):
+ if extra_args_provider is not None:
+ parser = extra_args_provider(parser)
+ parser = process_args(parser)
+ return parser
+
+ return wrapper
+
+
+def parse_args_wrapper(parse_args):
+ @wraps(parse_args)
+ def wrapper(extra_args_provider=None, ignore_unknown_args=False):
+ decorated_provider = extra_args_provider_decorator(extra_args_provider)
+ return parse_args(decorated_provider, ignore_unknown_args)
+
+ return wrapper
+
+
+def process_args(parser):
+ parser.conflict_handler = 'resolve'
+ parser = _add_network_size_args(parser)
+ parser = _add_distributed_args(parser)
+ parser = _add_training_args(parser)
+ parser = _add_data_args(parser)
+ parser = _add_moe_args(parser)
+ parser = _add_cp_args(parser)
+ parser = _add_network_args(parser)
+ parser = _add_algorithm_args(parser)
+ parser = _add_automated_pipeline_args(parser)
+ parser = _add_alibi_args(parser)
+ parser = _add_ndmm_args(parser)
+ parser = _add_2d_tp_args(parser)
+ parser = _add_coc_args(parser)
+ parser = _add_profile_args(parser)
+ parser = _add_auto_parallel_args(parser)
+ parser = _add_deepseek_args(parser)
+ parser = _auto_tuning_args(parser)
+ parser = _add_auto_parallel_mm_args(parser)
+ parser = _add_hccl_group_buffer_args(parser)
+ parser = _add_layerzero_args(parser)
+ parser = _add_dist_train_args(parser)
+
+ for feature in FEATURES_LIST:
+ feature.register_args(parser)
+
+ return parser
+
+
+def _add_deepseek_args(parser):
+ group = parser.add_argument_group(title='deepseek')
+ # deepseek moe arguments
+ group.add_argument('--n-shared-experts', type=int, default=None)
+ # mla arguments
+ group.add_argument('--multi-head-latent-attention', action='store_true', default=False,
+ help='Use Multi-head Latent Attention(MLA)')
+ group.add_argument('--q-lora-rank', type=int, default=None, help='The low rank of q')
+ group.add_argument('--kv-lora-rank', type=int, default=None, help='The low rank of k and v')
+ group.add_argument('--v-head-dim', type=int, default=None, help='The head dim of v')
+ group.add_argument('--qk-rope-head-dim', type=int, default=None, help='The qk head dim for rope')
+ group.add_argument('--qk-nope-head-dim', type=int, default=None, help='The qk head dim for only self-attn')
+ # yarn arguments
+ group.add_argument('--rope-scaling-type', type=str, default=None, choices=['yarn', ],
+ help='Set the rope scaling type, only support "yarn" type now')
+ group.add_argument('--rope-scaling-beta-fast', type=int, default=32, help='Yarn rope: rope beta fast')
+ group.add_argument('--rope-scaling-beta-slow', type=int, default=1, help='Yarn rope: rope beta slow')
+ group.add_argument('--rope-scaling-factor', type=float, default=1.0, help='Yarn rope: rope factor')
+ group.add_argument('--rope-scaling-mscale', type=float, default=1.0, help='Yarn rope: rope mscale')
+ group.add_argument('--rope-scaling-mscale-all-dim', type=float, default=0.0, help='Yarn rope: rope mscale all dim')
+ group.add_argument('--rope-scaling-original-max-position-embeddings', type=int, default=None,
+ help='Yarn rope: rope original max position embeddings')
+ group.add_argument('--moe-hierarchical-alltoallv', action='store_true',
+ help='Reduce communication cost between nodes')
+
+ return parser
+
+
+def _auto_tuning_args(parser):
+ group = parser.add_argument_group(title='auto_tuning')
+
+ group.add_argument('--auto-tuning', action='store_true', help='enable auto tuning')
+ group.add_argument('--auto-tuning-work-dir', type=str, default='./auto_tuning_dir',
+ help="auto tuning working path.")
+ group.add_argument('--auto-tuning-ranks', type=int, default=16, help='the global size of auto tuning')
+ group.add_argument('--auto-tuning-log-level', type=str, default='info', choices=['debug', 'info', 'warning'],
+ help='auto tuning log level, could be debug, info or warning')
+
+ return parser
+
+
+def _add_profile_args(parser):
+ group = parser.add_argument_group(title='profile')
+ group.add_argument("--profile-level", type=str, default='level0',
+ choices=['level0', 'level1', 'level2'],
+ help="Profile level default level0.")
+ group.add_argument("--profile-with-cpu", action='store_true', default=False,
+ help="Profile with cpu info.")
+ group.add_argument("--profile-with-stack", action='store_true', default=False,
+ help="Profile without stack info.")
+ group.add_argument("--profile-with-memory", action='store_true', default=False,
+ help="Profile without memory info.")
+ group.add_argument("--profile-record-shapes", action='store_true', default=False,
+ help="Profile record shape info.")
+ group.add_argument("--profile-save-path", type=str, default='./profile_dir',
+ help="Profile save path.")
+ group.add_argument('--profile-ranks', nargs='+', type=int, default=[-1],
+ help='Global ranks to profile.The default value of -1 means to profile all ranks')
+ return parser
+
+
+def _add_coc_args(parser):
+ group = parser.add_argument_group(title='coc')
+ # ascend mc2 arguments
+ group.add_argument("--use-ascend-mc2", action='store_true',
+ help="Use ascend mc2")
+ # ascend coc arguments
+ group.add_argument("--use-ascend-coc", action='store_true',
+ help="Use ascend coc")
+ group.add_argument('--coc-mode', type=int, default=-1,
+ help='coc-mode: 0=original, 1=rewrite, 2=coc default')
+ group.add_argument('--coc-parallel-num', type=int, default=1,
+ help='coc parallel num')
+ group.add_argument('--coc-fused-kernel', action='store_true',
+ help='use coc fused kernel')
+ return parser
+
+
+def _add_moe_args(parser):
+ group = parser.add_argument_group(title='moe')
+ # deepspeed moe arguments
+ group.add_argument('--moe-model-type', type=str, default='megatron_moe',
+ choices=['deepspeed_moe', 'megatron_moe'], help='moe model type default megatron moe')
+ group.add_argument('--expert-interval', type=int, default=1,
+ help='Use experts in every "expert-interval" layers')
+ group.add_argument('--moe-train-capacity-factor', type=float, default=1.0,
+ help='The capacity of the MoE expert at training time')
+ group.add_argument('--noisy-gate-policy', type=str, default=None, choices=['Jitter', 'RSample', 'None'],
+ help="noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.")
+ group.add_argument('--enable-token-rearrange-opt', action='store_true',
+ help="Use this flag to enable token rearrange optimize")
+ group.add_argument('--no-use-rts',
+ action='store_false', default=False,
+ help='whether to use Random Token Selection.',
+ dest='use_rts')
+ group.add_argument("--moe-no-drop", action='store_true',
+ help="Use no drop policy in moe layer, no tokens will be discarded.")
+ group.add_argument("--moe-dynamic-padding", action='store_true',
+ help="Reducing AllReduce communication under the no drop policy through the sliding window mechanism.")
+ group.add_argument("--moe-use-sinkhorn", action='store_true',
+ help="Use sinkhorn load balancing in the gate.")
+
+ # megatron mcore moe arguments
+ group.add_argument("--moe-tp-extend-ep", action='store_true',
+ help="use tp group to extend experts parallelism"
+ "instead of sharding weight tensor of experts in tp group")
+ group.add_argument("--moe-permutation-async-comm", action='store_true',
+ help="overlap moe permutation 3 all gather communications")
+ group.add_argument("--moe-adaptive-recompute-activation", action='store_true',
+ help="MoE adaptive recompute, avoiding memory imbalance in the early stage.")
+ group.add_argument('--moe-adaptive-recompute-activation-scale', type=float, default=2.0,
+ help='MoE adaptive recompute threshold factor.')
+ group.add_argument("--use-fused-moe-token-permute-and-unpermute", action='store_true',
+ help="Use fused moe permute and unpermute.")
+ group.add_argument("--gemm-gradient-accumulation-fusion", action='store_true',
+ help="Use gradient-accumulation-fusion in gemm.")
+ # moe optimization arguments
+ group.add_argument('--moe-alltoall-overlap-comm', action='store_true', default=False,
+ help='moe_alltoall_overlap_comm')
+ group.add_argument('--moe-allgather-overlap-comm', action='store_true', default=False,
+ help='moe_allgather_overlap_comm')
+ group.add_argument('--moe-experts-pipeline-degree', type=int, default=0,
+ help='Group experts into pipeline stages to overlap computation and communication.')
+ group.add_argument("--moe-zero-memory", type=str, default='disable',
+ choices=['disable', 'level0', 'level1'],
+ help="Save activation memory in moe layer.")
+ group.add_argument('--moe-zero-memory-num-layers', type=int, default=None,
+ help='the number of layers using moe-zero-memory level1'
+ 'in each pp stage.')
+ group.add_argument('--moe-bmm-mc2', action='store_true', default=False,
+ help='moe_bmm_mc2')
+ return parser
+
+
+def _add_cp_args(parser):
+ group = parser.add_argument_group(title='cp parallel')
+ group.add_argument('--context-parallel-algo', type=str, default='ulysses_cp_algo',
+ choices=['ulysses_cp_algo', 'megatron_cp_algo', 'hybrid_cp_algo', 'adaptive_cp_algo',
+ 'hybrid_adaptive_cp_algo'],
+ help='context parallel algorithm')
+ group.add_argument('--ulysses-degree-in-cp', type=int, default=None)
+ group.add_argument('--cp-window-size', type=int, default=1)
+ group.add_argument('--attention-mask-type', type=str, default='causal',
+ choices=['causal', 'general'], help='context parallel attention mask type')
+ group.add_argument('--use-cp-send-recv-overlap', action='store_true',
+ help='use this flag to enable cp send-recv-overlap.')
+ group.add_argument("--use-fused-ring-attention-update", action='store_true',
+ help="Use fused ring attention update.")
+ group.add_argument("--megatron-cp-in-bnsd", action='store_true',
+ help="Megatron CP in bnsd.")
+ group.add_argument('--attention-mask-on-cpu', action='store_true',
+ help='store full attention mask on CPU instead of NPU')
+ group.add_argument('--adaptive-cp-without-coarse', action='store_true',
+ help='does not coarse the attention mask in adaptive_cp feature, only recommended when full'
+ 'sequence length is less than 8K and dynamic attention mask is not feasible')
+ group.add_argument('--adaptive-cp-dynamic-attn-mask', action='store_true',
+ help='if the attention mask is dynamic across batches')
+ group.add_argument('--adaptive-cp-only-reschedule', action='store_true',
+ help='not apply remapping but only rescheduling process in adaptive-cp feature')
+ group.add_argument('--adaptive-cp-manually-set-mask-list', action='store_true',
+ help='manually set pre-cooked attention mask list')
+ group.add_argument('--context-parallel-kv-cache-policy', type=str, default=None,
+ choices=['full', 'half'],
+ help='Selectivity cache K, V in process of cp.'
+ 'Default is None, means not used cache K, V.'
+ 'If para is full, cache all K, V.'
+ 'If para is half, cache only K')
+ group.add_argument('--context-parallel-cache-interval', type=int, default=0,
+ help='Set the interval of cache layers in cp.'
+ 'Default is 0, means cache K, V in all layers.')
+ group.add_argument('--use-ulysses-allgather-kv', action='store_true',
+ help='use this flag to enable allgather kv + repeat all2all q in ulysses cp.')
+ return parser
+
+
+def _add_network_size_args(parser):
+ group = parser.add_argument_group(title='network size')
+ group.add_argument("--use-fused-rmsnorm", action='store_true',
+ help="Use fused rmsnorm.")
+ group.add_argument("--use-fused-swiglu", action='store_true',
+ help="Use fused swiglu.")
+ group.add_argument("--use-fused-rotary-pos-emb", action='store_true',
+ help="Use fused rotary-pos-emb.")
+ return parser
+
+
+def _add_data_args(parser):
+ group = parser.add_argument_group(title='data and dataloader')
+ group.add_argument('--tokenizer-type', type=str,
+ default=None,
+ choices=['BertWordPieceLowerCase',
+ 'BertWordPieceCase',
+ 'GPT2BPETokenizer',
+ 'SentencePieceTokenizer',
+ 'GPTSentencePieceTokenizer',
+ 'Llama2Tokenizer',
+ 'PretrainedFromHF',
+ 'NullTokenizer'],
+ help='What type of tokenizer to use.')
+ group.add_argument("--tokenizer-name-or-path", type=str, default=None,
+ help="Name or path of the huggingface tokenizer.")
+ group.add_argument("--tokenizer-not-use-fast", action='store_false',
+ help="HuggingFace tokenizer not use the fast version.")
+ return parser
+
+
+def _add_distributed_args(parser):
+ group = parser.add_argument_group(title='distributed')
+
+ group.add_argument('--local-rank', type=int, default=None,
+ help='Local rank passed from distributed launcher for torch2.x.')
+ group.add_argument('--param-and-grad-buffer-pad', type=int, default=None,
+ help='Use this argument to ensure that all buckets start at a memory address that is needed-byte. Set 512 for Ascend')
+ group.add_argument('--use-nanopipe', action='store_true',
+ default=False, help='use nano pipeline parallelism for reduce bubble.')
+ group.add_argument('--use-nanopipe-swap', action='store_true',
+ default=False, help='use nano pipeline parallelism with swap for reduce bubble.')
+ group.add_argument('--use-pipe-experts', action='store_true',
+ help='Use this flag to enable pipe moe, overlap all2all and expert')
+ group.add_argument('--disable-gloo-group', action='store_true',
+ help='Replace the communication method of the DP group in the distributed optimizer from gloo to hccl.')
+ group.add_argument('--hccl-slice-size', type=int, default=10 * 1024 * 1024,
+ help='data slice size on each dp rank in distributed optimizer')
+ group.add_argument('--variable-seq-lengths', action='store_true',
+ help='Supports variable sequence lengths across batches/microbatches. Set this if the data '
+ 'loader supports variable sequence length generation across batches/microbatches. Because '
+ 'of the additional communication overhead incurred during pipeline parallelism, it should '
+ 'not be set if the sequence length is constant during training. if sequence length is '
+ 'constant during training.')
+ return parser
+
+
+def _add_training_args(parser):
+
+ group = parser.add_argument_group(title='training')
+
+ group.add_argument('--pre-tockens', type=int, default=65536,
+ help='pre-tockens is used by Flash attention')
+ group.add_argument('--next-tockens', type=int, default=0,
+ help='next-tockens is used by Flash attention')
+ group.add_argument('--shape-order', type=str, default='SBH',
+ choices=['SBH', 'BSH', 'BSND'],
+ help='input shape order used by Flash attention')
+ group.add_argument('--sparse-mode', type=int, default=0,
+ help='To improve performance in different modes of attention mask')
+ group.add_argument('--adaptive-recompute-device-size',
+ type=int, default=-1,
+ help='The memory size for adaptive selective recompute strategy. '
+ 'The default is -1. If this parameter > 0, '
+ 'will activate adaptive selective recompute. ')
+ group.add_argument('--adaptive-recompute-profiling-step',
+ type=int, default=10,
+ help='The profiling step for adaptive selective recompute strategy. '
+ 'The default is 10. If activate adaptive selective recompute, '
+ 'will solve graph after step 10. ')
+ group.add_argument('--adaptive-recompute-device-swap',
+ action='store_true', default=False,
+ help='switch to open adaptive recompute feature. '
+ 'The default is False.')
+ group.add_argument('--enable-recompute-layers-per-pp-rank',
+ action='store_true', default=False,
+ help='If enabled, --recompute-num-layers will mean the number of '
+ 'layers recomputed in each pp rank. Otherwise it means the number '
+ 'of layers recomputed in each vpp rank.')
+ group.add_argument('--recompute-activation-function', action='store_true',
+ help='Recompute the activation function in MLP layers.')
+ group.add_argument('--recompute-activation-function-num-layers', type=int, default=None,
+ help='Can be used together with "--recompute-method block." '
+ 'and "--recompute-num-layers". ')
+ group.add_argument('--recompute-norm', action='store_true',
+ help='Recompute norm in Transformer Layers')
+ group.add_argument('--recompute-norm-num-layers', type=int, default=None,
+ help='Recompute norm num layers, can be used together with activation function recompute. ')
+ group.add_argument('--recompute-in-bubble', action='store_true',
+ help='use bubble to do recompute to reduce memory')
+ group.add_argument('--recompute-in-advance', action='store_true',
+ help='recompute early to reduce bubble and improve training.')
+ group.add_argument('--jit-compile', action='store_true', default=False,
+ help='Setting jit compile mode to True')
+ group.add_argument('--swap-attention', action='store_true', default=False,
+ help='switch to open swap-attention feature.'
+ 'The default is False.')
+ group.add_argument('--swap-modules', type=str, default="input_norm,self_attention,post_attention_norm",
+ help='Swap modules for model. Can be used together with "--swap-attention."')
+ group.add_argument('--adaptive-memory-optimization', action='store_true', default=False,
+ help='Switch to open adaptive memory optimization feature, default is False.')
+ group.add_argument('--use-fusion-attn-v2', action='store_true', default=False,
+ help='use fusion_attention ops version 2')
+ group.add_argument('--pipe-experts-multi-data', type=int, default=1,
+ help='Use multi data to split the input tensor to implement masking when --use-pipe-experts. '
+ 'The default is 1.')
+ group.add_argument('--pipe-experts-multi-stream', action='store_true', default=False,
+ help='Use multi stream to avoid link collision in collective communication when --use-pipe-experts. '
+ 'The default is False.')
+ group.add_argument("--additional-config", help="additional model config file path")
+ group.add_argument('--use-ema', action='store_true', default=False,
+ help='use ema when training')
+ group.add_argument('--use-multiparameter-pipeline-model-parallel', action='store_true', default=False,
+ help='can transfer multi parameters from stage to stage in pipeline model parallel')
+ group.add_argument('--ampipe-degree', type=int, default=1,
+ help='Set Attention MoE pipe(AMPipe) degree, 1 means not enable '
+ 'AMPipe, greater than 1 means enable this feature.')
+ group.add_argument('--ampipe-tp-sp-comm-overlap', action='store_true', default=False,
+ help='enable computation and tp or sp communication overlap in ampipe')
+ group.add_argument('--op-cal-tflops', action='store_true', default=False,
+ help='use for cal mfu and hfu')
+ group.add_argument('--npu-deterministic', action='store_true', default=False,
+ help='enable deterministic computing for npu')
+ group.add_argument('--optimizer-selection', type=str, default='fused_adamw',
+ choices=['fused_adamw', 'fused_torch_adamw', 'fused_ema_adamw'],
+ help='Select from the former fused AdamW optimizer and Torch fused AdamW optimizer')
+ group.add_argument('--ema-decay', type=float, default=0.9999,
+ help='Set ema_decay of fused_ema_adamw optimizer.')
+ return parser
+
+
+def _add_network_args(parser):
+ group = parser.add_argument_group(title='network')
+
+ group.add_argument("--add-qkv-bias", action="store_true", default=False,
+ help='Configuration for the qkv bias.')
+ group.add_argument("--add-dense-bias", action="store_true", default=False,
+ help='Configuration for the dense bias.')
+ group.add_argument("--skip-bias-add", action="store_false", default=True,
+ help='Configuration for the skip bias.')
+ group.add_argument("--noop-layers", type=str,
+ help='Specity the noop layers.')
+ return parser
+
+
+def _add_automated_pipeline_args(parser):
+ group = parser.add_argument_group(title='automated_pipeline_allocation')
+ group.add_argument('--automated-pipeline',
+ action='store_true',
+ help='To enable automated pipeline memory saving process'
+ )
+ group.add_argument('--automated-pipeline-perf',
+ action='store_true',
+ help='To enable automated pipeline performance acceleration process'
+ )
+ group.add_argument('--save-memory-ratio',
+ type=float, default=0.20,
+ help='To set memory saving rate in automated pipeline'
+ )
+ group.add_argument('--num-layer-list',
+ type=str, help='To store the layer policy of automated pipeline'
+ )
+ group.add_argument('--recompute-module-list',
+ type=str, help='To store the recompute policy of automated pipeline'
+ )
+ group.add_argument('--recompute-type',
+ type=int, default=2,
+ help='To store the recompute type of automated pipeline, 0 for mlp block '
+ '1 for attention block and 2 for transformer layer'
+ )
+ group.add_argument('--optimized-mbs-list',
+ type=str,
+ help='To store the optimized mbs policy of automated pipeline performance'
+ )
+ group.add_argument('--mbs-idx',
+ type=int,
+ help='To store the index of mbs list'
+ )
+ group.add_argument('--pp-schedule-list',
+ type=str,
+ help='To store the pipeline schedule policy of automated pipeline performance'
+ )
+ group.add_argument('--optimized-mbs-mode',
+ action='store_false',
+ help='To store the status of optimized mbs in automated pipeline performance'
+ )
+ group.add_argument('--memory-fragmentation',
+ action='store_true', default=False,
+ help='Enable the memory fragmentation feature.')
+ group.add_argument('--smart-swap',
+ action='store_true', default=False, help='Enable the smart swap feature.')
+ return parser
+
+
+def _add_algorithm_args(parser):
+ group = parser.add_argument_group(title='training')
+ group.add_argument('--optimization-level', type=int, choices=[0, 1, 2], default=2,
+ help='0: The minimum patch set for megatron to adapt to NPU,'
+ '1: Affinity optimization (fusion operator, etc.), '
+ '2: Advanced acceleration algorithm')
+ group.add_argument('--reuse-fp32-param', action='store_true',
+ help='The distributed training optimizer frees up '
+ 'param copies of FP32 to save memory.')
+
+ group.add_argument('--optimize-send-recv-comm', action='store_true',
+ help='optimize send_recv communication in pipeline without interleaving.')
+ group.add_argument('--optimize-vpp-send-recv-comm', action='store_true',
+ help='optimize send_recv communication in pipeline with interleaving.')
+ group.add_argument('--enable-zero3', action='store_true', default=False,
+ help='Use this flag to enable zero3, including the segmentation of the parameters, gradients, and optimizers of the row-parallel and column-parallel models, as well as the overlap optimization of the gradient reduce sactter and weight all gather.')
+ return parser
+
+
+def _add_layerzero_args(parser):
+ group = parser.add_argument_group(title='layerzero')
+ group.add_argument('--layerzero', action='store_true', default=False,
+ help='Use this flag to enable layerzero, including the segmentation of the parameters, gradients, and optimizers of the row-parallel and column-parallel models, as well as the overlap optimization of the gradient reduce sactter and weight all gather.')
+ group.add_argument('--layerzero-config', type=str,
+ help='Use this yaml file to config layerzero behaviours')
+ return parser
+
+
+def _add_dist_train_args(parser):
+ group = parser.add_argument_group(title='dist_train')
+ group.add_argument('--dist-train', action='store_true', help='Enable dist-train feature.')
+ return parser
+
+
+def core_transformer_config_from_args_wrapper(fn):
+ @wraps(fn)
+ def wrapper(args):
+ config = fn(args)
+ config.context_parallel_algo = args.context_parallel_algo
+ config.batch_p2p_comm = False
+ if args.use_multiparameter_pipeline_model_parallel:
+ config.deallocate_pipeline_outputs = False
+ return config
+
+ return wrapper
+
+
+def validate_args_wrapper(validate_args):
+ @wraps(validate_args)
+ def wrapper(args, defaults=None):
+ if args.dist_train:
+ if not hasattr(args, 'mm_model'):
+ raise ValueError('DistTrain must work with MindSpeed-MM')
+ from mindspeed.multi_modal.dist_train.config.dist_train_config import validate_configs_world_size, \
+ get_dist_model_config, merge_dist_train_args
+ merge_dist_train_args(args.mm_model)
+ validate_configs_world_size(args)
+ cfg = get_dist_model_config(rank=args.rank)
+ args.world_size = cfg.world_size
+ args.tensor_model_parallel_size = cfg.tensor_model_parallel_size
+ args.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size
+ args.context_parallel_size = cfg.context_parallel_size
+ seq_parallel_enabled = args.sequence_parallel
+
+ if defaults is None:
+ defaults = {}
+ replace_model_type_for_deepspeed_moe = False
+ if args.num_experts:
+ if args.use_ascend_coc:
+ raise AssertionError('coc is not compatible with moe models')
+ if args.use_ascend_mc2:
+ raise AssertionError('mc2 is not compatible with moe models')
+ if args.use_legacy_models:
+ if args.moe_model_type == 'megatron_moe':
+ raise AssertionError('megatron_moe is not compatible with --use-legacy-models')
+ replace_model_type_for_deepspeed_moe = True
+ else:
+ if args.moe_model_type == 'deepspeed_moe':
+ raise AssertionError('deepspeed_moe only support with --use-legacy-models')
+ overlap_param_gather_without_mcore_models = False
+ if args.overlap_param_gather and args.use_legacy_models:
+ args.use_legacy_models = False
+ overlap_param_gather_without_mcore_models = True
+
+ #validate optimizer
+ if args.optimizer_selection == 'fused_adamw':
+ print("[WARNING] The default AdamW optimizer is no longer recommended for new edition, Use the torch fused AdamW optimizer by argument --optimizer-selection fused_torch_adamw")
+ elif args.optimizer_selection == 'fused_ema_adamw':
+ if args.reuse_fp32_param:
+ raise AssertionError('fused_ema_adamw optimizer is not compatible with reuse_fp32_param')
+
+ # validate mla
+ if args.multi_head_latent_attention:
+ if args.kv_lora_rank is None:
+ raise AssertionError('The parameter kv-lora-rank should be set when use multi_head_latent_attention.')
+ elif args.v_head_dim is None:
+ raise AssertionError('The parameter v-head-dim should be set when use multi_head_latent_attention.')
+ elif args.qk_rope_head_dim is None:
+ raise AssertionError(
+ 'The parameter qk-rope-head-dim should be set when use multi_head_latent_attention.')
+ elif args.qk_nope_head_dim is None:
+ raise AssertionError(
+ 'The parameter qk-nope-head-dim should be set when use multi_head_latent_attention.')
+
+ # validate yarn
+ if args.rope_scaling_type == "yarn":
+ if args.rope_scaling_original_max_position_embeddings is None:
+ raise AssertionError('The parameter rope_scaling_original_max_position_embeddings should be set '
+ 'when use yarn.')
+
+ # alibi type [2, 3] is only support FA2
+ if args.alibi_fusion_attn_type in [2, 3]:
+ args.use_fusion_attn_v2 = True
+ if args.use_fusion_attn_v2:
+ args.use_flash_attn = True
+ print("[WARNING] \"use_fusion_attn_v2\" is not recommended. This feature is not officially released.")
+
+ # for vpp assert pp should > 2
+ flag_num_layers_per_virtual_pipeline_stage = None
+ flag_overlap_p2p_comm = False
+ if args.num_layers_per_virtual_pipeline_stage is not None and args.pipeline_model_parallel_size == 2:
+ flag_num_layers_per_virtual_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
+ args.num_layers_per_virtual_pipeline_stage = None
+ if args.overlap_p2p_comm:
+ flag_overlap_p2p_comm = True
+
+ # skip validation for deepspeed_moe with CP
+ origin_use_legacy_models = args.use_legacy_models
+ if replace_model_type_for_deepspeed_moe:
+ args.use_legacy_models = False
+ origin_context_parallel_size = args.context_parallel_size
+ args.context_parallel_size = 1
+ original_variable_seq_lengths = args.variable_seq_lengths
+ args = validate_args(args, defaults)
+
+ args.variable_seq_lengths = original_variable_seq_lengths
+ args.context_parallel_size = origin_context_parallel_size
+ model_parallel_size = args.pipeline_model_parallel_size * args.tensor_model_parallel_size
+ args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size)
+
+ if args.dist_train:
+ from mindspeed.multi_modal.dist_train.config.dist_train_config import get_all_config
+ if args.tensor_model_parallel_size > 1 and seq_parallel_enabled:
+ args.sequence_parallel = True
+ if any(cfg.main_dp for cfg in get_all_config().values()):
+ from mindspeed.multi_modal.dist_train.inner_data_parallel.utils import get_global_data_parallel_size
+ args.data_parallel_size = get_global_data_parallel_size()
+
+ if args.global_batch_size is None:
+ args.global_batch_size = args.micro_batch_size * args.data_parallel_size
+ if args.rank == 0:
+ print('Resetting global batch size to {}'.format(
+ args.global_batch_size), flush=True)
+ if args.optimize_vpp_send_recv_comm and args.num_layers_per_virtual_pipeline_stage is None:
+ raise AssertionError('--optimize-vpp-send-recv-comm can only be used with pipeline with interleaving.')
+
+ if replace_model_type_for_deepspeed_moe:
+ args.use_legacy_models = origin_use_legacy_models
+ if args.enable_zero3:
+ print("[WARNING] zero3 currently does not support model save and load")
+ if args.use_ascend_mc2 or args.reuse_fp32_param or args.recompute_granularity is not None or args.use_pipe_experts:
+ raise AssertionError('zero3 cannot be used together with MC2(--use-ascend-mc2), '
+ 'parameter copy reuse(--reuse-fp32-param),'
+ 'recompute(--recompute-granularity)'
+ 'and pipe_experts(use-pipe-experts)')
+
+ # for vpp assert pp should > 2
+ if flag_num_layers_per_virtual_pipeline_stage is not None and args.pipeline_model_parallel_size == 2:
+ args.num_layers_per_virtual_pipeline_stage = flag_num_layers_per_virtual_pipeline_stage
+ args.overlap_p2p_comm = flag_overlap_p2p_comm
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
+ 'number of layers should be divisible by the pipeline parallel size'
+ num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size
+ assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
+ 'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage'
+ args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
+ args.num_layers_per_virtual_pipeline_stage
+
+ # num_layers_per_virtual_pipeline_stage should be meaningful
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ assert num_layers_per_pipeline_stage // args.num_layers_per_virtual_pipeline_stage > 1, \
+ 'considering args of num_layers and pipeline_model_parallel_size, vpp setting should be meaningful'
+
+ # deepspeed dropless does not support pp
+ if args.moe_no_drop and args.pipeline_model_parallel_size > 1:
+ raise AssertionError("--moe-no-drop is not compatible with pp")
+
+ if args.param_and_grad_buffer_pad and args.param_and_grad_buffer_pad <= 0:
+ raise AssertionError('--param-and-grad-buffer-pad must be greater than 0')
+
+ if args.use_fused_rmsnorm:
+ if args.normalization != "RMSNorm":
+ raise AssertionError(
+ '--use-fused-rmsnorm must enable with '
+ '--normalization=RMSNorm, but got normalization'
+ '={}.'.format(args.normalization))
+ if args.use_nd_matmul:
+ raise AssertionError("ND_MatMul is not compatible with fused_rmsnorm.")
+ if args.use_fused_swiglu:
+ if not args.swiglu:
+ raise AssertionError(
+ '--use-fused-swiglu must enable with --swiglu, '
+ 'but --swiglu={}.'.format(args.swiglu))
+ if args.use_fused_rotary_pos_emb:
+ if args.position_embedding_type != 'rope':
+ raise AssertionError(
+ '--use-fused-rotary-pos-emb must enable with'
+ '--position-embedding-type=rope')
+ if args.alibi_fusion_attn_type is not None and args.alibi_fusion_attn_type not in [0, 2, 3]:
+ raise AssertionError('--alibi-fusion-attn-type only support for `0, 2, 3`')
+ if args.reuse_fp32_param and not args.bf16:
+ raise AssertionError('--reuse-fp32-param only support for `bf16`')
+ if args.use_pipe_experts:
+ if args.pipe_experts_multi_data <= 0:
+ raise AssertionError('--pipe-experts-multi-data must greater than 0')
+ if not args.sequence_parallel and args.pipe_experts_multi_stream:
+ raise AssertionError('--pipe-experts-multi-stream can only be used with --sequence-parallel.')
+ local_experts = args.num_experts // args.expert_model_parallel_size
+ if local_experts == 1 and args.pipe_experts_multi_data == 1:
+ print("[WARNING] if local_experts = num_experts // expert_model_parallel_size is equal to 1 "
+ "and --pipe-experts-multi-data is set to 1, "
+ "--use-pipe-experts will be turned off.")
+ args.use_pipe_experts = False
+ if args.moe_alltoall_overlap_comm and not args.moe_token_dispatcher_type == 'alltoall':
+ raise AssertionError('`--moe-alltoall-overlap-comm` only support with `--moe-token-dispatcher-type alltoall`.')
+
+ if args.moe_adaptive_recompute_activation and args.moe_token_dispatcher_type == 'alltoall':
+ raise AssertionError('`--moe-adaptive-recompute-activation` only support with `--moe-token-dispatcher-type allgather`.')
+
+ if args.moe_allgather_overlap_comm and not args.moe_token_dispatcher_type == 'allgather':
+ raise AssertionError('`--moe-allgather-overlap-comm` only support with `--moe-token-dispatcher-type allgather`.')
+
+ if args.moe_alltoall_overlap_comm or args.moe_allgather_overlap_comm:
+ if not args.moe_permutation_async_comm:
+ raise AssertionError('`--moe-alltoall-overlap-comm` and `--moe-allgather-overlap-comm` only support with `--moe-permutation-async-comm`.')
+ if not args.moe_grouped_gemm:
+ raise AssertionError('`--moe-alltoall-overlap-comm` and `--moe-allgather-overlap-comm` only support with `--moe-grouped-gemm`.')
+ if not (args.moe_tp_extend_ep or args.moe_experts_pipeline_degree) and args.moe_alltoall_overlap_comm and args.tensor_model_parallel_size > 1:
+ raise AssertionError('`--moe-alltoall-overlap-comm` do not support tp for now. only support with moe_tp_extend_ep or moe_experts_pipeline_degree when tp > 1.')
+ if args.moe_experts_pipeline_degree:
+ if args.moe_experts_pipeline_degree < 2:
+ raise AssertionError("`--moe-experts-pipeline-degree` should be at least 2. ")
+ if args.moe_experts_pipeline_degree > args.num_experts or args.num_experts % args.moe_experts_pipeline_degree != 0:
+ raise AssertionError("`--moe-experts-pipeline-degree` must smaller than `--num-experts` and `--num-experts` divided by `--moe-experts-pipeline-degree` is an integer.")
+ if args.moe_zero_memory != "disable":
+ raise AssertionError("`--moe-experts-pipeline-degree` is not compatible with `--moe-zero-memory`")
+ if not args.tensor_model_parallel_size or args.tensor_model_parallel_size <= 1:
+ raise AssertionError("`--moe-experts-pipeline-degree` only support when '--tensor-model-parallel-size' is bigger than 1.")
+ if args.expert_model_parallel_size > 1:
+ raise AssertionError("`--moe-experts-pipeline-degree` is not compatible with expert model parallel.")
+ if args.moe_tp_extend_ep:
+ raise AssertionError("`--moe-experts-pipeline-degree` is not compatible with `--moe-tp-extend-ep`.")
+ if args.moe_tp_extend_ep:
+ if args.num_experts % (args.tensor_model_parallel_size * args.expert_model_parallel_size) != 0:
+ raise AssertionError('`--moe-tp-extend-ep` only support when num_experts % ( tp * ep ) == 0')
+ if not (args.moe_permutation_async_comm and args.moe_grouped_gemm):
+ raise AssertionError('`--moe-tp-extend-ep` needs `--moe-permutation-async-comm` and `--moe-grouped-gemm`.')
+ if args.moe_expert_capacity_factor is not None:
+ raise AssertionError('`--moe-tp-extend-ep` only support when moe_expert_capacity_factor is None.')
+ if args.moe_hierarchical_alltoallv:
+ tp = args.tensor_model_parallel_size
+ ep = args.expert_model_parallel_size
+ if ((not args.moe_alltoall_overlap_comm) or (not args.moe_tp_extend_ep) or tp <= 1 or tp > torch.npu.device_count() or
+ ep * tp <= torch.npu.device_count() or args.world_size <= torch.npu.device_count()):
+ raise AssertionError(
+ '`--moe-hierarchical-alltoallv` must have `--moe-alltoall-overlap-comm` on and '
+ '`--moe-tp-extend-ep` on and 1 < tp <= torch.npu.device_count() and cross-device communication')
+ if args.moe_zero_memory_num_layers is not None:
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
+ if args.moe_zero_memory_num_layers < 0 or args.moe_zero_memory_num_layers > num_layers_per_pipeline_stage:
+ raise AssertionError('`--moe-zero-memory-num-layers` must be between 0 and num layers per pipeline stage')
+ if args.moe_zero_memory == "disable":
+ raise AssertionError('`--moe-zero-memory` must be enabled when using `--moe-zero-memory-num-layers`')
+ if args.moe_zero_memory != "disable" and args.moe_allgather_overlap_comm:
+ raise AssertionError('`--moe-zero-memory` do not support `--moe-allgather-overlap-comm` for now.')
+ if args.moe_dynamic_padding and not args.moe_no_drop:
+ raise AssertionError('`--moe-dynamic-padding` only support for `--moe-no-drop`.')
+ if args.moe_permutation_async_comm and args.moe_model_type != 'megatron_moe':
+ raise AssertionError('`--moe-permutation-async-comm` only support for megatron core moe.')
+ if args.moe_bmm_mc2:
+ if args.moe_model_type != 'megatron_moe' or not args.moe_token_dispatcher_type == 'alltoall':
+ raise AssertionError('`--moe-bmm-mc2` only support for megatron core moe and dispatcher is alltoall.')
+ if not args.moe_grouped_gemm:
+ raise AssertionError('`--moe-bmm-mc2` only support when `--moe-grouped-gemm` is true.')
+ if args.moe_tp_extend_ep or args.moe_alltoall_overlap_comm:
+ raise AssertionError(
+ '`--moe-bmm-mc2` not support with `--moe-tp-extend-ep` and `--moe-alltoall-overlap-comm`.')
+
+ if args.context_parallel_size > 1 and args.position_embedding_type == 'alibi':
+ assert args.context_parallel_algo == 'megatron_cp_algo', f"alibi only support megatron_cp_algo"
+ if args.context_parallel_size > 1 and args.context_parallel_algo == 'ulysses_cp_algo':
+ assert args.seq_length % args.context_parallel_size == 0, f"sequence length must be divisible by context_parallel_size"
+ head, remainder = divmod(args.num_attention_heads, args.context_parallel_size * args.tensor_model_parallel_size)
+ assert head >= 1 and remainder == 0, f"num_attention_heads must be divisible by context_parallel_size * tensor_model_parallel_size"
+ args.use_flash_attn = True
+ if args.context_parallel_size > 1 and args.context_parallel_algo == 'megatron_cp_algo':
+ assert args.seq_length % (2 * args.context_parallel_size) == 0, f"sequence length must be divisible by 2 * context_parallel_size"
+ if args.position_embedding_type == 'alibi':
+ assert args.alibi_fusion_attn_type in [2, 3] and args.attention_mask_type == 'causal', f"megatron_cp_algo only support alibi type in [2, 3] and attention_mask_type is causal"
+
+ assert args.cp_window_size >= 1 and args.cp_window_size < args.context_parallel_size, f'cp_window_size should in range [1, context_parallel_size) when using double_ring_attention.'
+ n_window, remainder = divmod(args.context_parallel_size, args.cp_window_size)
+ assert n_window >= 1 and remainder == 0, f'context parallel size must be divisible by cp_window_size when using double ring attention.'
+ args.use_flash_attn = True
+ if args.context_parallel_size > 1 and args.context_parallel_algo == 'hybrid_cp_algo':
+ assert args.ulysses_degree_in_cp is not None, "--ulysses-degree-in-cp must be specified in hybrid_cp_algo"
+ ring_degree, remainder = divmod(args.context_parallel_size, args.ulysses_degree_in_cp)
+ assert ring_degree > 1 and remainder == 0, "--ulysses-degree-in-cp must be devisible by --context-parallel-size"
+ args.ring_degree = ring_degree
+
+ head, remainder = divmod(args.num_attention_heads, args.ulysses_degree_in_cp * args.tensor_model_parallel_size)
+ assert head >= 1 and remainder == 0, f"num_attention_heads must be divisible by ulysse-degree-in-cp * tensor_model_parallel_size in hybrid cp"
+
+ assert args.seq_length % (2 * args.context_parallel_size) == 0, f"sequence length must be divisible by 2 * context_parallel_size in hybrid cp"
+
+ assert args.cp_window_size >= 1 and args.cp_window_size < ring_degree, f'cp_window_size should be in range [1, ring_degree) when using double ring attention with hybrid context parallelism.'
+ n_window, remainder = divmod(ring_degree, args.cp_window_size)
+ assert n_window >= 1 and remainder == 0, f'ring_degree should be divisible by cp_window_size when using double ring with hybrid context parallelism.'
+ args.use_flash_attn = True
+
+ if args.context_parallel_size > 1 and args.context_parallel_algo == 'adaptive_cp_algo':
+ assert args.seq_length % args.context_parallel_size == 0, f"sequence length must be divisible by context_parallel_size"
+ args.use_flash_attn = True
+ if args.context_parallel_size > 1 and args.context_parallel_algo == 'hybrid_adaptive_cp_algo':
+ assert args.ulysses_degree_in_cp is not None, "--ulysses-degree-in-cp must be specified in hybrid_adaptive_cp_algo"
+ ring_degree, remainder = divmod(args.context_parallel_size, args.ulysses_degree_in_cp)
+ assert ring_degree > 1 and remainder == 0, "--ulysses-degree-in-cp must be devisible by --context-parallel-size"
+ head, remainder = divmod(args.num_attention_heads, args.ulysses_degree_in_cp * args.tensor_model_parallel_size)
+ assert head >= 1 and remainder == 0, f"num_attention_heads must be divisible by ulysse-degree-in-cp * tensor_model_parallel_size in hybrid cp"
+ assert args.seq_length % args.context_parallel_size == 0, f"sequence length must be divisible by context_parallel_size in hybrid cp"
+ args.use_flash_attn = True
+
+ # Mandatory modification to SBH, subsequent abandonment of other formats such as BSH,BSND
+ if args.shape_order != 'SBH':
+ args.shape_order = 'SBH'
+ if overlap_param_gather_without_mcore_models:
+ args.use_legacy_models = True
+ if args.transformer_impl == 'transformer_engine':
+ args.transformer_impl = 'local'
+ if args.fp8:
+ raise AssertionError('NPU not supported FP8.')
+ if args.tp_comm_overlap:
+ args.tp_comm_overlap = False
+ if args.recompute_method == "uniform":
+ assert not args.recompute_activation_function, \
+ 'uniform recomputation is not compatible ' \
+ 'with activation function recomputation '
+ assert not args.recompute_norm, \
+ 'uniform recomputation is not compatible ' \
+ 'with norm recomputation '
+ if args.recompute_activation_function and args.recompute_granularity == "selective":
+ raise AssertionError('--recompute-activation-function is not compatible with selective recomputation')
+ adaptive_recompute_enable = args.adaptive_recompute_device_size > 0 or args.adaptive_recompute_device_swap
+ if args.recompute_norm and args.recompute_granularity == "selective":
+ raise AssertionError('--recompute-norm is not compatible with selective recomputation')
+ if args.recompute_norm and args.use_legacy_models:
+ raise AssertionError('--recompute-norm is only supported with mcore models')
+ if args.use_nanopipe and not args.use_legacy_models:
+ raise AssertionError('--use-nanopipe is not available with mcore models')
+ if args.adaptive_recompute_device_swap and not args.use_legacy_models:
+ raise AssertionError('--adaptive-recompute-device-swap is not available with mcore models')
+ if adaptive_recompute_enable:
+ assert args.recompute_granularity is None and args.recompute_method is None, \
+ 'adaptive selective recompute is not compatible with ' \
+ 'recompute_granularity and recompute_method. '
+ assert not args.recompute_activation_function, \
+ 'adaptive selective recompute is not compatible ' \
+ 'with activation function recomputation '
+ assert not args.swap_attention, 'adaptive selective recompute is not compatible with swap_attention feature'
+ assert not args.recompute_in_advance and not args.recompute_in_bubble, 'adaptive selective recompute ' \
+ 'is not compatible with ripipe schedule'
+ assert not args.memory_fragmentation, \
+ 'adaptive selective recompute is not compatible with memory fragmentation'
+ if args.memory_fragmentation:
+ assert not args.use_fused_rotary_pos_emb, \
+ 'memory fragmentation is not compatible with use_fused_rotary_pos_emb'
+ if args.smart_swap:
+ assert not adaptive_recompute_enable, 'smart swap is not compatible with adaptive selective recompute'
+ assert not args.memory_fragmentation, 'smart swap is not compatible with memory fragmentation'
+ if args.adaptive_memory_optimization:
+ assert args.ampipe_degree <= 1, 'adaptive memory optimization is not compatible with ampipe'
+ assert not adaptive_recompute_enable, 'adaptive memory optimization is not compatible with adaptive recomputing'
+ assert args.recompute_granularity is None and args.recompute_method is None, \
+ 'adaptive memory optimization is not compatible with recompute_granularity or recompute_method'
+ assert not args.recompute_activation_function, \
+ 'adaptive memory optimization is not compatible with recompute_activation_function'
+ assert not args.swap_attention, 'adaptive memory optimization is not compatible with swap_attention feature'
+ assert not args.recompute_in_bubble, 'adaptive memory optimization is not compatible with recompute_in_bubble'
+ assert not args.memory_fragmentation, \
+ 'adaptive memory optimization is not compatible with memory_fragmentation'
+ if args.use_flash_attn:
+ assert args.sparse_mode == 0 or args.sparse_mode == 2, f"Only supports sparse modes 0 and 2"
+ args.create_attention_mask_in_dataloader = False
+ if args.automated_pipeline:
+ if args.recompute_activation_function:
+ print("[WARNING] disable activation function recomputation when enabling automated pipeline")
+ args.recompute_activation_function = False
+ if args.recompute_granularity is not None or args.recompute_method is not None:
+ print("[WARNING] disable recompute granularity and recompute method when enabling automated pipeline")
+ args.recompute_granularity = None
+ args.recompute_method = None
+ if args.noop_layers:
+ print("[WARNING] disable noop_layers when enabling automated pipeline")
+ args.noop_layers = None
+ if args.automated_pipeline_perf:
+ if args.automated_pipeline:
+ print("[WARNING] disable automated pipeline when enabling automated pipeline performance version")
+ args.automated_pipeline = False
+ if args.num_layers_per_virtual_pipeline_stage is not None:
+ raise AssertionError('automated pipeline performance is temporarily incompatible with virtual pipeline')
+ if args.use_ascend_mc2:
+ if args.use_ascend_coc:
+ raise AssertionError('--mc2 and coc can not be used together')
+ if args.use_nd_matmul:
+ if args.normalization == 'LayerNorm':
+ raise AssertionError('ND_MatMul is temporarily incompatible with LayerNorm')
+ if args.load is not None or args.pretrained_checkpoint is not None:
+ raise AssertionError('ND_MatMul does not support loading weights for training temporarily')
+ if args.tensor_model_parallel_size % args.nd1_dim1_size != 0:
+ raise AssertionError('tensor_model_parallel_size must be divisible by nd1_dim1_size')
+ if args.tensor_model_parallel_size % args.nd2_dim1_size != 0:
+ raise AssertionError('tensor_model_parallel_size must be divisible by nd2_dim1_size')
+
+ args.reduce_recompute_for_last_chunk = False
+ if args.recompute_in_advance:
+ args.reduce_recompute_for_last_chunk = True
+ if args.recompute_method == "uniform":
+ raise AssertionError('recompute_in_advance does not support uniform recompute_method')
+ if not args.recompute_num_layers and not args.adaptive_memory_optimization:
+ raise AssertionError('recompute_num_layers can not be None or 0 when using recompute_in_advance')
+ if args.pipeline_model_parallel_size <= 1 or args.num_layers_per_virtual_pipeline_stage is None:
+ raise AssertionError('recompute_in_advance only support pipelining with interleaving')
+ if args.num_layers_per_virtual_pipeline_stage != 1:
+ args.recompute_in_advance = False
+ if args.recompute_in_bubble:
+ if args.recompute_num_layers:
+ raise AssertionError('recompute_num_layers must be None or 0 when using recompute_in_bubble')
+ if args.pipeline_model_parallel_size <= 1 or args.num_layers_per_virtual_pipeline_stage is None:
+ raise AssertionError('recompute_in_bubble only support pipelining with interleaving')
+ if not args.swap_attention:
+ # Following is a trick to realize bubble recomputation. We first enable all recomputation,
+ # and then disable recomputation for all layers except the ones chosen for bubble recomputation.
+ args.recompute_granularity = "full"
+ args.recompute_method = "block"
+ if args.enable_recompute_layers_per_pp_rank:
+ args.recompute_num_layers = args.num_layers // args.pipeline_model_parallel_size
+ else:
+ args.recompute_num_layers = args.num_layers_per_virtual_pipeline_stage
+ if isinstance(args.noop_layers, str):
+ noop_layers = set()
+ for x in args.noop_layers.split(','):
+ if int(x) >= args.num_layers or int(x) < 0:
+ raise AssertionError(f'each element in args.noop_layers({args.noop_layers}) should bigger or equal '
+ f'to 0 and smaller than args.num_layers({args.num_layers})')
+ noop_layers.add(int(x))
+ args.noop_layers = noop_layers
+
+ if args.ampipe_degree > 1:
+ assert args.use_flash_attn, "ampipe only supports flash attention, please enable '--use-flash-attn'."
+ assert args.num_experts is not None, "ampipe only supports MoE model."
+ assert args.expert_model_parallel_size > 1, "ampipe only supports expert_model_parallel_size > 1"
+ assert args.moe_model_type == 'deepspeed_moe', "ampipe only supports deepspeed_moe."
+ assert not args.use_ascend_mc2, "ampipe does't supports ascend mc2 for now."
+ assert not args.add_bias_linear, "ampipe does't supports bias linear for now."
+ assert not args.overlap_grad_reduce, "ampipe does't supports overlap_grad_reduce for now."
+ assert not args.overlap_param_gather, "ampipe does't supports overlap_param_gather for now."
+ assert not args.use_nanopipe, "ampipe does't supports use_nanopipe for now."
+ assert not args.recompute_in_bubble, "ampipe does't supports ripipe recompute_in_bubble for now."
+ assert not args.recompute_in_advance, "ampipe does't supports ripipe recompute_in_advance for now."
+ assert not args.adaptive_recompute_device_swap, "ampipe does't supports ripipe recompute_in_advance for now."
+ if args.sequence_parallel:
+ assert args.seq_length % (args.ampipe_degree * args.tensor_model_parallel_size) == 0, \
+ "sequence length must be divisible by ampipe_degree * tensor_model_parallel_size"
+ if args.context_parallel_size > 1:
+ assert args.context_parallel_algo == 'megatron_cp_algo', "ampipe only supports megatron_cp_algo"
+ assert args.ampipe_degree == 2, "ampipe only supports ampipe_degree=2 when context_parallel_size>1"
+ slice_size, remainder = divmod(args.seq_length, 2 * args.ampipe_degree * args.context_parallel_size)
+ assert remainder == 0, \
+ "sequence length must be divisible by 2 * ampipe_degree * context_parallel_size"
+ if args.sequence_parallel:
+ assert slice_size % (args.tensor_model_parallel_size) == 0, \
+ "sequence length must be divisible by 2 * ampipe_degree * context_parallel_size * tensor_model_parallel_size"
+ if args.use_pipe_experts:
+ if args.pipe_experts_multi_data % args.ampipe_degree != 0:
+ print("[WARNING] if pipe_experts_multi_data isn't divisible by ampipe_degree "
+ "--use-pipe-experts will be turned off.")
+ args.use_pipe_experts = False
+ args.pipe_experts_multi_stream = False
+ args.pipe_experts_multi_data = 1
+ if args.tp_2d:
+ if args.sequence_parallel:
+ raise AssertionError('2d tp does not support sequence parallel')
+ if args.use_fused_rmsnorm:
+ raise AssertionError('2d tp does not support fused rmsnorm')
+ if args.use_nanopipe:
+ raise AssertionError('tp-2d does not support nano-pipe')
+ if args.ampipe_degree > 1:
+ raise AssertionError('tp-2d does not support ampipe')
+ if args.context_parallel_algo not in ['megatron_cp_algo', 'ulysses_cp_algo']:
+ raise AssertionError('tp-2d now only support megatron_cp_algo or ulysses_cp_algo')
+ if args.use_ascend_coc:
+ raise AssertionError('tp-2d does not support ascend coc')
+ if args.tensor_model_parallel_size // args.tp_x != args.tp_y:
+ raise AssertionError('need satisfy tp = tp_x * tp_y')
+ if args.expert_model_parallel_size > 1:
+ if args.moe_token_dispatcher_type != "allgather":
+ raise AssertionError('2d tp only support allgather megatron-moe now')
+
+ if args.expert_interval <= 0 or args.expert_interval > args.num_layers:
+ raise AssertionError("--expert-interval must be between 1 and num layers")
+ if args.moe_train_capacity_factor <= 0.0:
+ raise AssertionError("--moe-train-capacity-factor must be greater than 0.0")
+
+ if args.gemm_gradient_accumulation_fusion:
+ if not args.moe_grouped_gemm:
+ raise AssertionError('`--gemm-gradient-accumulation-fusion` only support with `--moe-grouped-gemm`.')
+
+ if args.use_legacy_models:
+ if args.overlap_param_gather and args.reuse_fp32_param:
+ raise AssertionError('In legacy, `overlap_param_gather` does not support `reuse_fp32_param`.')
+
+ if args.fp16:
+ args.gradient_accumulation_fusion = False
+ warnings.warn("Unsupported gradient fp16 bf16 for gradient accumulation fusion")
+
+ if args.reset_attention_mask and args.attention_mask_type == 'causal':
+ assert args.context_parallel_algo == 'megatron_cp_algo', 'accelerated eod reset mode only support ring attention'
+
+ if args.context_parallel_kv_cache_policy:
+ if args.context_parallel_size == 1:
+ raise AssertionError(
+ 'context parallel size must larger than 1 when --context-parallel-kv-cache-policy is set.')
+ if not args.use_flash_attn:
+ raise AssertionError(
+ '--context-parallel-kv-cache-policy only support use flash attention.'
+ )
+
+ if args.context_parallel_cache_interval != 0:
+ if not args.context_parallel_kv_cache_policy:
+ raise AssertionError(
+ '--context-parallel-cache-interval only can be used when --context-parallel-kv-cache-policy is set.'
+ )
+ if args.context_parallel_cache_interval >= args.num_layers:
+ raise AssertionError(
+ '--context-parallel-cache-interval should be smaller than the number of layers.'
+ )
+ if args.context_parallel_cache_interval < 0:
+ raise AssertionError(
+ '--context-parallel-cache-interval cannot be negative number.'
+ )
+
+ if args.use_ulysses_allgather_kv:
+ if args.context_parallel_size == 1:
+ raise AssertionError(
+ 'context parallel size must larger than 1 when --use-ulysses-allgather-kv is set.')
+ if args.context_parallel_algo != 'ulysses_cp_algo':
+ raise AssertionError(
+ '--context_parallel-algo should be ulysses_cp_algo when using --use-ulysses-allgather-kv.'
+ )
+ if not args.group_query_attention:
+ raise AssertionError(
+ '--use-ulysses-allgather-kv needs to enable --group-query-attention.'
+ )
+
+ from megatron.training.arguments import _print_args
+ _print_args('arguments', args, True)
+
+ for feature in FEATURES_LIST:
+ if feature.optimization_level <= args.optimization_level and \
+ (getattr(args, feature.feature_name, None) or feature.default_patches):
+ feature.pre_validate_args(args)
+ feature.validate_args(args)
+ feature.post_validate_args(args)
+
+ return args
+
+ return wrapper
+
+
+def add_parser_argument_choices_value(parser, argument_name, value):
+ if parser._actions:
+ for action in parser._actions:
+ if isinstance(action, argparse._ArgumentGroup):
+ add_parser_argument_choices_value(action, argument_name)
+ elif isinstance(action, argparse.Action) and argument_name in action.option_strings:
+ action.choices.append(value)
+
+
+def _add_alibi_args(parser):
+ add_parser_argument_choices_value(parser, "--position-embedding-type", 'alibi')
+
+ group = parser.add_argument_group(title='alibi')
+ group.add_argument('--square-alibi-mask',
+ action='store_true',
+ default=False,
+ help='attention mask of alibi is squared')
+ group.add_argument('--fill-neg-inf',
+ action='store_true',
+ default=False,
+ help='fill alibi with negative inf')
+
+ group.add_argument('--alibi-fusion-attn-type',
+ type=int,
+ help='alibi pse type, support for 0,2,3')
+
+ group.add_argument('--alibi-diagonal-opposite',
+ action='store_true',
+ default=False,
+ help='make alibi diagonal opposite')
+
+ return parser
+
+
+def _add_ndmm_args(parser):
+ group = parser.add_argument_group(title='ndmm')
+ group.add_argument('--use-nd-matmul', action='store_true', default=False,
+ help='use use-nd-matmul to replace megatron-style tensor parallel')
+ group.add_argument('--nd1-dim1-size', type=int, default=1,
+ help='Dim1 of the first nd matmul when use-3d-matmul is True')
+ group.add_argument('--nd2-dim1-size', type=int, default=1,
+ help='Dim1 of the second nd matmul when use-3d-matmul is True')
+ return parser
+
+
+def _add_auto_parallel_args(parser):
+ group = parser.add_argument_group(title='auto_parallel')
+ group.add_argument('--auto-parallel', action='store_true',
+ help='enable automatic parallelism with auto-parallel')
+ group.add_argument('--nnodes', type=int, default=1, help='the number of node in the cluster')
+ group.add_argument('--nproc-per-node', type=int, default=8, help='the number of NPU on each node')
+ group.add_argument('--master-addr', type=str, default=None, help='the ip-address of master node')
+ group.add_argument('--master-port', type=str, default=None, help='the ip-port of master node')
+ group.add_argument('--node-rank', type=int, default=0,
+ help='the rank of nodes in the cluster, starting from 0 and increment by 1')
+ group.add_argument('--profile-operator', action='store_true', help='')
+ group.add_argument('--profile-memory', action='store_true', help='')
+ group.add_argument('--prof-file', type=str, default=None, help='')
+ return parser
+
+
+def _add_auto_parallel_mm_args(parser):
+ group = parser.add_argument_group(title='auto_parallel_mm')
+ group.add_argument('--auto-parallel-mm', action='store_true', default=False,
+ help='enable multimode automated parallel policy search')
+ group.add_argument('--auto-parallel-profile', action='store_true', default=False,
+ help='multimode performance sampling')
+
+ return parser
+
+
+def _add_2d_tp_args(parser):
+ group = parser.add_argument_group(title='2d-tp')
+ group.add_argument('--tp-2d', action='store_true', default=False,
+ help='use use-2d-tp to replace megatron-style tensor parallel')
+ group.add_argument('--tp-x', type=int, default=1,
+ help='the fist dim tensor parallel size for Linear')
+ group.add_argument('--tp-y', type=int, default=1,
+ help='the second dim tensor parallel size for Linear')
+ group.add_argument('--enable-overlap-ag-with-matmul', action='store_true', default=False,
+ help='use enable-overlap-ag-with-matmul to overlap all-gather with matmul')
+ group.add_argument('--enable-overlap-matmul-with-rs', action='store_true', default=False,
+ help='use enable-overlap-matmul-with-rs to overlap matmul with reduce-scatter')
+ group.add_argument('--enable-backward-overlap-ag-with-matmul', action='store_true', default=False,
+ help='use enable-backward-overlap-ag-with-matmul to overlap all-gather with matmul in backward')
+ return parser
+
+
+def _add_hccl_group_buffer_args(parser):
+ group = parser.add_argument_group(title='hccl-group-buffer')
+ group.add_argument('--hccl-group-buffer', type=str, default=None,
+ help='the hccl buffer for group')
+ group.add_argument('--hccl-group-buffer-adaptive', action='store_true', default=False,
+ help='the hccl buffer for group adaptively')
+ group.add_argument('--hccl-ep-group-buffer-adaptive-factor', type=float, default=-1.0,
+ help='the ep group buffer factor')
+ return parser
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/auto_tuning.py b/model/train/yoco_moe/mindspeed/auto_tuning/auto_tuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..e27e129e75eb8323f885eb22ede4e2bd302e4fd7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/auto_tuning.py
@@ -0,0 +1,152 @@
+import json
+import logging
+import os
+import stat
+import time
+import pickle
+from argparse import Namespace
+
+from mindspeed.auto_tuning.utils.logger import init_logger, get_logger
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.module.memory.memory_modeling import MemoryModeling
+from mindspeed.auto_tuning.module.model_performance import ModelPerformance
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_node_parse import GatherNodeProfiling
+from mindspeed.auto_tuning.module.search.search_engine import search_demo
+from mindspeed.auto_tuning.utils.runner.model_executor import ExecutorFlag, ModelExecutor
+from mindspeed.auto_tuning.utils.runner.torchrun_runner import TorchRunRunner
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.generate_profiling_configs import generate_profiling_configs
+from mindspeed.auto_tuning.utils.utils import get_prof_dir
+from mindspeed.auto_tuning.utils.restricted_unpickler import restricted_loads
+
+
+def auto_tuning(args: Namespace, working_dir: str):
+ init_logger(args.auto_tuning_log_level)
+ logger = get_logger("main")
+ start_time = time.time()
+ executor = ModelExecutor(TorchRunRunner())
+
+ # Force refresh model args just in case model has been modified after previous run.
+ logger.info("<==========Begin to parse args==========>")
+ executor.execute(working_dir, flag=ExecutorFlag.PARSE_ARGS)
+ hardware_parse_path = os.path.join(working_dir, Hardware.HARDWARE_PARSE_FILENAME)
+ args_parse_path = os.path.join(working_dir, ModelConfig.ARGS_PARSE_FILENAME)
+ try:
+ with open(hardware_parse_path, mode="rb") as file:
+ hardware: Hardware = restricted_loads(file) # type: ignore
+ with open(args_parse_path, mode="rb") as file:
+ model_config: ModelConfig = restricted_loads(file) # type: ignore
+ except pickle.UnpicklingError as e:
+ logger.error(f"Incorrect pickle format. UnpicklingError: {e}")
+ raise e
+ Hardware().load(hardware)
+ model_config.disable_cp_flag = False
+ logger.info("<==========Finished parsing args==========>")
+
+ # Memory modeling
+ MemoryModeling.set_model_cfg(model_config)
+ static_list, dynamic_list = MemoryModeling.generate_mem_modeling_profiling_list()
+ logger.info("<==========Begin to profile static memory==========>")
+ for cfg, filename in static_list:
+ if not os.path.exists(os.path.join(working_dir, filename)):
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ pkl_filename = os.path.join(working_dir, f'ootb_{Hardware().node_rank}.pkl')
+ with os.fdopen(os.open(pkl_filename, flags, mode=mode), 'wb') as f:
+ pickle.dump(cfg, f)
+ executor.execute(working_dir, output_filename=filename, cfg=cfg, flag=ExecutorFlag.PARSE_MODEL)
+ logger.info("<==========Finished profiling static memory==========>")
+ logger.info("<==========Begin to profile dynamic memory==========>")
+ for cfg in dynamic_list:
+ path = os.path.join(working_dir, get_prof_dir(cfg))
+ if not os.path.exists(path):
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ pkl_filename = os.path.join(working_dir, f'ootb_{Hardware().node_rank}.pkl')
+ with os.fdopen(os.open(pkl_filename, flags, mode=mode), 'wb') as f:
+ pickle.dump(cfg, f)
+ executor.execute(working_dir, output_filename=path, cfg=cfg, flag=ExecutorFlag.PROFILE)
+ logger.info("<==========Finished profiling dynamic memory==========>")
+ MemoryModeling.modeling(working_dir)
+ model_parser_end_time = time.time()
+ logger.info("Model parser cost time: %sms", str((model_parser_end_time - start_time) * 1000))
+
+ hardware_config = Hardware()
+ profiling_cfg_list = generate_profiling_configs(model_config)
+
+ logger.info("profile_cfgs (tp, pp, dp, cp, ep, #layers, seq_len):")
+ logger.info(",".join(
+ str((cfg.tp,
+ cfg.pp,
+ cfg.dp,
+ cfg.cp,
+ cfg.ep,
+ cfg.num_layers,
+ cfg.seq_length))
+ for cfg in profiling_cfg_list))
+
+ generate_profiling_config_end_time = time.time()
+
+ profiling_results = []
+ logger.info("<==========Begin profiling==========>")
+ logger.info("This process will run the script and get some profiling results.")
+ logger.info("Please wait for a while.")
+ count = 1
+ for profiling_cfg in profiling_cfg_list:
+ # tracking the order of profiling all over the list
+ logger.info('<==========the %s/%s loop==========>', str(count), str(len(profiling_cfg_list)))
+ logger.info("profile_db_configs (tp, pp, dp, cp, ep, #layers, seq_len):")
+ logger.info(str([profiling_cfg.tp,
+ profiling_cfg.pp,
+ profiling_cfg.dp,
+ profiling_cfg.cp,
+ profiling_cfg.ep,
+ profiling_cfg.num_layers,
+ profiling_cfg.seq_length]))
+ res_dir = f"{working_dir}/{get_prof_dir(profiling_cfg)}"
+ if not os.path.exists(res_dir):
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ pkl_filename = os.path.join(working_dir, f'ootb_{Hardware().node_rank}.pkl')
+ with os.fdopen(os.open(pkl_filename, flags, mode=mode), 'wb') as f:
+ pickle.dump(profiling_cfg, f)
+ executor.execute(working_dir, output_filename=res_dir, cfg=profiling_cfg, flag=ExecutorFlag.PROFILE)
+
+ profiling_node_parse = GatherNodeProfiling(res_dir)
+ profiling_res = profiling_node_parse.fuse_node_pkl()
+
+ profiling_results.append([profiling_cfg, profiling_res])
+ count += 1
+
+ profiling_and_parser_end_time = time.time()
+
+ # Performance Modeling
+ model_performance = ModelPerformance(hardware_config, model_config, working_dir)
+ model_performance.get_profiling_info(profiling_results)
+
+ final_cfgs, unsampled_profiling = search_demo(model_config=model_config,
+ perf_obj_function=model_performance.performance,
+ working_dir=working_dir)
+ logger.info("model config is that:\n%s", str(model_config))
+ logger.info("hardware config is that:\n%s", str(hardware_config))
+
+ search_cfg_end_time = time.time()
+ logger.info(">>>>>> Generate profiling config cost time: %sms",
+ str((generate_profiling_config_end_time - model_parser_end_time) * 1000))
+ logger.info(">>>>>> Profiling and parser cost time: %sms",
+ str((profiling_and_parser_end_time - generate_profiling_config_end_time) * 1000))
+ logger.info(">>>>>> Search_cfg cost time: %sms",
+ str((search_cfg_end_time - profiling_and_parser_end_time) * 1000))
+ logger.info(">>>>>> Total cost time: %sms",
+ str((search_cfg_end_time - start_time) * 1000))
+
+ logger.info("<==========Final config generated==========>")
+ logger.info("The recommended configs are:")
+ for i, final_cfg in enumerate(final_cfgs):
+ if final_cfg:
+ logger.info("<==========Top #%s config==========>", str(i))
+ if logger.getEffectiveLevel() == logging.DEBUG:
+ logger.debug("\n%s", str(final_cfg))
+ else:
+ logger.info("\n%s", ModelConfig.__str__(final_cfg))
+ logger.info("<==========Launch training==========>")
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/config/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/config/generate_profiling_configs.py b/model/train/yoco_moe/mindspeed/auto_tuning/config/generate_profiling_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..76dd8286c647dff4f53b6beb156cfdd084ee2e02
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/config/generate_profiling_configs.py
@@ -0,0 +1,103 @@
+from typing import List
+from dataclasses import replace
+
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.utils.utils import get_tp_for_profiling, get_seq_length_for_profiling
+
+
+def generate_profiling_configs(model_cfg: ModelConfig) -> List[SearchConfig]:
+ profile_cfgs: List[SearchConfig] = list()
+
+ base_cfg = SearchConfig()
+ base_cfg.copy_from_config(model_cfg)
+ base_cfg.tensor_model_parallel_size = get_tp_for_profiling()
+ base_cfg.context_parallel_size = 1
+ base_cfg.pipeline_model_parallel_size = 1
+ base_cfg.seq_length = get_seq_length_for_profiling(model_cfg)
+ if model_cfg.is_moe():
+ base_cfg.num_experts = 4
+ base_cfg.expert_model_parallel_size = 4
+ bi_tp = base_cfg.tp * 2
+ if "910B" in Hardware().device_type and base_cfg.tp == 8:
+ bi_tp = 4
+
+ if "910_9" in Hardware().device_type and base_cfg.tp == 8:
+ bi_tp = 16
+
+ # base config
+ # 4dp
+ profile_cfgs.append(base_cfg)
+
+ # 4dp mc2
+ gen_cfg_mc2 = replace(base_cfg, use_ascend_mc2=True)
+ profile_cfgs.append(gen_cfg_mc2)
+
+ # 2dp 2tp
+ gen_cfg = replace(base_cfg)
+ gen_cfg.tensor_model_parallel_size = bi_tp
+ if model_cfg.is_moe():
+ gen_cfg.expert_model_parallel_size = 2
+ profile_cfgs.append(gen_cfg)
+
+ # 2dp 2tp mc2
+ gen_cfg_mc2 = replace(gen_cfg, use_ascend_mc2=True)
+ profile_cfgs.append(gen_cfg_mc2)
+
+ # 2dp 2pp
+ gen_cfg = replace(base_cfg)
+ gen_cfg.pipeline_model_parallel_size = 2
+ if model_cfg.is_moe():
+ gen_cfg.expert_model_parallel_size = 2
+ profile_cfgs.append(gen_cfg)
+
+ # CP config
+ if not model_cfg.disable_cp_flag:
+ # 4cp
+ gen_cfg = replace(base_cfg)
+ gen_cfg.context_parallel_size = 4
+ if gen_cfg.seq_length // gen_cfg.cp >= 2 * 1024:
+ profile_cfgs.append(gen_cfg)
+
+ # 2cp
+ gen_cfg = replace(base_cfg)
+ gen_cfg.context_parallel_size = 2
+ if model_cfg.is_moe():
+ gen_cfg.expert_model_parallel_size = 2
+ if gen_cfg.seq_length // gen_cfg.cp >= 2 * 1024:
+ profile_cfgs.append(gen_cfg)
+
+ # roce cp
+ gen_cfg = replace(base_cfg)
+ gen_cfg.context_parallel_size = 2
+ gen_cfg.tensor_model_parallel_size = bi_tp
+ if model_cfg.is_moe():
+ gen_cfg.expert_model_parallel_size = 2
+ if gen_cfg.seq_length // gen_cfg.cp >= 2 * 1024:
+ profile_cfgs.append(gen_cfg)
+
+ # MLP config
+ if model_cfg.is_moe():
+ gen_cfg = replace(base_cfg)
+ gen_cfg.expert_model_parallel_size = 1
+ gen_cfg.pipeline_model_parallel_size = 1
+ profile_cfgs.append(gen_cfg)
+
+ gen_cfg_pp2 = replace(gen_cfg)
+ gen_cfg_pp2.pipeline_model_parallel_size = 2
+ profile_cfgs.append(gen_cfg_pp2)
+
+ # half-seq
+ gen_cfg = replace(base_cfg)
+ if model_cfg.is_moe():
+ gen_cfg.expert_model_parallel_size = 1
+ gen_cfg.seq_length = base_cfg.seq_length // 2
+ if gen_cfg.seq_length < 2 * 1024:
+ gen_cfg.seq_length = gen_cfg.seq_length * 4
+
+ for cfg in profile_cfgs:
+ cfg.prepare_for_profiling()
+ cfg.num_layers = cfg.pp
+
+ return profile_cfgs
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/config/model_config.py b/model/train/yoco_moe/mindspeed/auto_tuning/config/model_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..609629e48a919a861f11b453b629e90613500e19
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/config/model_config.py
@@ -0,0 +1,156 @@
+from typing import List, Optional
+from dataclasses import dataclass
+
+from mindspeed.auto_tuning.utils.dtype import DTYPE
+
+
+@dataclass
+class ModelConfig:
+ ARGS_PARSE_FILENAME = "auto_tuning_model_args.json"
+ # Set all parameter defaults to None, so that errors will occur when calculations are performed with
+ # unresolved parameters, reflect issues in time.
+ # Parallel configs
+ tensor_model_parallel_size: int = None # type: ignore
+ context_parallel_size: int = None # type: ignore
+ pipeline_model_parallel_size: int = None # type: ignore
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
+ data_parallel_size: int = None # type: ignore
+ sequence_parallel: bool = None # type: ignore
+ use_distributed_optimizer: bool = None # type: ignore
+ global_batch_size: int = None # type: ignore
+ micro_batch_size: int = None # type: ignore
+
+ # Model configs
+ num_layers: int = None # type: ignore
+ num_attention_heads: int = None # type: ignore
+ hidden_size: int = None # type: ignore
+ ffn_hidden_size: int = None # type: ignore
+ add_bias_linear: bool = None # type: ignore
+ swiglu: bool = None # type: ignore
+ fp16: bool = None # type: ignore
+ bf16: bool = None # type: ignore
+ use_ascend_mc2: bool = None # type: ignore
+
+ # Data configs
+ seq_length: int = None # type: ignore
+
+ # MoE configs
+ num_experts: Optional[int] = None
+ moe_router_topk: Optional[int] = None
+ moe_train_capacity_factor: Optional[float] = None
+ expert_model_parallel_size: Optional[int] = None
+ enable_token_rearrange_opt: bool = None # type: ignore
+
+ # Memory configs
+ recompute_granularity: Optional[str] = None
+ recompute_method: Optional[str] = None
+ recompute_num_layers: Optional[int] = None
+ use_flash_attn: bool = None # type: ignore
+ adaptive_recompute_device_swap: bool = None # type: ignore
+
+ # Train configs
+ train_iters: int = None # type: ignore
+ profile: bool = None # type: ignore
+ profile_step_start: int = None # type: ignore
+ profile_step_end: int = None # type: ignore
+ profile_ranks: List[int] = None # type: ignore
+ profile_level: str = None # type: ignore
+ profile_with_cpu: bool = None # type: ignore
+ profile_with_stack: bool = None # type: ignore
+ profile_with_memory: bool = None # type: ignore
+ profile_record_shapes: bool = None # type: ignore
+
+ # World Size
+ global_world_size: int = None # type: ignore
+
+ # JIT
+ jit_compile: bool = None # type: ignore
+
+ # Flags
+ disable_cp_flag: bool = False
+
+ def __str__(self) -> str:
+ rt = list()
+ rt.append(f"{'Data Parallel Size':<30}{str(self.dp):<40}")
+ rt.append(f"{'Tensor Parallel Size':<30}{str(self.tp):<40}")
+ rt.append(f"{'Pipeline Parallel Size':<30}{str(self.pp):<40}")
+ rt.append(f"{'Virtual Pipeline Size':<30}{str(self.vpp):<40}")
+ rt.append(f"{'Context Parallel Size':<30}{str(self.cp):<40}")
+ rt.append(f"{'Expert Parallel Size':<30}{str(self.ep):<40}")
+ rt.append(f"{'ZeRO1':<30}{str(self.zero1):<40}")
+ rt.append(f"{'MC2':<30}{str(self.use_ascend_mc2):<40}")
+ rt.append(f"{'Token Rearrange':<30}{str(self.enable_token_rearrange_opt):<40}")
+ rt.append(f"{'Micro Batch Size':<30}{str(self.mbs):<40}")
+ rt.append(f"{'Recompute layer':<30}{str(self.re_layer):<40}")
+ return "\n".join(rt)
+
+ @property
+ def tp(self) -> int:
+ return self.tensor_model_parallel_size
+
+ @property
+ def cp(self) -> int:
+ return self.context_parallel_size
+
+ @property
+ def pp(self) -> int:
+ return self.pipeline_model_parallel_size
+
+ @property
+ def layers_per_vpp(self) -> Optional[int]:
+ return self.num_layers_per_virtual_pipeline_stage
+
+ @property
+ def vpp(self) -> Optional[int]:
+ if self.num_layers_per_virtual_pipeline_stage:
+ return self.num_layers // (self.pp * self.num_layers_per_virtual_pipeline_stage)
+ return None
+
+ @property
+ def dp(self) -> int:
+ return self.data_parallel_size
+
+ @property
+ def ep(self) -> Optional[int]:
+ return self.expert_model_parallel_size or 1
+
+ @property
+ def zero1(self) -> bool:
+ return self.use_distributed_optimizer
+
+ @property
+ def gbs(self) -> int:
+ return self.global_batch_size
+
+ @property
+ def mbs(self) -> int:
+ return self.micro_batch_size
+
+ @property
+ def adaptive_recompute(self) -> bool:
+ return self.adaptive_recompute_device_swap
+
+ @property
+ def re_layer(self) -> Optional[int]:
+ return self.recompute_num_layers
+
+ @property
+ def num_micro_batches(self) -> int:
+ return self.global_batch_size // self.micro_batch_size
+
+ @property
+ def dtype(self) -> DTYPE:
+ if self.fp16:
+ return DTYPE.fp16
+ elif self.bf16:
+ return DTYPE.bf16
+ return DTYPE.fp32
+
+ def is_full_recompute(self) -> bool:
+ return self.recompute_granularity is not None and \
+ self.recompute_granularity == "full" and \
+ self.recompute_method is not None and \
+ self.recompute_method == "block"
+
+ def is_moe(self) -> bool:
+ return self.num_experts is not None
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/config/search_config.py b/model/train/yoco_moe/mindspeed/auto_tuning/config/search_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..405d92db9e981f56c74ff8447fe7b814ccb65dba
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/config/search_config.py
@@ -0,0 +1,60 @@
+from typing import Optional
+from dataclasses import dataclass
+
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.module.hardware import Hardware
+
+
+@dataclass
+class SearchConfig(ModelConfig):
+ memory: Optional[float] = None
+ performance: Optional[float] = None
+
+ def __str__(self) -> str:
+ rt = list()
+ if self.performance:
+ rt.append(f"{'Performance':<30}{str(self.performance):<40}")
+ if self.memory:
+ rt.append(f"{'Memory':<30}{str(self.memory):<40}")
+ return super().__str__() + "\n" + "\n".join(rt)
+
+ def copy_from_config(self, cfg: ModelConfig) -> None:
+ for k, v in vars(cfg).items():
+ if k in self.__dict__:
+ self.__dict__[k] = v
+
+ def prepare_for_profiling(self) -> None:
+ self.use_distributed_optimizer = True
+ self.recompute_granularity = "full"
+ self.recompute_method = "block"
+ self.adaptive_recompute_device_swap = False
+ self.global_world_size = Hardware().num_devices
+ self.micro_batch_size = 1
+
+ self.normalize()
+ self.global_batch_size = self.dp * self.pp * self.mbs
+
+ self.train_iters = 10
+ self.profile = True
+ self.profile_step_start = 8
+ self.profile_step_end = 9
+ self.profile_ranks = list(range(Hardware().num_devices))
+ self.profile_level = "level1"
+ self.profile_with_cpu = True
+ self.profile_with_stack = False
+ self.profile_with_memory = True
+ self.profile_record_shapes = True
+
+ def normalize(self) -> None:
+ self.data_parallel_size = self.global_world_size // \
+ (self.tp * self.cp * self.pp)
+
+ if self.is_moe():
+ self.enable_token_rearrange_opt = True
+
+ if self.adaptive_recompute_device_swap:
+ self.recompute_granularity = None
+ self.recompute_method = None
+ self.recompute_num_layers = None
+ elif self.is_full_recompute():
+ self.recompute_num_layers = self.num_layers // self.pp
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/mindspeed_adaptor.py b/model/train/yoco_moe/mindspeed/auto_tuning/mindspeed_adaptor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eaa89d64582576c545490984f47dc85ebbcb00a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/mindspeed_adaptor.py
@@ -0,0 +1,167 @@
+from typing import Dict, List, Optional, Tuple
+import os
+import stat
+from argparse import Namespace
+
+import pickle
+from torch.nn import Module
+import torch.distributed as dist
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.utils.restricted_unpickler import restricted_loads
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.module.memory.model_param import ModelParam
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+
+
+_logger = get_logger("MindSpeedAdaptor")
+
+
+class MindSpeedAdaptor:
+
+ def __new__(cls):
+ raise NotImplementedError("MindSpeedAdaptor is a static class.")
+
+ @staticmethod
+ def get_hardware(working_dir: str = str()) -> Hardware:
+ import acl
+ from .utils.mem_utils import mem_b_to_mb
+
+ device_type = acl.get_soc_name()
+
+ devices_per_node, _ = acl.rt.get_device_count()
+
+ num_nodes = dist.get_world_size() // devices_per_node
+ device_rank = dist.get_rank()
+ node_rank = device_rank // devices_per_node
+ device_id = device_rank % devices_per_node
+ acl.rt.set_device(device_id)
+ _, memory_limit, _ = acl.rt.get_mem_info(1)
+ acl.rt.reset_device(device_id)
+
+ host_ip = os.environ.get("MASTER_ADDR", None)
+
+ if device_rank == 0:
+ import getpass
+ user_name = getpass.getuser()
+
+ object_list = [user_name]
+ else:
+ object_list = [None]
+
+ dist.broadcast_object_list(object_list)
+ user_name: str = object_list[0] # type: ignore
+
+ hardware = Hardware()
+ hardware.device_type = device_type
+ hardware.host_ip = host_ip
+ hardware.user_name = user_name
+ hardware.memory_limit = mem_b_to_mb(memory_limit) - 2 * 1024
+ hardware.devices_per_node = devices_per_node
+ hardware.num_nodes = num_nodes
+ hardware.node_rank = node_rank
+
+ if working_dir and device_id == 0:
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ hardware_filename = os.path.join(working_dir, Hardware.HARDWARE_PARSE_FILENAME)
+ with os.fdopen(os.open(hardware_filename, flags, mode=mode), 'wb') as f:
+ pickle.dump(hardware, f)
+
+ return hardware
+
+ @staticmethod
+ def get_model_args(args: Namespace, hardware: Hardware, working_dir: str) -> ModelConfig:
+ model_config = ModelConfig()
+ for arg_name, arg_value in vars(args).items():
+ if arg_name in model_config.__dict__:
+ model_config.__dict__[arg_name] = arg_value
+ model_config.global_world_size = args.auto_tuning_ranks
+
+ if dist.get_rank() % hardware.devices_per_node == 0:
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ model_config_filename = os.path.join(working_dir, ModelConfig.ARGS_PARSE_FILENAME)
+ with os.fdopen(os.open(model_config_filename, flags, mode=mode), 'wb') as f:
+ pickle.dump(model_config, f)
+
+ return model_config
+
+ @staticmethod
+ def get_model_params(model: List[Module],
+ pipeline_model_parallel_rank: int,
+ hardware: Hardware,
+ output_path: str
+ ) -> List[ModelParam]:
+ model_params: List[ModelParam] = list()
+
+ def traverse_module_layers(module: Module, prefix: str):
+ new_prefix = f"{prefix}{module.__class__.__name__}."
+
+ if all(False for _ in module.children()):
+ for param_name, param in module.named_parameters():
+ model_params.append(ModelParam(f"{new_prefix}{param_name}", param.numel()))
+ return
+
+ for sub_module in module.children():
+ traverse_module_layers(sub_module, new_prefix)
+
+ for module in model:
+ traverse_module_layers(module, str())
+
+ total_model_params = [None] * dist.get_world_size()
+ dist.all_gather_object(total_model_params, (pipeline_model_parallel_rank, model_params))
+ if dist.get_rank() % hardware.devices_per_node == 0:
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ with os.fdopen(os.open(output_path, flags, mode=mode), 'wb') as f:
+ pickle.dump(total_model_params, f)
+
+ return model_params
+
+ @staticmethod
+ def set_argv(argv: List[str], input_path: str) -> List[str]:
+ with open(input_path, mode="rb") as file:
+ try:
+ modified_argv: Tuple[Dict[str, Optional[str]], Dict[str, Optional[str]]] = \
+ restricted_loads(file) # type: ignore
+ except pickle.UnpicklingError as e:
+ _logger.warning(f"Incorrect pickle format. UnpicklingError: {e}")
+ raise e
+
+ enabled_argv, disabled_argv = modified_argv
+
+ for arg_name, arg_value in enabled_argv.items():
+ # Flag args
+ if arg_name == "--profile-ranks" and arg_value:
+ argv.extend([arg_name, *[s.strip() for s in arg_value.strip("[]").split(",")]])
+ continue
+ if arg_value is None:
+ try:
+ argv.index(arg_name)
+ except ValueError:
+ argv.append(arg_name)
+ # Non-flag args
+ else:
+ try:
+ argv[argv.index(arg_name) + 1] = arg_value
+ except ValueError:
+ argv.extend([arg_name, arg_value])
+
+ for arg_name, arg_value in disabled_argv.items():
+ # Flag args
+ if arg_value is None:
+ try:
+ argv.pop(argv.index(arg_name))
+ except ValueError:
+ continue
+ # Non-flag args
+ else:
+ try:
+ i = argv.index(arg_name)
+ argv.pop(i)
+ argv.pop(i)
+ except ValueError:
+ continue
+
+ return argv
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f01850ef74d1d1847242c38885322bc1f95b10
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication.py
@@ -0,0 +1,107 @@
+from mindspeed.auto_tuning.module.communication import communication_profile
+from mindspeed.auto_tuning.module.communication.communication_model_tp import TpModel
+from mindspeed.auto_tuning.module.communication.communication_model_cp import CpModel
+from mindspeed.auto_tuning.module.communication.communication_model_dp import DpModel
+from mindspeed.auto_tuning.module.communication.communication_model_pp import PpModel
+from mindspeed.auto_tuning.module.communication.communication_model_ep import EpModel
+from mindspeed.auto_tuning.module.communication.communication_model_mc2 import Mc2Model
+
+
+class Communication(object):
+ """Communication modeling."""
+
+ def __init__(self, hardware=None, model_cfg=None):
+ self.hardware = hardware
+ self.model_cfg = model_cfg
+
+ self.hccs_dev_num_910_9 = 384
+ self.hccs_dev_num_910b = 8
+ self.hccs_dev_num = 0
+ if "910_9" in self.hardware.device_type:
+ self.hccs_dev_num = self.hccs_dev_num_910_9
+ if "910B" in self.hardware.device_type:
+ self.hccs_dev_num = self.hccs_dev_num_910b
+
+ self.tp_model = TpModel(self.hccs_dev_num)
+ self.cp_model = CpModel(self.hccs_dev_num)
+ self.dp_model = DpModel(self.hccs_dev_num)
+ self.pp_model = PpModel(self.hccs_dev_num)
+ self.ep_model = EpModel(self.hccs_dev_num)
+ self.mc2_model = Mc2Model(self.hccs_dev_num)
+
+ self.config_list = []
+
+ def communication_modeling(self, profiling_results):
+ self.adapt_to_profile_info(profiling_results)
+ self.info_to_modeling()
+
+ def adapt_to_profile_info(self, profiling_results):
+ for index, (config, model) in enumerate(profiling_results):
+ # Reads profile information in a group of configuration files.
+ total_profile_time_info = communication_profile.TotalProfileTimeInfo()
+
+ self.config_list.append(config)
+
+ self.get_profile_info(model, total_profile_time_info, config, profiling_results, index)
+ # Now force to run only one floor
+
+ if config.use_ascend_mc2:
+ self.mc2_model.get_comm_info_list(
+ total_profile_time_info.mc2_profile_time_info, config)
+ else:
+ self.tp_model.get_comm_info_list(
+ total_profile_time_info.tp_profile_time_info, config)
+ self.dp_model.get_comm_info_list(
+ total_profile_time_info.dp_profile_time_info, config)
+ self.cp_model.get_comm_info_list(
+ total_profile_time_info.cp_profile_time_info, config)
+ self.ep_model.get_comm_info_list(
+ total_profile_time_info.ep_profile_time_info, config)
+ self.pp_model.get_comm_info_list(
+ total_profile_time_info.pp_profile_time_info, config)
+
+ def info_to_modeling(self):
+ self.tp_model.modeling()
+ self.tp_model.print_modeling(self.config_list)
+ self.mc2_model.modeling()
+ self.mc2_model.print_modeling(self.config_list)
+ self.dp_model.modeling()
+ self.dp_model.print_modeling(self.config_list)
+ self.cp_model.modeling()
+ self.cp_model.print_modeling(self.config_list)
+ self.ep_model.modeling()
+ self.ep_model.print_modeling(self.config_list)
+ self.pp_model.modeling()
+ self.pp_model.print_modeling(self.config_list)
+
+ def get_profile_info(self, model, total_profile_time_info, config, profiling_results, index):
+ tensor_hcom_info = model.tensor_parallel_comm
+ data_hcom_info = model.data_parallel_comm
+ pipeline_hcom_info = model.pipeline_parallel_comm
+ context_hcom_info = model.context_parallel_comm
+ expert_hcom_info = model.expert_parallel_comm
+ if config.use_ascend_mc2:
+ self.mc2_model.get_communication_info_from_profile(total_profile_time_info.mc2_profile_time_info,
+ profiling_results,
+ index)
+ for stage_id, stage_id_tensor_hcom_info in enumerate(tensor_hcom_info):
+ # ["tp_x"] regression
+ if stage_id == 0 and len(tensor_hcom_info) > stage_id:
+ self.tp_model.get_communication_info_from_profile(
+ total_profile_time_info.tp_profile_time_info, tensor_hcom_info[stage_id])
+ # para_list.cp_x regression
+ if stage_id == 0 and len(context_hcom_info) > stage_id:
+ self.cp_model.get_communication_info_from_profile(
+ total_profile_time_info.cp_profile_time_info, context_hcom_info[stage_id], model, config.cp)
+ if config.pp > 1:
+ if stage_id == 0 and len(pipeline_hcom_info) > stage_id:
+ self.pp_model.get_communication_info_from_profile(
+ total_profile_time_info.pp_profile_time_info, pipeline_hcom_info[stage_id], config.pp)
+ # para_list.dp_x regression
+ if stage_id == len(tensor_hcom_info) - 1 and len(data_hcom_info) > stage_id:
+ self.dp_model.get_communication_info_from_profile(
+ total_profile_time_info.dp_profile_time_info, data_hcom_info[stage_id])
+ # para_list.ep_x regression
+ if stage_id == 0 and len(expert_hcom_info) > stage_id:
+ self.ep_model.get_communication_info_from_profile(
+ total_profile_time_info.ep_profile_time_info, expert_hcom_info[stage_id])
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e652f072874dbd3d37afaf7c9865f9f51ed2b4da
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model.py
@@ -0,0 +1,196 @@
+import abc
+from mindspeed.auto_tuning.module.operator.operator_shape_cal import linear_regression
+from mindspeed.auto_tuning.utils.logger import get_logger
+
+
+class CommunicationList():
+ def __init__(self):
+ self.roce_x_list = []
+ self.roce_time_list = []
+ self.hccs_x_list = []
+ self.hccs_time_list = []
+ self.cross_x_list = []
+ self.cross_y_list = []
+ self.cross_time_list = []
+
+ self.roce_w = 0
+ self.roce_b = 0
+ self.hccs_w = 0
+ self.hccs_b = 0
+ self.cross_list = (0, 0)
+
+ def append_roce(self, iv_list, time):
+ self.roce_x_list.append([iv_list[0]])
+ self.roce_time_list.append([time])
+ self.hccs_x_list.append([None])
+ self.hccs_time_list.append([None])
+ self.cross_x_list.append([None])
+ self.cross_y_list.append([None])
+ self.cross_time_list.append([None])
+
+ def append_hccs(self, iv_list, time):
+ self.roce_x_list.append([None])
+ self.roce_time_list.append([None])
+ self.hccs_x_list.append([iv_list[0]])
+ self.hccs_time_list.append([time])
+ self.cross_x_list.append([None])
+ self.cross_y_list.append([None])
+ self.cross_time_list.append([None])
+
+ def append_cross(self, iv_list, time):
+ self.roce_x_list.append([None])
+ self.roce_time_list.append([None])
+ self.hccs_x_list.append([None])
+ self.hccs_time_list.append([None])
+ self.cross_x_list.append([iv_list[1]])
+ self.cross_y_list.append([iv_list[2]])
+ self.cross_time_list.append([time])
+
+ def cal_roce(self, iv_list):
+ return self.roce_w * iv_list[0] + self.roce_b
+
+ def cal_hccs(self, iv_list):
+ return self.hccs_w * iv_list[0] + self.hccs_b
+
+ def cal_cross(self, iv_list):
+ return self.hccs_w * iv_list[1] + self.hccs_b + self.roce_w * iv_list[2] + self.roce_b
+
+ def modeling(self):
+ lists = (
+ self.hccs_x_list,
+ self.hccs_time_list,
+ self.roce_x_list,
+ self.roce_time_list,
+ self.cross_x_list,
+ self.cross_y_list,
+ self.cross_time_list
+ )
+ (hccs_x_cal, hccs_time_cal), (roce_x_cal, roce_time_cal), (cross_x_cal,
+ cross_time_cal) = self.get_hccs_roce_list(lists)
+ if roce_x_cal:
+ self.roce_w, self.roce_b = self.linear_x_y(roce_x_cal, roce_time_cal)
+ if hccs_x_cal:
+ self.hccs_w, self.hccs_b = self.linear_x_y(hccs_x_cal, hccs_time_cal)
+
+ def get_hccs_roce_list(self, lists):
+ hccs_x_list = []
+ hccs_y_list = []
+ roce_x_list = []
+ roce_y_list = []
+ cross_x_list = []
+ cross_y_list = []
+ for i, x_index in enumerate(lists[0]):
+ if lists[0][i] != [None]:
+ hccs_x_list.append(lists[0][i])
+ hccs_y_list.append(lists[1][i])
+ elif lists[2][i] != [None]:
+ roce_x_list.append(lists[2][i])
+ roce_y_list.append(lists[3][i])
+ else:
+ cross_x_list.append([lists[4][i][0] / lists[5][i][0]])
+ cross_y_list.append([lists[6][i][0] / lists[5][i][0]])
+ hccs_lists = (hccs_x_list, hccs_y_list)
+ roce_lists = (roce_x_list, roce_y_list)
+ cross_lists = (cross_x_list, cross_y_list)
+ re_hccs_lists = self.add_origin_whith_single_point(hccs_lists)
+ re_roce_lists = self.add_origin_whith_single_point(roce_lists)
+ re_cross_lists = self.add_origin_whith_single_point(cross_lists)
+
+ return re_hccs_lists, re_roce_lists, re_cross_lists
+
+ @classmethod
+ def add_origin_whith_single_point(cls, lists):
+ last = None
+ for item in lists[0]:
+ if last:
+ if item != last:
+ last = None
+ break
+ else:
+ last = item
+ listres = lists
+ if last:
+ listres = [[], []]
+ listres[0].append(lists[0][0])
+ listres[1].append(lists[1][0])
+ if len(listres[0]) == 1:
+ listres[0].append([0])
+ listres[1].append([0])
+ return listres
+
+ @classmethod
+ def linear_x_y(cls, list1, list2):
+ w, b = 0, 0
+ if len(list1) > 0:
+ w, b = linear_regression(list1, list2) if list1 else (0, 0)
+ return w, b
+
+
+class CommunicationModel:
+ def __init__(self, hccs_dev_num):
+ self.comm = CommunicationList()
+ self.main_domain = Domain(hccs_dev_num)
+ self.hccs_dev_num = hccs_dev_num
+ self.logger = get_logger("Communication")
+
+ @abc.abstractmethod
+ def get_communication_info_from_profile(self, hcom_info_tage_id):
+ pass
+
+ @abc.abstractmethod
+ def get_comm_info_list(self, profile_info):
+ pass
+
+ @abc.abstractmethod
+ def modeling(self):
+ pass
+
+ @abc.abstractmethod
+ def print_modeling(self):
+ pass
+
+
+class Domain:
+ def __init__(self, hccs_dev_num):
+ self.max_domain = 0
+ self.min_domain = 0
+ self.roce_comm_exist = False
+ self.hccs_comm_exist = False
+ self.cross_comm_exist = False
+ self.hccs_dev_num = hccs_dev_num
+
+ def is_hccs_domain(self):
+ return self.max_domain <= self.hccs_dev_num
+
+ def is_cross_domain(self):
+ return self.min_domain < self.hccs_dev_num < self.max_domain
+
+ def is_roce_domain(self):
+ return not (self.is_hccs_domain() or self.is_hccs_domain())
+
+ def append_method_for_domain(self):
+ if self.is_hccs_domain():
+ self.hccs_comm_exist = True
+ return "append_hccs"
+ if self.is_cross_domain():
+ self.cross_comm_exist = True
+ return "append_cross"
+ self.roce_comm_exist = True
+ return "append_roce"
+
+ def append_time_in_domain(self, communication_list, iv_list, time):
+ method_for_domain = self.append_method_for_domain()
+ append_domain = getattr(communication_list, method_for_domain)
+ append_domain(iv_list, time)
+
+ def cal_method_for_domain(self):
+ if self.is_hccs_domain():
+ return "cal_hccs"
+ if self.is_cross_domain():
+ return "cal_cross"
+ return "cal_roce"
+
+ def cal_time_in_domain(self, communication_list, iv_list):
+ method_for_domain = self.cal_method_for_domain()
+ cal_domain = getattr(communication_list, method_for_domain)
+ return cal_domain(iv_list)
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_cp.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_cp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dbe0cb8d5b3fb336d34404e702c7276dc1d3ff2
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_cp.py
@@ -0,0 +1,228 @@
+from mindspeed.auto_tuning.module.communication.communication_model import CommunicationModel
+_GLOBAL_ATTN_FORWARD_KERNEL_NAMES = [
+ "aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore"
+]
+_GLOBAL_ATTN_BACKWARD_KERNEL_NAMES = [
+ "aclnnFlashAttentionScoreGrad_FlashAttentionScoreGrad_FlashAttentionScoreGrad"
+]
+
+
+class CpModel(CommunicationModel):
+ def __init__(self, hccs_dev_num):
+ super(CpModel, self).__init__(hccs_dev_num)
+ # Profile Modeling Data Information Table
+ self.cp_vector_x = []
+ self.cp_vector_time = []
+ self.cp_attn_x = []
+ self.cp_attn_time = []
+ self.cp_attn_bw_x = []
+ self.cp_attn_bw_time = []
+
+ self.cp_attn_w = 0
+ self.cp_attn_b = 0
+ self.cp_attn_bw_w = 0
+ self.cp_attn_bw_b = 0
+ self.cp_vector_w = 0
+ self.cp_vector_b = 0
+
+ def get_communication_info_from_profile(self, cp_profile_time_info, hcom_info_tage_id, model, cp):
+ cp_profile_time_info.total_comm_time += hcom_info_tage_id.total_time_ms
+ cp_profile_time_info.wait_comm_time += hcom_info_tage_id.wait_time_ms
+ cp_profile_time_info.attn_cp_time, cp_profile_time_info.attn_cpbw_time = \
+ self.get_vectortime_from_profiling(model, cp)
+ cp_profile_time_info.vector_cp_time += hcom_info_tage_id.vector_time_ms
+
+ def get_comm_info_list(self, cp_profile_time_info, config):
+ tp = config.tp
+ cp = config.cp
+ pp = config.pp
+ dp = config.dp
+ s = config.seq_length / 1000
+
+ # CP's communication volume is CP-1 times the forward KV, backward KV, and dKV per machine.
+ if cp > 1:
+ # Here we consider only the attention of communication hiding, with forward CP-1 and backward CP.
+ self.cp_attn_x.append([s / tp / cp * (cp - 1) / cp])
+ self.cp_attn_time.append([cp_profile_time_info.attn_cp_time])
+ self.cp_attn_bw_x.append([s / tp / cp])
+ self.cp_attn_bw_time.append([cp_profile_time_info.attn_cpbw_time])
+ self.cp_vector_time.append([cp_profile_time_info.vector_cp_time])
+ if cp - 2 < 0:
+ self.cp_vector_x.append([0])
+ else:
+ self.cp_vector_x.append([cp - 2])
+
+ comm_x = (cp - 1) * s / (tp * cp) * pp
+ comm_time = cp_profile_time_info.total_comm_time
+
+ K = cp * tp / self.hccs_dev_num
+ comm_y = (K) * s / (tp * cp) * pp
+ comm_z = (K - 1) * s / (tp * cp) * pp
+ iv_list = [comm_x, comm_y, comm_z]
+ self.main_domain.max_domain = cp * tp
+ self.main_domain.min_domain = tp
+ self.main_domain.append_time_in_domain(self.comm, iv_list, comm_time)
+
+ def modeling(self):
+ # traffic of model
+ self.comm.modeling()
+
+ # overlap
+ self.cp_attn_w, self.cp_attn_b = self.comm.linear_x_y(
+ self.cp_attn_x, self.cp_attn_time)
+ self.cp_attn_bw_w, self.cp_attn_bw_b = self.comm.linear_x_y(
+ self.cp_attn_bw_x, self.cp_attn_bw_time)
+ self.cp_vector_w, self.cp_vector_b = self.comm.linear_x_y(
+ self.cp_vector_x, self.cp_vector_time)
+
+ def print_modeling(self, config_list):
+ self.logger.debug(f"****************** cp(ms) ***********************")
+ if self.main_domain.roce_comm_exist:
+ self.logger.debug(f"roce")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'cp_time', 'cp_x',
+ chr(12288)))
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8.2f}\t{7:<8}"
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].cp > 1:
+ if self.comm.roce_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ self.comm.roce_time_list[index][0], self.comm.roce_x_list[index][0],
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"--------------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format('cp_w,', 'cp_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.comm.roce_w, 3), round(self.comm.roce_b, 3),
+ chr(12288)))
+ self.logger.debug(f"-------------")
+ if self.main_domain.hccs_comm_exist:
+ self.logger.debug(f"hccs")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'cp_time', 'cp_x',
+ chr(12288)))
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8.2f}\t{7:<8}"
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].cp > 1:
+ if self.comm.hccs_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ self.comm.hccs_time_list[index][0], self.comm.hccs_x_list[index][0],
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format('cp_HCCS_w,', 'cp_HCCS_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.comm.hccs_w, 3), round(self.comm.hccs_b, 3),
+ chr(12288)))
+ self.logger.debug(f"-----------")
+
+ if self.main_domain.cross_comm_exist:
+ self.logger.debug(f"cross")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp',
+ 'ep', 'cp_time', 'cp_cross_x', 'cp_cross_y', chr(12288)))
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8.2f}\t{7:<8.2f}\t{8:<8.2f}"
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].cp > 1:
+ if self.comm.cross_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ self.comm.cross_time_list[index][0],
+ self.comm.cross_x_list[index][0], self.comm.cross_y_list[index][0],
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format(round(self.comm.hccs_w, 3), round(self.comm.roce_w, 3),
+ chr(12288)))
+ self.logger.debug(f"-----------")
+
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}\t{9:<8}\t{10:<8}\t{11:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'attn_x',
+ 'attention', 'attn_bw_x', 'attn_bw', 'vector_x', 'vector_time', chr(12288)))
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8.2f}\t{7:<8.2f}\t{8:<8.2f}\t{9:<8.2f}\t{10:<8.2f}\t{11:<8.2f}"
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].cp > 1:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ self.cp_attn_x[index][0], self.cp_attn_time[index][0],
+ self.cp_attn_bw_x[index][0], self.cp_attn_bw_time[index][0],
+ self.cp_vector_x[index][0], self.cp_vector_time[index][0], chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}\t{2:<9}\t{3:<9}\t{4:<9}\t{5:<9}"
+ self.logger.debug(tplt.format('attn_w,', 'attn_b', 'attn_bw_w',
+ 'attn_bw_b', 'vector_w', 'vector_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.cp_attn_w, 3), round(self.cp_attn_b, 3),
+ round(self.cp_attn_bw_w, 3), round(self.cp_attn_bw_b, 3),
+ round(self.cp_vector_w, 3), round(
+ self.cp_vector_b, 3),
+ chr(12288)))
+ self.logger.debug(f"\n\n\n")
+ return
+
+
+ def get_vectortime_from_profiling(self, model, cp):
+ attn_list = []
+ attn_re_list = []
+ attn_gb_list = []
+ profile_info = model
+ attention = 0.0
+ attn_bw = 0.0
+ for item in profile_info.forward.operator_info[0]:
+ if item.name in _GLOBAL_ATTN_FORWARD_KERNEL_NAMES and len(attn_list) < cp - 1:
+ attn_list.append(item)
+ attention += float(item.duration_us)
+ for item in profile_info.backward.operator_info[0]:
+ if item.name in _GLOBAL_ATTN_FORWARD_KERNEL_NAMES and len(attn_re_list) < cp - 1:
+ attn_re_list.append(item)
+ attention += float(item.duration_us)
+ if item.name in _GLOBAL_ATTN_BACKWARD_KERNEL_NAMES and len(attn_gb_list) < cp:
+ attn_gb_list.append(item)
+ attn_bw += float(item.duration_us)
+ # Attention, one of them is shadowed. attn_bw needs to be calculated.
+ attention = attention / 1000
+ attn_bw = attn_bw / 1000
+ return attention, attn_bw
+
+ def performance(self, search_cfg):
+ tp = search_cfg.tensor_model_parallel_size
+ pp = search_cfg.pipeline_model_parallel_size
+ cp = search_cfg.context_parallel_size
+ s = search_cfg.seq_length / 1000
+ cp_time = 0.0
+ comm_x = (cp - 1) * s / (tp * cp) * pp
+ K = cp * tp / self.hccs_dev_num
+ comm_y = (K) * s / (tp * cp) * pp
+ comm_z = (K - 1) * s / (tp * cp) * pp
+ iv_list = [comm_x, comm_y, comm_z]
+ self.main_domain.max_domain = cp * tp
+ self.main_domain.min_domain = tp
+ if cp > 1:
+ comm_time = self.main_domain.cal_time_in_domain(self.comm, iv_list)
+
+ attn_time = self.cp_attn_w * (s / tp / cp * (cp - 1) / cp) + self.cp_attn_b
+ attn_bw_time = self.cp_attn_bw_w * (s / tp / cp) + self.cp_attn_bw_b
+ # Attention and attn_bw need to be considered separately.
+ cp_time1 = comm_time / 2 - attn_time * pp
+ if cp_time1 < 0:
+ cp_time1 = 0
+ cp_time2 = comm_time / 2 - attn_bw_time * pp
+ if cp_time2 < 0:
+ cp_time2 = 0
+ cp_time = cp_time1 + cp_time2
+ if cp > 2:
+ cp_vector_time = self.cp_vector_w * (cp - 2) + self.cp_vector_b
+ cp_time = cp_time - cp_vector_time
+ self.logger.debug('cp_time:{}, attn_time:{}, attn_bw_time:{}, '
+ 'cp_vector_time:{}'.format(cp_time, attn_time, attn_bw_time, cp_vector_time))
+ if cp_time < 0:
+ cp_time = 0.0
+ self.logger.debug(f'The communication time of the CP is the waiting time.')
+ return cp_time
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_dp.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_dp.py
new file mode 100644
index 0000000000000000000000000000000000000000..6210566315c1094a4488a996ca5f9984da07047f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_dp.py
@@ -0,0 +1,226 @@
+from mindspeed.auto_tuning.module.communication.communication_model \
+ import CommunicationModel, CommunicationList, Domain
+
+
+class DpModel(CommunicationModel):
+ def __init__(self, hccs_dev_num):
+ super(DpModel, self).__init__(hccs_dev_num)
+ # Profile modeling data table
+
+ self.attention = CommunicationList()
+ self.attention_reducescatter = CommunicationList()
+ self.attention_allgather = CommunicationList()
+
+ self.mlp_domain = Domain(hccs_dev_num)
+ self.zero_comm = CommunicationList()
+ self.zero = CommunicationList()
+ self.zero_reducescatter = CommunicationList()
+ self.zero_allgather = CommunicationList()
+
+ def get_communication_info_from_profile(self, dp_profile_time_info, hcom_info_tage_id):
+ dp_profile_time_info.total_comm_time += hcom_info_tage_id.total_time_ms
+ dp_profile_time_info.total_mlpzero_time += hcom_info_tage_id.mlp_zero_time_ms
+ dp_profile_time_info.total_otherzero_time += hcom_info_tage_id.total_time_ms - hcom_info_tage_id.mlp_zero_time_ms
+ dp_profile_time_info.mlp_ag_time += hcom_info_tage_id.mlp_ag_time_ms
+ dp_profile_time_info.mlp_rs_time += hcom_info_tage_id.mlp_rs_time_ms
+ dp_profile_time_info.other_ag_time += hcom_info_tage_id.other_ag_time_ms
+ dp_profile_time_info.other_rs_time += hcom_info_tage_id.other_rs_time_ms
+
+ def get_comm_info_list(self, dp_profile_time_info, config):
+ tp = config.tp
+ cp = config.cp
+ dp = config.dp
+ ep = config.ep
+ pp = config.pp
+ zero = config.zero1
+ experts = config.num_experts if config.num_experts else 1
+
+ # attention
+ if dp * cp > 1:
+ comm_x = (dp * cp - 1) / (tp * pp)
+ K = dp * cp * tp / self.hccs_dev_num
+ comm_y = (K) / (tp * pp)
+ comm_z = (K - 1) / (tp * pp)
+ iv_list = [comm_x, comm_y, comm_z]
+ comm_time = dp_profile_time_info.total_otherzero_time
+ reducescatter_time = dp_profile_time_info.other_rs_time
+ allgather_time = dp_profile_time_info.other_ag_time
+ dp_total_time = dp_profile_time_info.total_comm_time
+ self.main_domain.max_domain = dp * cp * tp
+ self.main_domain.min_domain = cp * tp
+ self.main_domain.append_time_in_domain(self.attention, iv_list, comm_time)
+ self.main_domain.append_time_in_domain(self.attention_reducescatter, iv_list, reducescatter_time)
+ self.main_domain.append_time_in_domain(self.attention_allgather, iv_list, allgather_time)
+ self.main_domain.append_time_in_domain(self.comm, iv_list, dp_total_time)
+ # MLP
+ mlp_x = experts * (dp * cp / ep - 1) / tp / pp
+ comm_time = dp_profile_time_info.total_mlpzero_time
+ reducescatter_time = dp_profile_time_info.mlp_rs_time
+ allgather_time = dp_profile_time_info.mlp_ag_time
+ mlp_x = experts * (dp * cp / ep - 1) / tp / pp
+ K = dp * cp * tp / ep / self.hccs_dev_num
+ mlp_y = experts * (K) / (tp * pp)
+ mlp_z = experts * (K - 1) / (tp * pp)
+ iv_list = [mlp_x, mlp_y, mlp_z]
+ self.mlp_domain.max_domain = dp * cp * tp
+ self.mlp_domain.min_domain = cp * tp * ep
+ self.mlp_domain.append_time_in_domain(self.zero, iv_list, comm_time)
+ self.mlp_domain.append_time_in_domain(self.zero_reducescatter, iv_list, reducescatter_time)
+ self.mlp_domain.append_time_in_domain(self.zero_allgather, iv_list, allgather_time)
+ self.mlp_domain.append_time_in_domain(self.zero_comm, iv_list, dp_total_time)
+
+ def modeling(self):
+ self.attention.modeling()
+ self.attention_reducescatter.modeling()
+ self.attention_allgather.modeling()
+ self.zero.modeling()
+ self.zero_reducescatter.modeling()
+ self.zero_allgather.modeling()
+
+ def print_modeling(self, config_list):
+ self.logger.debug(f"****************** dp(ms) ***********************")
+ attention = [
+ self.comm,
+ self.attention,
+ self.attention_reducescatter,
+ self.attention_allgather,
+ ]
+ self.logger.debug(f"attention time :")
+ self.print_modeling_unit(config_list, attention, self.main_domain)
+ self.logger.debug(f"\n\n")
+
+ mlp = [
+ self.zero_comm,
+ self.zero,
+ self.zero_reducescatter,
+ self.zero_allgather,
+ ]
+ self.logger.debug(f"mlp time :")
+ self.print_modeling_unit(config_list, mlp, self.mlp_domain)
+ self.logger.debug(f"\n\n\n")
+
+ def print_modeling_unit(self, config_list, info_list, domain):
+ if domain.roce_comm_exist:
+ self.logger.debug(f" roce")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}\t{9:<8}\t{10:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'dp_time',
+ 'x', 'time', 'ag_time', 'rs_time', chr(12288)))
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].dp * config_list[i].cp > 1:
+ if info_list[1].roce_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ round(info_list[0].roce_time_list[index][0], 2),
+ round(info_list[1].roce_x_list[index][0], 3),
+ round(info_list[1].roce_time_list[index][0], 2),
+ round(info_list[2].roce_time_list[index][0], 3),
+ round(info_list[3].roce_time_list[index][0], 2),
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}\t{2:<9}\t{3:<9}\t{4:<9}\t{5:<9}"
+ self.logger.debug(tplt.format('time_w', 'time_b', 'rs_w', 'rs_b', 'ag_w', 'ag_b', chr(12288)))
+ self.logger.debug(tplt.format(round(info_list[1].roce_w, 2), round(info_list[1].roce_b, 2),
+ round(info_list[2].roce_w, 2),
+ round(info_list[2].roce_b, 2),
+ round(info_list[3].roce_w, 2),
+ round(info_list[3].roce_b, 2), chr(12288)))
+ self.logger.debug(f"----------------------")
+ if domain.hccs_comm_exist:
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}\t{9:<8}\t{10:<8}"
+ self.logger.debug(f" hccs")
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'dp_time',
+ 'x', 'time', 'ag_time', 'rs_time', chr(12288)))
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].dp * config_list[i].cp > 1:
+ if info_list[1].hccs_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ round(info_list[0].hccs_time_list[index][0], 2),
+ round(info_list[1].hccs_x_list[index][0], 3),
+ round(info_list[1].hccs_time_list[index][0], 2),
+ round(info_list[2].hccs_time_list[index][0], 3),
+ round(info_list[3].hccs_time_list[index][0], 2),
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}\t{2:<9}\t{3:<9}\t{4:<9}\t{5:<9}"
+ self.logger.debug(tplt.format('dp_w', 'dp_b', 'rs_w', 'rs_b', 'ag_w', 'ag_b', chr(12288)))
+ self.logger.debug(tplt.format(round(info_list[1].hccs_w, 2), round(self.attention.hccs_b, 2),
+ round(info_list[2].hccs_w, 2),
+ round(info_list[2].hccs_b, 2),
+ round(info_list[3].hccs_w, 2),
+ round(info_list[3].hccs_b, 2), chr(12288)))
+ self.logger.debug(f"----------------------")
+ if domain.cross_comm_exist:
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}\t{9:<8}\t{10:<8}\t{11:<8}"
+ self.logger.debug(f" cross")
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'dp_time', 'dp_x', 'dp_y', 'total_time', 'ag_time',
+ 'rs_time', chr(12288)))
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].dp * config_list[i].cp > 1:
+ if info_list[1].cross_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ round(info_list[0].cross_time_list[index][0], 2),
+ round(info_list[1].cross_x_list[index][0], 3),
+ round(info_list[1].cross_y_list[index][0], 3),
+ round(info_list[1].cross_time_list[index][0], 2),
+ round(info_list[2].cross_time_list[index][0], 3),
+ round(info_list[3].cross_time_list[index][0], 3),
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"----------------------")
+
+ def performance(self, search_cfg):
+ tp = search_cfg.tensor_model_parallel_size
+ dp = search_cfg.data_parallel_size
+ pp = search_cfg.pipeline_model_parallel_size
+ cp = search_cfg.context_parallel_size
+ ep = search_cfg.expert_model_parallel_size if search_cfg.expert_model_parallel_size else 1
+ zero = search_cfg.use_distributed_optimizer
+ experts = search_cfg.num_experts if search_cfg.num_experts else 1
+
+ dp_time = 0.0
+ comm_time = 0.0
+ mlp_time = 0.0
+ overlap_time = 0.0
+ other_reducescatter = 0.0
+ other_allgather = 0.0
+ zero_reducescatter = 0.0
+ zero_allgather = 0.0
+ if dp * cp > 1:
+ # attention:
+ self.main_domain.max_domain = dp * cp * tp
+ self.main_domain.min_domain = cp * tp
+ comm_x = (dp * cp - 1) / tp / pp
+ K = dp * cp * tp / self.hccs_dev_num
+ comm_y = (K) / (tp * pp)
+ comm_z = (K - 1) / (tp * pp)
+ iv_list = [comm_x, comm_y, comm_z]
+ comm_time = self.main_domain.cal_time_in_domain(self.attention, iv_list)
+ other_reducescatter = self.main_domain.cal_time_in_domain(self.attention_reducescatter, iv_list)
+ other_allgather = self.main_domain.cal_time_in_domain(self.attention_allgather, iv_list)
+
+ # mlp
+ self.mlp_domain.max_domain = dp * cp * tp
+ self.mlp_domain.min_domain = cp * tp * ep
+ mlp_x = experts * (dp * cp / ep - 1) / tp / pp
+ K = dp * cp * tp / ep / self.hccs_dev_num
+ mlp_y = experts * (K) / (tp * pp)
+ mlp_z = experts * (K - 1) / (tp * pp)
+ mlp_iv_list = [mlp_x, mlp_y, mlp_z]
+ mlp_time = self.mlp_domain.cal_time_in_domain(self.zero, mlp_iv_list)
+ zero_reducescatter = self.mlp_domain.cal_time_in_domain(self.zero_reducescatter, mlp_iv_list)
+ zero_allgather = self.mlp_domain.cal_time_in_domain(self.zero_allgather, mlp_iv_list)
+ if zero:
+ if pp > 1:
+ overlap_time += (pp - 1) / pp * (other_reducescatter + zero_reducescatter)
+ if pp > 2:
+ overlap_time += (pp - 2) / pp * (other_allgather + zero_allgather)
+ dp_time = comm_time + mlp_time - overlap_time
+ # dp_time here is the total gbs time effect
+ return dp_time
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_ep.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_ep.py
new file mode 100644
index 0000000000000000000000000000000000000000..639ca6686a5cd9830cb189d82f72f1cb5a9a58f4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_ep.py
@@ -0,0 +1,119 @@
+from mindspeed.auto_tuning.module.communication.communication_model import CommunicationModel
+
+
+class EpModel(CommunicationModel):
+ def __init__(self, hccs_dev_num):
+ super(EpModel, self).__init__(hccs_dev_num)
+
+ def get_communication_info_from_profile(self, ep_profile_time_info, hcom_info_tage_id):
+ ep_profile_time_info.total_comm_time += hcom_info_tage_id.total_time_ms
+ ep_profile_time_info.wait_comm_time += hcom_info_tage_id.wait_time_ms
+ ep_profile_time_info.min_time += hcom_info_tage_id.min_comm_time_ms
+
+ def get_comm_info_list(self, ep_profile_time_info, config):
+ tp = config.tp
+ cp = config.cp
+ ep = config.ep
+ pp = config.pp
+ s = config.seq_length / 1000
+ experts = config.num_experts if config.num_experts else 1
+
+ if ep and ep > 1:
+ comm_x = experts * s * (ep - 1) * pp / ep / tp / cp
+ K = ep * tp / self.hccs_dev_num
+ comm_y = experts * s * (K) * pp / ep / tp / cp
+ comm_z = experts * s * (K - 1) / K * pp / ep / tp / cp
+ iv_list = [comm_x, comm_y, comm_z]
+ comm_time = ep_profile_time_info.min_time
+ self.main_domain.max_domain = ep * tp
+ self.main_domain.min_domain = tp
+ self.main_domain.append_time_in_domain(self.comm, iv_list, comm_time)
+
+ def modeling(self):
+ self.comm.modeling()
+
+ def print_modeling(self, config_list):
+ self.logger.debug(f"****************** ep(ms) ***********************")
+ if self.main_domain.roce_comm_exist:
+ self.logger.debug(f"roce")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp',
+ 'ep', 'ep_roce_time', 'ep_roce_x', chr(12288)))
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].ep > 1:
+ if self.comm.roce_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ round(self.comm.roce_time_list[index][0], 2), round(
+ self.comm.roce_x_list[index][0], 3),
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"--------------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format('ep_w', 'ep_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.comm.roce_w, 3),
+ round(self.comm.roce_b, 3), chr(12288)))
+ self.logger.debug(f"--------------")
+ if self.main_domain.hccs_comm_exist:
+ self.logger.debug(f"hccs")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp',
+ 'ep', 'ep_hccs_time', 'ep_hccs_x', chr(12288)))
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].ep > 1:
+ if self.comm.hccs_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ round(
+ self.comm.hccs_time_list[index][0], 2),
+ round(self.comm.hccs_x_list[index][0], 3), chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format('ep_HCCS_w', 'ep_HCCS_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.comm.hccs_w, 3), round(self.comm.hccs_b, 3),
+ chr(12288)))
+ self.logger.debug(f"-----------")
+ if self.main_domain.cross_comm_exist:
+ self.logger.debug(f"cross")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp',
+ 'ep', 'ep_cross_time', 'ep_cross_x', 'ep_cross_y', chr(12288)))
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8.2f}\t{7:<8.2f}\t{8:<8.2f}"
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].ep > 1:
+ if self.comm.cross_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ config_list[i].cp, config_list[i].ep,
+ self.comm.cross_time_list[index][0],
+ self.comm.cross_x_list[index][0], self.comm.cross_y_list[index][0],
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format(round(self.comm.hccs_w, 3), round(self.comm.roce_w, 3),
+ chr(12288)))
+ self.logger.debug(f"-----------")
+ self.logger.debug(f"\n\n\n")
+
+ def performance(self, search_cfg):
+ tp = search_cfg.tensor_model_parallel_size
+ pp = search_cfg.pipeline_model_parallel_size
+ cp = search_cfg.context_parallel_size
+ ep = search_cfg.expert_model_parallel_size
+ s = search_cfg.seq_length / 1000
+ ep_time = 0.0
+ experts = search_cfg.num_experts if search_cfg.num_experts else 1
+ comm_x = experts * s * (ep - 1) * pp / ep / tp / cp
+ K = ep * tp / self.hccs_dev_num
+ comm_y = experts * s * (K) * pp / ep / tp / cp
+ comm_z = experts * s * (K - 1) / K * pp / ep / tp / cp
+ iv_list = [comm_x, comm_y, comm_z]
+ self.main_domain.max_domain = ep * tp
+ self.main_domain.min_domain = tp
+ if ep and ep > 1:
+ ep_time = self.main_domain.cal_time_in_domain(self.comm, iv_list)
+ return ep_time
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_mc2.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_mc2.py
new file mode 100644
index 0000000000000000000000000000000000000000..64b6dc9ec5325baffa8dfbbce25c05395286bf30
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_mc2.py
@@ -0,0 +1,57 @@
+from mindspeed.auto_tuning.module.communication.communication_model import CommunicationModel
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import NumberConstant
+
+
+class Mc2Model(CommunicationModel):
+ def __init__(self, hccs_dev_num):
+ super(Mc2Model, self).__init__(hccs_dev_num)
+
+ def get_communication_info_from_profile(self, mc2_profile_time_info, hcom_info_tage_id, index):
+ mc2_res = hcom_info_tage_id[index][1]
+ mat_res = hcom_info_tage_id[index - 1][1]
+ mc2_profile_time_info.matmul_compute_time = mat_res.matmul_total_time[0]
+ mc2_profile_time_info.total_comm_time = mc2_res.mc2_total_time[0]
+
+ def get_comm_info_list(self, mc2_profile_time_info, config):
+ tp = config.tp
+ cp = config.cp
+ s = config.seq_length / NumberConstant.CONVERSION_TIME
+ hccs_x = (s / (tp * cp))
+ hccs_time = mc2_profile_time_info.total_comm_time - mc2_profile_time_info.matmul_compute_time
+ self.comm.append_hccs([hccs_x], hccs_time)
+
+ def modeling(self):
+ sum_x = 0
+ sum_time = 0
+ for index, x in enumerate(self.comm.hccs_x_list):
+ sum_x += x[0]
+ sum_time += self.comm.hccs_time_list[index][0]
+ self.comm.hccs_w = sum_time / sum_x
+
+ def print_modeling(self, config_list):
+ mc2lt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}"
+ self.logger.debug(f"****************** mc2(ms) ***********************")
+ self.logger.debug(mc2lt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'mc2_time', 'mc2_x', chr(12288)))
+ index = 0
+ for cfg in config_list:
+ if cfg.use_ascend_mc2:
+ self.logger.debug(mc2lt.format(index, cfg.tp, cfg.dp, cfg.pp,
+ cfg.cp,
+ cfg.ep,
+ round(self.comm.hccs_time_list[index][0], 2), round(
+ self.comm.hccs_x_list[index][0], 3), chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ mc2lt = "{0:<9}\t{1:<9}"
+ self.logger.debug(mc2lt.format('tp_w', 'tp_b', chr(12288)))
+ self.logger.debug(mc2lt.format(round(self.comm.hccs_w, 3), round(self.comm.hccs_b, 3), chr(12288)))
+ self.logger.debug(f"\n\n\n")
+
+ def performance(self, search_cfg):
+ tp = search_cfg.tensor_model_parallel_size
+ cp = search_cfg.context_parallel_size
+ s = search_cfg.seq_length / 1000
+ mc2_time = 0
+ if tp > 1:
+ mc2_time = self.comm.hccs_w * (s / (tp * cp)) + self.comm.hccs_b
+ return mc2_time
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_pp.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_pp.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5655fe5651440538f0f5186330ce359b6e3ba1c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_pp.py
@@ -0,0 +1,99 @@
+from mindspeed.auto_tuning.module.communication.communication_model import CommunicationModel
+
+
+class PpModel(CommunicationModel):
+ def __init__(self, hccs_dev_num):
+ super(PpModel, self).__init__(hccs_dev_num)
+
+ def get_communication_info_from_profile(self, pp_profile_time_info, hcom_info_tage_id, pp):
+ last_pp_start_time = 0
+ total_pp_time = 0
+ for i in range(0, pp - 1):
+ key = list(hcom_info_tage_id.details[i].keys())[0]
+ total_pp_time += hcom_info_tage_id.details[i][key]['Elapse Time(ms)']
+ if last_pp_start_time == 0:
+ last_pp_start_time = hcom_info_tage_id.details[i][key]['Start Timestamp(us)']
+ pp_profile_time_info.each_pp_time = total_pp_time / (pp - 1)
+
+ def get_comm_info_list(self, pp_profile_time_info, config):
+ tp = config.tp
+ cp = config.cp
+ pp = config.pp
+ dp = config.dp
+ layers_per_vpp = config.layers_per_vpp if config.layers_per_vpp else 1
+ comm_x = 1 / (layers_per_vpp * tp * cp)
+ iv_list = [comm_x, 0, 0] # PP does not need to consider cross modeling.
+ comm_time = pp_profile_time_info.each_pp_time
+ self.main_domain.max_domain = pp * dp * cp * tp
+ self.main_domain.min_domain = pp * dp * cp * tp
+ if pp > 1:
+ self.main_domain.append_time_in_domain(self.comm, iv_list, comm_time)
+ # PPtime indicates the time consumed by each PP communication.
+
+ def modeling(self):
+ self.comm.modeling()
+ if self.comm.hccs_w == 0:
+ self.comm.hccs_w = self.comm.roce_w
+
+ def print_modeling(self, config_list):
+ self.logger.debug(f"****************** pp(ms) ***********************")
+ if self.main_domain.roce_comm_exist:
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<1}\t{7:<8}\t{8:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'vp',
+ 'cp', 'ep', 'pp_x', 'pp_time', chr(12288)))
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].pp > 1:
+ if self.comm.roce_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ str(config_list[i].layers_per_vpp), config_list[i].cp, config_list[i].ep,
+ round(self.comm.roce_x_list[index][0], 3), round(
+ self.comm.roce_time_list[index][0], 2),
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format('pp_w', 'pp_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.comm.roce_w, 3),
+ round(self.comm.roce_b, 3), chr(12288)))
+ self.logger.debug(f"-----------")
+ if self.main_domain.hccs_comm_exist:
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<1}\t{7:<8}\t{8:<8}"
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'vp', 'cp',
+ 'ep', 'pp_HCCS_x', 'pp_HCCS_time', chr(12288)))
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].pp > 1:
+ if self.comm.hccs_x_list[index][0]:
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp,
+ str(config_list[i].layers_per_vpp), config_list[i].cp, config_list[i].ep,
+ round(
+ self.comm.hccs_x_list[index][0], 3),
+ round(self.comm.hccs_time_list[index][0], 2), chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}"
+ self.logger.debug(tplt.format('pp_HCCS_w', 'pp_HCCS_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.comm.hccs_w, 3), round(self.comm.hccs_b, 3),
+ chr(12288)))
+ self.logger.debug(f"-----------")
+ self.logger.debug(f"\n\n\n")
+
+ def performance(self, search_cfg):
+ tp = search_cfg.tensor_model_parallel_size
+ dp = search_cfg.data_parallel_size
+ pp = search_cfg.pipeline_model_parallel_size
+ vp = search_cfg.num_layers // (
+ pp * search_cfg.num_layers_per_virtual_pipeline_stage) if search_cfg.num_layers_per_virtual_pipeline_stage else 1
+ cp = search_cfg.context_parallel_size
+
+ pp_time = 0.0
+ comm_x = (1 / (vp * tp * cp))
+ iv_list = [comm_x, 0, 0] # PP does not need to consider cross modeling.
+ self.main_domain.max_domain = pp * dp * cp * tp
+ self.main_domain.min_domain = pp * dp * cp * tp
+ if pp > 1:
+ each_pp_time = self.main_domain.cal_time_in_domain(self.comm, iv_list)
+ each_pp_time = each_pp_time * 2 # Multiply send and receive by 2.
+ pp_time = each_pp_time * (pp * vp - 1) * 2
+ return pp_time
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_tp.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_tp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c701b995bbaba9f4df861380560d58bcbedd1a12
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_model_tp.py
@@ -0,0 +1,96 @@
+from mindspeed.auto_tuning.module.communication.communication_model import CommunicationModel
+
+
+class TpModel(CommunicationModel):
+ def __init__(self, hccs_dev_num):
+ super(TpModel, self).__init__(hccs_dev_num)
+ # Profile modeling data table
+ self.tp_comm_total_time_list = []
+ self.tp_comm_wait_time_list = []
+ self.tp_comm_overlap_time_list = []
+
+ self.tp_hccs_overlap_w = 0
+ self.tp_hccs_overlap_b = 0
+
+ def get_communication_info_from_profile(self, tp_profile_time_info, hcom_info_tage_id):
+ tp_profile_time_info.total_comm_time += hcom_info_tage_id.total_time_ms
+ tp_profile_time_info.wait_comm_time += hcom_info_tage_id.wait_time_ms
+ tp_profile_time_info.overlap_comm_time += hcom_info_tage_id.overlap_time_ms
+
+ def get_comm_info_list(self, tp_profile_time_info, config):
+ tp = config.tp
+ cp = config.cp
+ pp = config.pp
+ s = config.seq_length / 1000
+ total_time = tp_profile_time_info.total_comm_time
+ wait_time = tp_profile_time_info.wait_comm_time
+ overlap_time = tp_profile_time_info.overlap_comm_time
+
+ comm_x = (s / (tp * cp))
+ if pp == 1:
+ # The last forward allgather is not calculated. The first two reverse allgathers plus the last allgather
+ # are not calculated.
+ # When the PP function is disabled, there are 18 communications in the TP domain. Therefore, four loss
+ # communications need to be excluded.
+ comm_time = (total_time - wait_time) * 14 / 18 / pp
+ self.tp_comm_overlap_time_list.append([overlap_time * 2 / 3 / pp])
+ else:
+ # When PP is enabled, there are 15 communications in the TP domain, and one loss communication needs to
+ # be excluded.
+ comm_time = (total_time - wait_time) * 14 / 15 / pp
+ self.tp_comm_overlap_time_list.append([overlap_time / pp])
+ self.comm.append_hccs([comm_x], comm_time)
+ self.tp_comm_total_time_list.append([total_time])
+ self.tp_comm_wait_time_list.append([wait_time])
+
+ def modeling(self):
+ self.comm.hccs_w, self.comm.hccs_b = self.comm.linear_x_y(
+ self.comm.hccs_x_list, self.comm.hccs_time_list)
+ self.tp_hccs_overlap_w, self.tp_hccs_overlap_b = self.comm.linear_x_y(
+ self.comm.hccs_x_list, self.tp_comm_overlap_time_list)
+ return
+
+ def print_modeling(self, config_list):
+ self.logger.debug(f"******************profile info list***********************")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<8}\t{8:<8}\t{9:<8}\t{10:<8}\t{11:<8}"
+ self.logger.debug(f"****************** tp(ms) ***********************")
+ self.logger.debug(tplt.format('No', 'tp', 'dp', 'pp', 'cp', 'ep', 'tp_time', 'tp_x', 'overlap_time', 'total_time',
+ 'wait_time', chr(12288)))
+
+ index = 0
+ for i, _ in enumerate(config_list):
+ if config_list[i].use_ascend_mc2:
+ continue
+ self.logger.debug(tplt.format(i, config_list[i].tp, config_list[i].dp, config_list[i].pp, config_list[i].cp,
+ config_list[i].ep,
+ round(self.comm.hccs_time_list[index][0], 2),
+ round(self.comm.hccs_x_list[index][0], 3),
+ round(self.tp_comm_overlap_time_list[index][0], 2),
+ round(self.tp_comm_total_time_list[index][0], 2),
+ round(self.tp_comm_wait_time_list[index][0], 2),
+ chr(12288)))
+ index += 1
+ self.logger.debug(f"-----------")
+ tplt = "{0:<9}\t{1:<9}\t{2:<9}\t{3:<9}"
+ self.logger.debug(tplt.format('tp_w', 'tp_b', 'overlap_w', 'overlap_b', chr(12288)))
+ self.logger.debug(tplt.format(round(self.comm.hccs_w, 3), round(self.comm.hccs_b, 3),
+ round(self.tp_hccs_overlap_w, 3),
+ round(self.tp_hccs_overlap_b, 3),
+ chr(12288)))
+ self.logger.debug(f"\n\n\n")
+ return
+
+ def performance(self, search_cfg):
+ tp = search_cfg.tensor_model_parallel_size
+ cp = search_cfg.context_parallel_size
+ s = search_cfg.seq_length / 1000
+ tp_overlap_time = 0
+ tp_time = 0
+ if tp > 1:
+ tp_time = self.comm.hccs_w * (s / (tp * cp)) + self.comm.hccs_b
+ tp_overlap_time = self.tp_hccs_overlap_w * \
+ s / (tp * cp) + self.tp_hccs_overlap_b
+ tp_time = tp_time - tp_overlap_time
+ if tp_time < 0:
+ tp_time = 0
+ return tp_time
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_profile.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_profile.py
new file mode 100644
index 0000000000000000000000000000000000000000..52fc022d8816fd67e851aeef2999847de76b646f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/communication/communication_profile.py
@@ -0,0 +1,70 @@
+class ProfileTimeInfo():
+ def __init__(self):
+ # Profile source information
+ self.total_comm_time = 0
+ self.wait_comm_time = 0
+ self.overlap_comm_time = 0
+
+
+class TpProfileTimeInfo(ProfileTimeInfo):
+ def __init__(self):
+ super(TpProfileTimeInfo, self).__init__()
+ # Total time when communication hiding is not performed
+ self.fixedtotal_tp_time = 0
+ self.fixedwait_tp_time = 0
+
+
+class Mc2ProfileTimeInfo(ProfileTimeInfo):
+ def __init__(self):
+ super(Mc2ProfileTimeInfo, self).__init__()
+ self.matmul_compute_time = 0
+
+
+class CpProfileTimeInfo(ProfileTimeInfo):
+ def __init__(self):
+ super(CpProfileTimeInfo, self).__init__()
+ # Total time when communication hiding is not performed
+ self.attn_cp_time = 0
+ self.attn_cpbw_time = 0
+ self.vector_cp_time = 0
+
+
+class DpProfileTimeInfo(ProfileTimeInfo):
+ def __init__(self):
+ super(DpProfileTimeInfo, self).__init__()
+ # Total time when communication hiding is not performed
+ self.overlap_grad_reduce = 0
+ self.overlap_param_gather = 0
+ self.overlap = 0
+ self.total_mlpzero_time = 0
+ self.total_otherzero_time = 0
+ self.mlp_ag_time = 0
+ self.mlp_rs_time = 0
+ self.other_ag_time = 0
+ self.other_rs_time = 0
+
+
+class EpProfileTimeInfo(ProfileTimeInfo):
+ def __init__(self):
+ super(EpProfileTimeInfo, self).__init__()
+ self.min_time = 0
+
+
+class PpProfileTimeInfo(ProfileTimeInfo):
+ def __init__(self):
+ super(PpProfileTimeInfo, self).__init__()
+ # Total time when communication hiding is not performed
+ self.each_pp_time = 0
+ self.bubble_end_time = 0
+ self.bubble_start_time = 0
+
+
+class TotalProfileTimeInfo():
+ def __init__(self):
+ # Profile source information
+ self.tp_profile_time_info = TpProfileTimeInfo()
+ self.cp_profile_time_info = CpProfileTimeInfo()
+ self.dp_profile_time_info = DpProfileTimeInfo()
+ self.ep_profile_time_info = EpProfileTimeInfo()
+ self.pp_profile_time_info = PpProfileTimeInfo()
+ self.mc2_profile_time_info = Mc2ProfileTimeInfo()
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/hardware.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/hardware.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6d8bb668a6784c42637b05371fe6ea97d77894f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/hardware.py
@@ -0,0 +1,50 @@
+from __future__ import annotations
+
+from mindspeed.auto_tuning.utils.singleton import Singleton
+
+
+class Hardware(metaclass=Singleton):
+ """
+ hardware modeling
+ """
+ HARDWARE_PARSE_FILENAME = "auto_tuning_hardware.json"
+
+ def __init__(self) -> None:
+ self.device_type: str = "910"
+ self.host_ip: str = "localhost"
+ self.user_name: str = "root"
+
+ self.cube_performance: float = 363.7248
+ self.vector_performance: float = 11.3664
+ self.cube_utilization_ratio: float = 0.742
+ self.cube_time_ratio: float = 0.62
+ self.memory_limit: float = 60.0 * 1024
+
+ # intra-node config
+ self.devices_per_node: int = 8
+ self.intra_node_bandwidth: int = 196
+ self.intra_node_bandwidth_utilization_ratio: float = 0.65
+
+ # inter-node config
+ self.num_nodes: int = 2
+ self.node_rank: int = 0
+ self.inter_node_bandwidth: int = 25
+ self.inter_node_bandwidth_utilization_ratio: float = 0.7
+
+ def __str__(self):
+ rt = []
+ rt.append(f"{'Device Type':<30}{str(self.device_type):<40}")
+ rt.append(f"{'Host IP':<30}{str(self.host_ip):<40}")
+ rt.append(f"{'Devices Per Node':<30}{str(self.devices_per_node):<40}")
+ rt.append(f"{'Number Nodes':<30}{str(self.num_nodes):<40}")
+ rt.append(f"{'Node rank':<30}{str(self.node_rank):<40}")
+ return '\n'.join(rt)
+
+ @property
+ def num_devices(self) -> int:
+ return self.devices_per_node * self.num_nodes
+
+ def load(self, hardware: Hardware) -> None:
+ for k in self.__dict__.keys():
+ if k in hardware.__dict__:
+ self.__dict__[k] = hardware.__dict__[k]
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/dynamic_mem_modeling.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/dynamic_mem_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..366ae585724f29cd86f0fbd8ac120689aeb2a889
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/dynamic_mem_modeling.py
@@ -0,0 +1,349 @@
+from typing import no_type_check, List, Tuple
+from collections import namedtuple
+from dataclasses import replace
+import os.path
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_config import ProfilingModelInfo
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_node_parse import GatherNodeProfiling
+from mindspeed.auto_tuning.utils.utils import get_prof_dir
+
+ProfileResult = namedtuple("ProfileResult", ["cfg", "prof"])
+MemModule = namedtuple(
+ "MemModule",
+ [
+ "checkpoint_activation_layer",
+ "checkpoint_activation_embedding",
+ "checkpoint_activation_loss",
+ "forward_peak",
+ "loss_peak",
+ "backward_peak",
+ "optimizer_peak"
+ ]
+)
+
+
+class DynamicMemModeling:
+ BASELINE_SEQLEN = 4096
+
+ @no_type_check
+ def __init__(self, model_cfg: ModelConfig) -> None:
+ self.model_cfg = model_cfg
+ self._logger = get_logger("dynamic_mem")
+ self.ckpt_act_layer: float = None
+ self.ckpt_act_embedding: float = None
+ self.ckpt_act_tp_b_embedding: float = None
+ self.ckpt_act_loss: float = None
+ self.forward_peak: float = None
+ self.tp_b_forward_peak: float = None
+ self.backward_peak: float = None
+ self.tp_b_backward_peak: float = None
+ self.loss_peak: float = None
+ self.tp_b_loss_peak: float = None
+ self.optimizer_peak: float = None
+ self.tp_b_optimizer_peak: float = None
+ self.seq_b_optimizer_peak: float = None
+
+ @staticmethod
+ def _cal_peak_mem_per_stage(mem_module,
+ cfg: SearchConfig,
+ schedule: str,
+ nlayer: int,
+ stage_id: int
+ ) -> float:
+ checkpoint_activation_layer, \
+ checkpoint_activation_embedding, \
+ checkpoint_activation_loss, \
+ forward_peak, \
+ loss_peak, \
+ backward_peak, \
+ _ = mem_module
+
+ if schedule == "1f1b":
+ if not cfg.vpp:
+ num_warmup = cfg.pp - stage_id
+ num_embd = cfg.pp
+ else:
+ num_warmup = cfg.pp * (cfg.vpp + 1) - 1 - 2 * stage_id
+ num_embd = cfg.pp * 2 - 1
+
+ estimated_forward_peak = checkpoint_activation_layer * nlayer * (num_warmup - 1) + \
+ checkpoint_activation_layer * (nlayer - 1 + 1) + \
+ forward_peak
+
+ estimated_backward_peak = checkpoint_activation_layer * nlayer * num_warmup + \
+ backward_peak
+
+ if stage_id == 0:
+ estimated_forward_peak += checkpoint_activation_embedding * num_embd
+ estimated_backward_peak += checkpoint_activation_embedding * num_embd
+
+ if stage_id == cfg.pp - 1:
+ estimated_forward_peak += checkpoint_activation_loss
+ estimated_backward_peak += checkpoint_activation_loss
+
+ estimated_loss_peak = checkpoint_activation_layer * nlayer * num_warmup + \
+ checkpoint_activation_loss * (num_warmup - 1) + \
+ loss_peak
+ else:
+ estimated_loss_peak = 0
+
+ peak_mem = max(estimated_forward_peak,
+ estimated_backward_peak,
+ estimated_loss_peak)
+ else:
+ peak_mem = 0
+
+ return peak_mem
+
+ def generate_dynamic_mem_profiling_list(self) -> List[SearchConfig]:
+ result: List[SearchConfig] = list()
+
+ baseline_cfg = SearchConfig()
+ baseline_cfg.copy_from_config(self.model_cfg)
+ baseline_cfg.tensor_model_parallel_size = 4
+ baseline_cfg.context_parallel_size = 1
+ baseline_cfg.pipeline_model_parallel_size = 1
+ baseline_cfg.num_layers = 1
+ baseline_cfg.seq_length = self.BASELINE_SEQLEN
+ if self.model_cfg.is_moe():
+ baseline_cfg.num_experts = 4
+ baseline_cfg.expert_model_parallel_size = 1
+ result.append(baseline_cfg)
+
+ tp8_cfg = replace(baseline_cfg,
+ tensor_model_parallel_size=8)
+ result.append(tp8_cfg)
+
+ seq8k_cfg = replace(baseline_cfg,
+ seq_length=2 * self.BASELINE_SEQLEN)
+ result.append(seq8k_cfg)
+
+ for cfg in result:
+ cfg.prepare_for_profiling()
+
+ return result
+
+ def model_dynamic_mem(self, working_dir: str) -> None:
+ def _get_profiling(cfg: SearchConfig) -> ProfilingModelInfo:
+ profiling_path = os.path.join(working_dir, get_prof_dir(cfg))
+ profiling_node_parse = GatherNodeProfiling(profiling_path)
+ return profiling_node_parse.fuse_node_pkl()
+
+ baseline_cfg, tp8_cfg, seq8k_cfg = \
+ self.generate_dynamic_mem_profiling_list()
+
+ tp4seq4k_prof = _get_profiling(baseline_cfg)
+ tp8seq4k_prof = _get_profiling(tp8_cfg)
+ tp4seq8k_prof = _get_profiling(seq8k_cfg)
+
+ self._get_ckpt_act_layer_modeling(baseline_cfg, tp4seq4k_prof)
+ self._get_ckpt_act_embedding_modeling(baseline_cfg,
+ tp8_cfg,
+ tp4seq4k_prof,
+ tp8seq4k_prof)
+ self._get_ckpt_act_loss_modeling(baseline_cfg, tp4seq4k_prof)
+ self._get_forward_peak_modeling(baseline_cfg,
+ tp8_cfg,
+ tp4seq4k_prof,
+ tp8seq4k_prof)
+ self._get_backward_peak_modeling(baseline_cfg,
+ tp8_cfg,
+ tp4seq4k_prof,
+ tp8seq4k_prof)
+ self._get_loss_peak_modeling(baseline_cfg,
+ tp8_cfg,
+ tp4seq4k_prof,
+ tp8seq4k_prof)
+ self._get_optimizer_peak_modeling(
+ ProfileResult(cfg=baseline_cfg, prof=tp4seq4k_prof),
+ ProfileResult(cfg=seq8k_cfg, prof=tp4seq8k_prof),
+ ProfileResult(cfg=tp8_cfg, prof=tp8seq4k_prof)
+ )
+
+ self._logger.debug("== ckpt_act_layer:")
+ self._logger.debug(f"{self.ckpt_act_layer}")
+ self._logger.debug("== ckpt_act_embedding:")
+ self._logger.debug(f"{self.ckpt_act_embedding}, {self.ckpt_act_tp_b_embedding}")
+ self._logger.debug("== ckpt_act_loss:")
+ self._logger.debug(f"{self.ckpt_act_loss}")
+ self._logger.debug("== forward_peak:")
+ self._logger.debug(f"{self.forward_peak}, {self.tp_b_forward_peak}")
+ self._logger.debug("== backward_peak:")
+ self._logger.debug(f"{self.backward_peak}, {self.tp_b_backward_peak}")
+ self._logger.debug("== loss_peak:")
+ self._logger.debug(f"{self.loss_peak}, {self.tp_b_loss_peak}")
+ self._logger.debug("== optimizer_peak:")
+ self._logger.debug(f"{self.optimizer_peak}, {self.tp_b_optimizer_peak}, {self.seq_b_optimizer_peak}")
+
+ def cal_dynamic_mem(self,
+ cfg: SearchConfig
+ ) -> Tuple[List[float], float]:
+ mem_module = self._cal_mem_module(cfg)
+ optimizer_peak = mem_module[-1]
+
+ nlayer = self.model_cfg.num_layers // cfg.pp
+ if cfg.layers_per_vpp:
+ nlayer = cfg.layers_per_vpp
+
+ schedule = "1f1b"
+ dynamic_mem_stages: List[float] = list()
+ for stage_id in range(cfg.pp):
+ peak_mem = self._cal_peak_mem_per_stage(mem_module,
+ cfg,
+ schedule,
+ nlayer,
+ stage_id)
+ peak_mem *= (cfg.mbs / 1) # mbs in profiling cfg equals 1
+ dynamic_mem_stages.append(peak_mem)
+ return dynamic_mem_stages, optimizer_peak
+
+ def _get_ckpt_act_layer_modeling(self,
+ base_cfg: SearchConfig,
+ base_prof: ProfilingModelInfo
+ ) -> None:
+ self.ckpt_act_layer = base_cfg.tp * \
+ (base_prof.loss.start_memory[0][0] -
+ base_prof.forward.start_memory[0][0])
+
+ def _get_ckpt_act_embedding_modeling(self,
+ base_cfg: SearchConfig,
+ bi_tp_cfg: SearchConfig,
+ base_prof: ProfilingModelInfo,
+ bi_tp_prof: ProfilingModelInfo) -> None:
+ base_embd = base_prof.forward.start_memory[0][0] - \
+ base_prof.embedding.start_memory[0][0]
+ bi_tp_embd = bi_tp_prof.forward.start_memory[0][0] - \
+ bi_tp_prof.embedding.start_memory[0][0]
+ self.ckpt_act_tp_b_embedding = bi_tp_embd * \
+ (bi_tp_cfg.tp // base_cfg.tp) - \
+ base_embd
+ self.ckpt_act_embedding = base_embd * base_cfg.tp - \
+ self.ckpt_act_tp_b_embedding * (base_cfg.tp - 1)
+
+ def _get_ckpt_act_loss_modeling(self,
+ base_cfg: SearchConfig,
+ base_prof: ProfilingModelInfo) -> None:
+ self.ckpt_act_loss = base_cfg.tp * \
+ (base_prof.backward.start_memory[0][0] -
+ base_prof.loss.start_memory[0][0])
+
+ def _get_forward_peak_modeling(self,
+ base_cfg: SearchConfig,
+ bi_tp_cfg: SearchConfig,
+ base_prof: ProfilingModelInfo,
+ bi_tp_prof: ProfilingModelInfo) -> None:
+ base_forward_peak = base_prof.forward.peak_memory[0][0] - \
+ base_prof.loss.start_memory[0][0]
+ bi_tp_forward_peak = bi_tp_prof.forward.peak_memory[0][0] - \
+ bi_tp_prof.loss.start_memory[0][0]
+ self.tp_b_forward_peak = bi_tp_forward_peak * \
+ (bi_tp_cfg.tp // base_cfg.tp) - \
+ base_forward_peak
+ self.forward_peak = base_forward_peak * base_cfg.tp - \
+ self.tp_b_forward_peak * (base_cfg.tp - 1)
+
+ def _get_backward_peak_modeling(self,
+ base_cfg: SearchConfig,
+ bi_tp_cfg: SearchConfig,
+ base_prof: ProfilingModelInfo,
+ bi_tp_prof: ProfilingModelInfo) -> None:
+ base_backward_peak = base_prof.backward.peak_memory[0][0] - \
+ base_prof.backward.start_memory[0][0]
+ bi_tp_backward_peak = bi_tp_prof.backward.peak_memory[0][0] - \
+ bi_tp_prof.backward.start_memory[0][0]
+ self.tp_b_backward_peak = bi_tp_backward_peak * \
+ (bi_tp_cfg.tp // base_cfg.tp) - \
+ base_backward_peak
+ self.backward_peak = base_backward_peak * base_cfg.tp - \
+ self.tp_b_backward_peak * (base_cfg.tp - 1)
+
+ def _get_loss_peak_modeling(self,
+ base_cfg: SearchConfig,
+ bi_tp_cfg: SearchConfig,
+ base_prof: ProfilingModelInfo,
+ bi_tp_prof: ProfilingModelInfo) -> None:
+ base_loss_peak = base_prof.loss.peak_memory[0][0] - \
+ base_prof.loss.start_memory[0][0]
+ bi_tp_loss_peak = bi_tp_prof.loss.peak_memory[0][0] - \
+ bi_tp_prof.loss.start_memory[0][0]
+ self.tp_b_loss_peak = bi_tp_loss_peak * \
+ (bi_tp_cfg.tp // base_cfg.tp) - \
+ base_loss_peak
+ self.loss_peak = base_loss_peak * base_cfg.tp - \
+ self.tp_b_loss_peak * (base_cfg.tp - 1)
+
+ def _get_optimizer_peak_modeling(
+ self,
+ base_res: ProfileResult,
+ bi_seq_res: ProfileResult,
+ bi_tp_res: ProfileResult
+ ) -> None:
+ base_cfg, base_prof = base_res
+ bi_seq_cfg, bi_seq_prof = bi_seq_res
+ bi_tp_cfg, bi_tp_prof = bi_tp_res
+ base_optimizer_peak = base_prof.optimizer.peak_memory[0][0] - \
+ base_prof.optimizer.start_memory[0][0]
+ bi_seq_optimizer_peak = bi_seq_prof.optimizer.peak_memory[0][0] - \
+ bi_seq_prof.optimizer.start_memory[0][0]
+ bi_tp_optimizer_peak = bi_tp_prof.optimizer.peak_memory[0][0] - \
+ bi_tp_prof.optimizer.start_memory[0][0]
+ self.seq_b_optimizer_peak = (base_optimizer_peak *
+ (bi_seq_cfg.seq_length // base_cfg.seq_length) -
+ bi_seq_optimizer_peak) * base_cfg.tp
+ self.tp_b_optimizer_peak = bi_tp_optimizer_peak * \
+ (bi_tp_cfg.tp // base_cfg.tp) - \
+ base_optimizer_peak
+ self.optimizer_peak = base_optimizer_peak * base_cfg.tp - \
+ self.tp_b_optimizer_peak * (base_cfg.tp - 1)
+
+ def _cal_mem_module(self, cfg: SearchConfig) -> MemModule:
+ seq_length = self.model_cfg.seq_length
+ nseq = seq_length // cfg.cp // self.BASELINE_SEQLEN
+ tp = cfg.tp
+ tp_w = cfg.tp - 1
+
+ checkpoint_activation_layer = self.ckpt_act_layer * nseq / tp
+
+ checkpoint_activation_embedding = \
+ (self.ckpt_act_embedding +
+ tp_w * self.ckpt_act_tp_b_embedding) * nseq / tp
+
+ checkpoint_activation_loss = self.ckpt_act_loss * nseq / tp
+
+ forward_peak = \
+ (self.forward_peak +
+ tp_w * self.tp_b_forward_peak) * nseq / tp
+
+ loss_peak = \
+ (self.loss_peak +
+ tp_w * self.tp_b_loss_peak) * nseq / tp
+
+ backward_peak = \
+ (self.backward_peak +
+ tp_w * self.tp_b_backward_peak) * nseq / tp
+
+ optimizer_peak = \
+ ((self.optimizer_peak +
+ tp_w * self.tp_b_optimizer_peak) * nseq -
+ self.seq_b_optimizer_peak * (nseq - 1)) / tp
+
+ self._logger.debug(f"== checkpoint_activation_layer: {checkpoint_activation_layer}")
+ self._logger.debug(f"== checkpoint_activation_embedding: {checkpoint_activation_embedding}")
+ self._logger.debug(f"== checkpoint_activation_loss: {checkpoint_activation_loss}")
+ self._logger.debug(f"== forward_peak: {forward_peak}")
+ self._logger.debug(f"== loss_peak: {loss_peak}")
+ self._logger.debug(f"== backward_peak: {backward_peak}")
+ self._logger.debug(f"== optimizer_peak: {optimizer_peak}")
+
+ return MemModule(
+ checkpoint_activation_layer,
+ checkpoint_activation_embedding,
+ checkpoint_activation_loss,
+ forward_peak,
+ loss_peak,
+ backward_peak,
+ optimizer_peak
+ )
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/memory_modeling.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/memory_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..c50d8cfde357176c9be0e6598a9ab118fd2a0551
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/memory_modeling.py
@@ -0,0 +1,55 @@
+from typing import List, Tuple
+from logging import Logger
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.module.memory.static_mem_modeling import StaticMemModeling
+from mindspeed.auto_tuning.module.memory.dynamic_mem_modeling import DynamicMemModeling
+
+
+class MemoryModeling:
+ _static_modeling: StaticMemModeling = None # type: ignore
+ _dynamic_modeling: DynamicMemModeling = None # type: ignore
+ _logger: Logger = None # type: ignore
+
+ def __new__(cls):
+ raise NotImplementedError("MemoryModeling is a static class.")
+
+ @classmethod
+ def set_model_cfg(cls, model_cfg: ModelConfig) -> None:
+ if cls._static_modeling and cls._dynamic_modeling:
+ raise ValueError("ModelConfig has yet been set.")
+ cls._static_modeling = StaticMemModeling(model_cfg)
+ cls._dynamic_modeling = DynamicMemModeling(model_cfg)
+ cls._logger = get_logger("memory")
+
+ @classmethod
+ def generate_mem_modeling_profiling_list(cls) -> Tuple[List[Tuple[SearchConfig, str]], List[SearchConfig]]:
+ return cls._static_modeling.generate_static_mem_profiling_list(), \
+ cls._dynamic_modeling.generate_dynamic_mem_profiling_list()
+
+ @classmethod
+ def modeling(cls, working_dir: str) -> None:
+ cls._static_modeling.model_static_mem(working_dir)
+ cls._dynamic_modeling.model_dynamic_mem(working_dir)
+
+ @classmethod
+ def estimate(cls, cfg: SearchConfig) -> Tuple[float, float]:
+ cls._logger.debug("==========Memory Estimate Summary==========")
+ static_mem = cls._static_modeling.cal_static_mem(cfg)
+ dynamic_mem, optimizer_peak = \
+ cls._dynamic_modeling.cal_dynamic_mem(cfg)
+ peak_stage_mem = float(0)
+ for stage_id in range(cfg.pp):
+ stage_mem = static_mem[stage_id] + dynamic_mem[stage_id]
+ peak_stage_mem = max(peak_stage_mem, stage_mem)
+ cls._logger.debug(f"== stage_id: {stage_id} ==\n"
+ f"static memory: {static_mem[stage_id]} MB\n"
+ f"dynamic peak memory: {dynamic_mem[stage_id]} MB\n"
+ f"peak memory: {stage_mem} MB")
+ optimizer_peak = max([m + optimizer_peak for m in static_mem])
+ cls._logger.debug(f"optimizer peak memory: {optimizer_peak} MB")
+ cls._logger.debug("==========Memory Estimate Summary End==========")
+
+ return max(peak_stage_mem, optimizer_peak), optimizer_peak
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/model_param.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/model_param.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea5b9bb6349c781a170833e889780cd33f3eb0ca
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/model_param.py
@@ -0,0 +1,16 @@
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class ModelParam:
+ name: str
+ num_parameters: int
+
+ @staticmethod
+ def cmp(left: object, right: object) -> int:
+ if isinstance(left, ModelParam) and isinstance(right, ModelParam):
+ if left == right:
+ return 1
+ elif left.name == right.name:
+ return -1
+ return 0
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/static_mem_modeling.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/static_mem_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..3598084e6c4536fdbc9274b7ac158e2551d77f28
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/memory/static_mem_modeling.py
@@ -0,0 +1,257 @@
+from typing import no_type_check, Any, List, Set, Tuple
+from dataclasses import replace
+from itertools import chain
+import os.path
+
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.module.memory.model_param import ModelParam
+from mindspeed.auto_tuning.utils.dtype import DTYPE
+from mindspeed.auto_tuning.utils.mem_utils import mem_b_to_mb
+from mindspeed.auto_tuning.utils.restricted_unpickler import restricted_loads
+
+
+class StaticMemModeling:
+ LAYER1_FILENAME = "auto_tuning_static_model_layer1.json"
+ PP4_FILENAME = "auto_tuning_static_model_pp4.json"
+ EXPERT2_FILENAME = "auto_tuning_static_model_expert2.json"
+ TP2_FILENAME = "auto_tuning_static_model_tp2.json"
+
+ @no_type_check
+ def __init__(self, model_cfg: ModelConfig) -> None:
+ self.model_cfg = model_cfg
+ self._logger = get_logger("static_mem")
+ self.params_first_embedding: List[ModelParam] = None
+ self.params_per_layer_wo_experts: List[ModelParam] = None
+ self.params_per_experts: List[ModelParam] = None
+ self.params_last_layernorm_and_embedding: List[ModelParam] = None
+ self.params_pp_affected: List[ModelParam] = None
+ self.params_tp_unaffected: Set[str] = set()
+
+ @staticmethod
+ def _diff_params(left: List[ModelParam],
+ right: List[ModelParam]
+ ) -> List[ModelParam]:
+ """
+ Finds the difference between two lists of parameters.
+ The result follows these conditions:
+ 1. If a param exists in right but not in left,
+ it gets appended directly into the result
+
+ 2. If a param exists in both lists and sharing a same name,
+ however the shape is different, the shape difference is appended
+
+ 3. If a param (say A) exists only in left,
+ we assume there's another param B with the same name
+ but shape of 0 in the right list,
+ thus 0 (B's shape) subtracted by A's shape gets appended
+ """
+ diff: List[ModelParam] = list()
+
+ left_iter = iter(left)
+ left_p = next(left_iter, None)
+ for right_p in right:
+ cmp_result = ModelParam.cmp(left_p, right_p)
+ if cmp_result == 1:
+ left_p = next(left_iter, None)
+ elif cmp_result == -1 and left_p:
+ diff.append(ModelParam(left_p.name,
+ right_p.num_parameters -
+ left_p.num_parameters
+ ))
+ left_p = next(left_iter, None)
+ else:
+ diff.append(right_p)
+
+ while left_p:
+ diff.append(ModelParam(left_p.name, -left_p.num_parameters))
+ left_p = next(left_iter, None)
+
+ return diff
+
+ def generate_static_mem_profiling_list(self) -> List[Tuple[SearchConfig, str]]:
+ result: List[Tuple[SearchConfig, str]] = list()
+
+ layer1_cfg = SearchConfig()
+ layer1_cfg.copy_from_config(self.model_cfg)
+ layer1_cfg.tensor_model_parallel_size = 1
+ layer1_cfg.context_parallel_size = 1
+ layer1_cfg.pipeline_model_parallel_size = 1
+ layer1_cfg.num_layers = 1
+ if self.model_cfg.is_moe():
+ layer1_cfg.num_experts = 1
+ layer1_cfg.expert_model_parallel_size = 1
+ result.append((layer1_cfg, self.LAYER1_FILENAME))
+
+ pp4_cfg = replace(layer1_cfg,
+ pipeline_model_parallel_size=4,
+ num_layers=4)
+ result.append((pp4_cfg, self.PP4_FILENAME))
+
+ if self.model_cfg.is_moe():
+ expert2_cfg = replace(pp4_cfg, num_experts=2)
+ result.append((expert2_cfg, self.EXPERT2_FILENAME))
+
+ tp2_cfg = replace(pp4_cfg, tensor_model_parallel_size=2)
+ result.append((tp2_cfg, self.TP2_FILENAME))
+
+ for cfg, _ in result:
+ cfg.prepare_for_profiling()
+
+ return result
+
+ def model_static_mem(self, working_dir: str) -> None:
+ def _decode(filename: str) -> Any:
+ filepath = os.path.join(working_dir, filename)
+ with open(filepath, mode="rb") as file:
+ decode = restricted_loads(file)
+ return decode
+
+ def _get_pp_params(filename: str) -> List[List[ModelParam]]:
+ params = [None] * 4
+ for pp_rank, model_params in _decode(filename):
+ if not params[pp_rank]:
+ params[pp_rank] = model_params
+ return params # type: ignore
+
+ total_pp4_params = _get_pp_params(self.PP4_FILENAME)
+ per_layer_w_experts_params = total_pp4_params[1]
+ self.params_first_embedding = \
+ self._diff_params(per_layer_w_experts_params,
+ total_pp4_params[0])
+ self.params_last_layernorm_and_embedding = \
+ self._diff_params(per_layer_w_experts_params,
+ total_pp4_params[-1])
+
+ if self.model_cfg.is_moe():
+ total_expert2_params = _get_pp_params(self.EXPERT2_FILENAME)
+ self.params_per_experts = \
+ self._diff_params(per_layer_w_experts_params,
+ total_expert2_params[1])
+ else:
+ self.params_per_experts = list()
+ self.params_per_layer_wo_experts = \
+ self._diff_params(self.params_per_experts,
+ per_layer_w_experts_params)
+
+ total_layer1_params: List[List[ModelParam]] = \
+ [p for _, p in _decode(self.LAYER1_FILENAME)]
+ layer1_params = total_layer1_params[0]
+ self.params_pp_affected = \
+ self._diff_params(self.params_first_embedding +
+ self.params_per_layer_wo_experts +
+ self.params_per_experts +
+ self.params_last_layernorm_and_embedding,
+ layer1_params)
+
+ total_tp2_params = _get_pp_params(self.TP2_FILENAME)
+ total_pp4_params_concat = list(chain.from_iterable(total_pp4_params))
+ total_tp2_params_concat = list(chain.from_iterable(total_tp2_params))
+ for i, param in enumerate(total_pp4_params_concat):
+ if param == total_tp2_params_concat[i]:
+ self.params_tp_unaffected.add(param.name)
+
+ self._logger.debug("\n== first embedding params:\n" +
+ "\n".join(
+ [str(p) for p in self.params_first_embedding]) +
+ "\n== layer_wo_experts params:\n" +
+ "\n".join(
+ [str(p) for p in self.params_per_layer_wo_experts]) +
+ "\n== experts params:\n" +
+ "\n".join(
+ [str(p) for p in self.params_per_experts]) +
+ "\n== last layer norm and embedding params:\n" +
+ "\n".join(
+ [str(p) for p in self.params_last_layernorm_and_embedding]) +
+ "\n== pp affected params:\n" +
+ "\n".join(
+ [str(p) for p in self.params_pp_affected]) +
+ "\n== not tp affected params:\n" +
+ "\n".join(
+ [str(p) for p in self.params_tp_unaffected]))
+
+ def cal_static_mem(self, cfg: SearchConfig) -> List[float]:
+ dtype = self.model_cfg.dtype
+ non_expert_zero1 = cfg.dp * cfg.cp
+ expert_zero1 = cfg.dp * cfg.cp / (cfg.ep if cfg.ep else 1)
+
+ def _cal_static_mem_per_stage(non_expert_params: int,
+ expert_params: int,
+ not_zero1_div_bytes: int,
+ zero1_div_bytes: int
+ ) -> float:
+ result = float(0)
+ if cfg.zero1:
+ result += non_expert_params * \
+ (not_zero1_div_bytes + zero1_div_bytes / non_expert_zero1)
+ result += expert_params * \
+ (not_zero1_div_bytes + zero1_div_bytes / expert_zero1)
+ else:
+ result += (non_expert_params + expert_params) * \
+ (not_zero1_div_bytes + zero1_div_bytes)
+ result = mem_b_to_mb(result * dtype.value[1])
+ result += 5000 # roughly estimated cann+hccl+driver+os memory
+ return result
+
+ static_mem_stages: List[float] = list()
+ for stage_id in range(cfg.pp):
+ non_expert_params_per_stage, expert_params_per_stage = \
+ self._cal_num_params_per_stage(stage_id, cfg)
+ if dtype == DTYPE.fp16:
+ static_mem_per_stage = \
+ _cal_static_mem_per_stage(non_expert_params_per_stage,
+ expert_params_per_stage,
+ 1 + 1,
+ 8)
+ elif dtype == DTYPE.bf16:
+ static_mem_per_stage = \
+ _cal_static_mem_per_stage(non_expert_params_per_stage,
+ expert_params_per_stage,
+ 1 + 2,
+ 6)
+ else:
+ static_mem_per_stage = \
+ _cal_static_mem_per_stage(non_expert_params_per_stage,
+ expert_params_per_stage,
+ 1 + 1,
+ 2)
+ static_mem_stages.append(static_mem_per_stage)
+ return static_mem_stages
+
+ def _cal_num_params_per_stage(self,
+ stage_id: int,
+ cfg: SearchConfig
+ ) -> Tuple[int, int]:
+ def _cal_num_params(param: ModelParam, ep: int = 1):
+ if param.name in self.params_tp_unaffected:
+ return param.num_parameters
+ else:
+ return param.num_parameters // ep // cfg.tp
+
+ num_layers = self.model_cfg.num_layers
+
+ non_expert_params = 0
+ for param in self.params_per_layer_wo_experts:
+ non_expert_params += _cal_num_params(param)
+ non_expert_params *= num_layers // cfg.pp
+
+ expert_params = 0
+ if cfg.num_experts and cfg.ep:
+ for param in self.params_per_experts:
+ expert_params += _cal_num_params(param, ep=cfg.ep)
+ expert_params *= (num_layers * cfg.num_experts) // cfg.pp
+
+ if stage_id == 0:
+ for param in self.params_first_embedding:
+ non_expert_params += _cal_num_params(param)
+ if stage_id == cfg.pp - 1:
+ for param in self.params_last_layernorm_and_embedding:
+ non_expert_params += _cal_num_params(param)
+
+ if cfg.pp == 1:
+ for param in self.params_pp_affected:
+ non_expert_params += _cal_num_params(param)
+
+ return non_expert_params, expert_params
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/model_performance.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/model_performance.py
new file mode 100644
index 0000000000000000000000000000000000000000..728afb74f565330637ebca866783aaeb7f7ec1eb
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/model_performance.py
@@ -0,0 +1,159 @@
+import math
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.module.communication.communication import Communication
+from mindspeed.auto_tuning.module.operator.operator import OperatorPerformance
+from mindspeed.auto_tuning.module.operator.operator_re_profile import search_operator
+from mindspeed.auto_tuning.utils.logger import get_logger
+
+
+class ModelPerformance(object):
+ """
+ Model Performance modeling
+ """
+
+ def __init__(self, hardware=None, model_cfg: ModelConfig = None, working_dir: str = None):
+ self.communication = Communication(hardware, model_cfg)
+ self.operator = OperatorPerformance(model_cfg, working_dir=working_dir)
+ self.hardware = hardware
+ self.logger = get_logger("ModelPerformance")
+
+ def get_profiling_info(self, profiling_results):
+ self.communication.communication_modeling(profiling_results)
+ profiling_wo_mc2 = []
+ for item in profiling_results:
+ if item[0].use_ascend_mc2:
+ pass
+ else:
+ profiling_wo_mc2.append(item)
+ self.operator.model_operator_timer(profiling_wo_mc2)
+
+ def performance(self, search_cfg, working_dir, profile_count, re_profile_flag=False):
+ tp = search_cfg.tensor_model_parallel_size
+ dp = search_cfg.data_parallel_size
+ pp = search_cfg.pipeline_model_parallel_size
+ vp = search_cfg.num_layers // (pp * search_cfg.num_layers_per_virtual_pipeline_stage) \
+ if search_cfg.num_layers_per_virtual_pipeline_stage else 1
+ cp = search_cfg.context_parallel_size
+ ep = search_cfg.expert_model_parallel_size if search_cfg.expert_model_parallel_size else 1
+ num_layers = self.communication.model_cfg.num_layers
+ global_batch_size = self.communication.model_cfg.global_batch_size
+ model_micro_batch_size = self.communication.model_cfg.micro_batch_size
+ search_micro_batch_size = search_cfg.micro_batch_size
+ zero = search_cfg.use_distributed_optimizer
+ operator_time, unsampled_profiling = self.operator_performance(
+ search_cfg, working_dir, profile_count, re_profile_flag
+ )
+ comm_gap = 8
+
+ # Time for each micro-batch in each layer.
+ mc2_time = self.communication.mc2_model.performance(search_cfg)
+ tp_time = self.communication.tp_model.performance(search_cfg)
+
+ self.logger.debug(f"mc2_time:{mc2_time} tp_time:{tp_time}")
+ use_mc2 = mc2_time < tp_time
+ tp_time = min(mc2_time, tp_time)
+
+ cp_time = self.communication.cp_model.performance(search_cfg)
+ dp_time = self.communication.dp_model.performance(search_cfg)
+ pp_time = self.communication.pp_model.performance(search_cfg)
+ ep_time = self.communication.ep_model.performance(search_cfg)
+
+ micro_batch_num = global_batch_size / (dp * search_micro_batch_size)
+ # total layer number,total global_batch_size
+ layer_num = math.ceil(micro_batch_num * (num_layers / pp))
+ search_model_mbs_ratio = search_micro_batch_size / model_micro_batch_size
+ communication_time = (tp_time + cp_time + ep_time) * search_model_mbs_ratio * layer_num
+ total_operator_time = operator_time * layer_num
+ total_time = total_operator_time + communication_time
+
+ total_communication_time = communication_time + pp_time * search_model_mbs_ratio + dp_time
+ self.logger.debug('global_batch_size : {}, num_layers : {}, search_micro_batch_size : {}, operator_time : {}, '
+ 'layer_num : {}'.format(global_batch_size, num_layers, search_micro_batch_size,
+ operator_time, layer_num))
+ bubble_ratio = (pp - 1) / (micro_batch_num * vp + pp - 1)
+ total_time = total_time / (1 - bubble_ratio)
+ bubble_time = total_time * bubble_ratio
+ total_time = total_time + pp_time * search_model_mbs_ratio + dp_time
+
+ self.logger.debug(f"****************** total_time(ms) ***********************")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<8}\t{7:<10}\t{8:<8}\t{9:<8}"
+ self.logger.debug(tplt.format('tp', 'dp', 'pp', 'vp', 'cp', 'ep', 'operator_time',
+ 'comm_time', 'bubble_time', 'total_time', chr(12288)))
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:8.2f}\t{7:8.2f}\t{8:8.2f}\t{9:8.2f}"
+ self.logger.debug(tplt.format(tp, dp, pp, vp, cp, ep, total_operator_time,
+ total_communication_time, bubble_time, total_time, chr(12288)))
+ tplt = "{0:<4}\t{1:<4}\t{2:<4}\t{3:<4}\t{4:<4}\t{5:<4}"
+ self.logger.debug(f"******* each layer mbs communication time(ms) ********")
+ self.logger.debug(tplt.format('tp_time', 'dp_time', 'pp_time',
+ 'bubble', 'cp_time', 'ep_time', chr(12288)))
+ tplt = "{0:4.2f}\t{1:4.2f}\t{2:4.2f}\t{3:4.2f}\t{4:4.2f}\t{5:4.2f}"
+ self.logger.debug(tplt.format(tp_time, dp_time, pp_time,
+ bubble_time, cp_time, ep_time, chr(12288)))
+ self.logger.debug(f"end-to-end, each*(global_batch_size / (dp *pp))* num_layers")
+ tplt = "{0:<4}\t{1:<4}\t{2:<4}\t{3:<4}\t{4:<4}\t{5:<4}"
+ self.logger.debug(tplt.format('tp_time', 'dp_time', 'pp_time',
+ 'bubble', 'cp_time', 'ep_time', chr(12288)))
+ tplt = "{0:4.0f}\t{1:4.2f}\t{2:4.2f}\t{3:4.2f}\t{4:4.2f}\t{5:4.2f}"
+ self.logger.debug(tplt.format(tp_time * layer_num * search_model_mbs_ratio, dp_time,
+ pp_time, bubble_time, cp_time * layer_num * search_model_mbs_ratio,
+ ep_time * layer_num * search_model_mbs_ratio, chr(12288)))
+ return total_time, unsampled_profiling, use_mc2
+
+ def operator_performance(self, search_cfg, working_dir, profile_count,
+ re_profile_flag=False):
+ tp = search_cfg.tensor_model_parallel_size
+ cp = search_cfg.context_parallel_size
+ pp = search_cfg.pipeline_model_parallel_size
+ ep = search_cfg.expert_model_parallel_size
+ dp = search_cfg.data_parallel_size
+ mbs = search_cfg.micro_batch_size
+ num_experts = search_cfg.num_experts if search_cfg.num_experts else 1
+ communication = self.communication
+ model_config = communication.model_cfg
+ unsampled_profiling_info = []
+ operators, cp_exist_list, cp_diff_list, ep_exist_list, ep_diff_list, operator_not_found_list = \
+ self.operator.cal_operator_timer(search_cfg)
+
+ scal_flag = True if model_config.global_world_size > Hardware().num_devices else False
+ self.logger.debug("Total number of operators have been found is {0}".format((len(operators)
+ + len(cp_exist_list)
+ + len(cp_diff_list)
+ + len(ep_exist_list)
+ + len(ep_diff_list))))
+ if (re_profile_flag and profile_count[0] < 6 and
+ len(operator_not_found_list) / (len(operators) + len(cp_exist_list) + len(cp_diff_list) +
+ len(ep_exist_list) + len(ep_diff_list)) > 1):
+ unsampled_profiling_info = search_operator(working_dir, search_cfg, communication, profile_count, scal_flag)
+ operators, cp_exist_list, cp_diff_list, ep_exist_list, ep_diff_list, operator_not_found_list = \
+ self.operator.cal_operator_timer(search_cfg)
+ operator_time = 0.0
+ for operator in operators:
+ operator_time += operator.duration
+
+ cp_exist_time = 0.0
+ cp_diff_time = 0.0
+ if cp > 1:
+ for operator in cp_exist_list:
+ cp_exist_time = cp_exist_time + operator.duration
+ operator_time += cp_exist_time
+ if cp > 2:
+ for operator in cp_diff_list:
+ cp_diff_time = cp_diff_time + operator.duration
+ operator_time += cp_diff_time * (cp - 2)
+
+ ep_each_exist_time, ep_each_diff_time = 0.0, 0.0
+ num_experts = self.communication.model_cfg.num_experts
+ if num_experts and num_experts > 0:
+ for operator in ep_exist_list:
+ ep_each_exist_time = ep_each_exist_time + operator.duration
+ ep_each_exist_time = ep_each_exist_time / 2
+ for operator in ep_diff_list:
+ ep_each_diff_time = ep_each_diff_time + operator.duration
+ ep_each_diff_time = ep_each_diff_time / 2
+ if num_experts:
+ operator_time = operator_time + (num_experts / ep - 1) * ep_each_exist_time
+
+ # Convert to the total operator time for one micro_batch on a single node.
+ operator_time = (operator_time * 0.001)
+ return operator_time, unsampled_profiling_info
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator.py
new file mode 100644
index 0000000000000000000000000000000000000000..67f2ca6e34dfe6957520756cddc5727a08eb6106
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator.py
@@ -0,0 +1,314 @@
+import json
+import time
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.module.operator.operator_profile_get import OriginalProfileDataList
+from mindspeed.auto_tuning.module.operator.operator_note_cal import OperatorNoteList
+from mindspeed.auto_tuning.module.operator.operator_base_block import BaseBlock
+from mindspeed.auto_tuning.module.operator.operator_change_block_cp import CpBlock
+from mindspeed.auto_tuning.module.operator.operator_change_block_ep import EpBlock
+from mindspeed.auto_tuning.module.operator.operator_elemental import DictCalShape
+from mindspeed.auto_tuning.module.operator.operator_database import DataBase, Operator, OperatorHistory
+from mindspeed.auto_tuning.module.operator.operator_shape_analysis import separate_ep, separate_cp_tp
+from mindspeed.auto_tuning.module.operator.operator_shape_cal import (model_operator_with_tp,
+ model_operator_with_shape,
+ cal_new_shape_tce,
+ cal_operator_flops,
+ cal_operator_duration_with_shape)
+
+
+class OperatorPerformance(object):
+ """
+ Operator Performance modeling
+ 1. Test Run
+ 2. Profiling Parser
+ 3. Modeling [taking the results from the test run and placing them into all modules within
+ modeling for mathematical modeling estimation, then dynamically adjusting the test run configuration and
+ performing mathematical modeling estimation again [loop]]
+ 4. Return recommended configuration
+ """
+
+ def __init__(self, model_config: ModelConfig, working_dir: str):
+ self.db = DataBase(working_dir=working_dir)
+ self.origin_profile_data_list = OriginalProfileDataList()
+ self.model_config = model_config
+ self._logger = get_logger('operator')
+
+ self.base_block = BaseBlock()
+ self.cp_block = CpBlock()
+ self.ep_block = EpBlock()
+
+ self.dict_model = dict()
+
+ def model_operator_timer(self, profiling_results):
+ """
+ Model shape and duration based on the profiling result. Currently, all operator only takes one micro_batch,
+ no matter whether pp is enabled.
+ """
+ self.dict_model = dict()
+ # 1. get original data
+ self.origin_profile_data_list.get_origin_profile_data(profiling_results)
+ # 2. get base_block
+ self.base_block.get_block(self.origin_profile_data_list.data_list)
+ # 3. get change block
+ self.cp_block.get_block(self.origin_profile_data_list, self.base_block)
+ if self.origin_profile_data_list.data_list[0].config_info.num_experts:
+ self.ep_block.get_block(self.origin_profile_data_list, self.base_block)
+
+ st_time = time.time()
+ # 第 3 轮, Note数据表重新排序,按照新生成的index_name分类
+ operator_note_list = OperatorNoteList()
+ operator_note_list.get_operator_note(self)
+
+ self.get_history_db(operator_note_list.operator_note_list)
+ self._logger.info(f'-----------------------------------')
+ # 第 4 轮,基于operator_note_model建shape计算operator_model_dao
+ self.get_operator_model(operator_note_list.operator_note_dict)
+
+ self._logger.info("get operator_base_dao successful")
+ self._logger.info("total number of operator_note_dict: {}, dict_model {}, base_block {}, cp_block {}, "
+ "ep_block {}".format(len(operator_note_list.operator_note_dict), len(self.dict_model),
+ len(self.base_block.fw) + len(self.base_block.bw),
+ len(self.cp_block.fw) + len(self.cp_block.bw) + len(self.cp_block.re),
+ len(self.ep_block.fw) + len(self.ep_block.bw) + len(self.ep_block.re)))
+ self._logger.info(f'total time: {time.time() - st_time}')
+ self._logger.info(f'---------------------------【Add operator to db】---------------------------')
+
+ def get_history_db(self, operator_note_list):
+ self._logger.info("****************** duration_sum(ms) ***********************")
+ tplt = "{0:<1}\t{1:<1}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<8}\t{6:<8}\t{7:<8}"
+ self._logger.info(tplt.format('tp', 'dp', 'pp', 'cp', 'ep', 'duration_sum', 'operator_num', chr(12288)))
+ self._logger.info(f'--------------------------------------------------------------------------')
+ for (index, operator_note) in enumerate(operator_note_list):
+ operator_history_list = []
+ duration_sum = 0
+ operator_list = operator_note.fw + operator_note.bw
+ for operator in operator_list:
+ duration_sum += float(operator.duration)
+ operator_history = OperatorHistory(types=operator.type,
+ accelerator_core=operator.accelerator_core,
+ input_shape=operator.input_shape,
+ output_shape=operator.output_shape,
+ duration=operator.duration,
+ device=Hardware().device_type,
+ jit=operator.jit,
+ cann="8.0.RC2.alpha002",
+ driver="24.1.rc2.b030",
+ dtype=self.model_config.dtype.value[0])
+ operator_history_list.append(operator_history.convert_to_dict())
+ # 历史数据
+ self.db.operator_history_dao.insert_history(operator_history_list)
+ self._logger.info(tplt.format(
+ self.origin_profile_data_list.data_list[index].config_info.tp,
+ self.origin_profile_data_list.data_list[index].config_info.dp,
+ self.origin_profile_data_list.data_list[index].config_info.pp,
+ self.origin_profile_data_list.data_list[index].config_info.cp,
+ self.origin_profile_data_list.data_list[index].config_info.ep,
+ int(duration_sum), len(operator_note.fw), len(operator_note.bw), chr(12288)))
+
+ def get_operator_model(self, operator_note_dict):
+ operator_list = self.base_block.fw + self.base_block.bw
+ self.get_operator_model_dao(operator_list, operator_note_dict)
+ self.base_block.exist_cal_list = self.get_dict_base_shape(operator_list, operator_note_dict)
+
+ operator_list = self.cp_block.fw + self.cp_block.bw + self.cp_block.re
+ self.get_operator_model_dao(operator_list, operator_note_dict)
+ self.cp_block.exist_cal_list = self.get_dict_base_shape(operator_list, operator_note_dict)
+
+ operator_list = self.cp_block.diff_list.fw + self.cp_block.diff_list.bw + self.cp_block.diff_list.re
+ self.get_operator_model_dao(operator_list, operator_note_dict)
+ self.cp_block.diff_cal_list = self.get_dict_base_shape(operator_list, operator_note_dict)
+
+ operator_list = self.ep_block.fw + self.ep_block.bw + self.ep_block.re
+ self.get_operator_model_dao(operator_list, operator_note_dict)
+ self.ep_block.exist_cal_list = self.get_dict_base_shape(operator_list, operator_note_dict)
+
+ operator_list = self.ep_block.diff_list.fw + self.ep_block.diff_list.bw + self.ep_block.diff_list.re
+ self.get_operator_model_dao(operator_list, operator_note_dict)
+ self.ep_block.diff_cal_list = self.get_dict_base_shape(operator_list, operator_note_dict)
+
+
+ def get_dict_base_shape(self, operator_list, operator_note_dict):
+ re_list = []
+ for operator in operator_list:
+ index_name = operator.index_name
+ # cp 1 tp 1 2 4 8 -> shape_tp
+ # cp 2 tp 1 2 4 8 -> shape_tp
+ # shape_cp
+ # model the shape, according to the change between profiling result with different tp value, calculate the
+ # change formula for each position in the operator's shape
+ results = operator_note_dict[index_name]
+ # take ep first
+ result = separate_ep(results)
+ input_shape_cal, output_shape_cal = separate_cp_tp(result)
+ dict_shape = DictCalShape()
+ dict_shape.name = operator.name
+ dict_shape.index_name = index_name
+ dict_shape.accelerator_core = operator.accelerator_core
+ dict_shape.types = operator.type
+ dict_shape.input_cal = json.dumps(input_shape_cal)
+ dict_shape.output_cal = json.dumps(output_shape_cal)
+ re_list.append(dict_shape)
+ return re_list
+
+ def get_operator_model_dao(self, operator_list, operator_note_dict):
+ for operator in operator_list:
+ index_name = operator.index_name
+ # cp 1 tp 1 2 4 8 -> shape_tp
+ # cp 2 tp 1 2 4 8 -> shape_tp
+ # shape_cp
+ # model the shape, according to the change between profiling result with different tp value, calculate the
+ # change formula for each position in the operator's shape
+ results = operator_note_dict[index_name]
+ # input_shape_cal, has the same format as the shape array, with positive numbers representing unchanged
+ # positions, and negative numbers representing varying positions. Assuming the number is num, the variation
+ # rule is -num/tp.
+
+ # duration is modeled based on the same position operators and TPs. For operators with shape changes,
+ # it is initially observed that as TP increases [2, 4, 8], the duration decreases approximately by a
+ # factor of 2.
+ # tp_model_w is the number calculated when the duration decreases. Theoretically, it is the duration of the
+ # operator when tp=1. Therefore, when tp = 2, duration(2) = tp_model_w/2; tp_model_b is the redundancy
+ # coefficient.
+ tp_model_w, tp_model_b = model_operator_with_tp(results)
+
+ # duration is modeled based on the Flops calculated from the shape. For all operators,
+ # F(duration) = shape_model_w * Flops + shape_model_b.
+ history_results = self.db.operator_history_dao.get_by_types_and_accelerator_core(
+ operator.accelerator_core, operator.type)
+ shape_model_w, shape_model_b = model_operator_with_shape(history_results)
+ dict_shape = {
+ 'index_name': index_name,
+ 'accelerator_core': operator.accelerator_core,
+ 'model_w': float(tp_model_w),
+ 'model_b': float(tp_model_b),
+ 'shape_model_w': shape_model_w,
+ 'shape_model_b': shape_model_b,
+ }
+ accelerator_core_exist = False
+ if dict_shape["index_name"] in self.dict_model.keys():
+ for dict_temp in self.dict_model[dict_shape["index_name"]]:
+ if dict_temp['accelerator_core'] == dict_shape['accelerator_core']:
+ accelerator_core_exist = True
+ break
+ if not accelerator_core_exist:
+ self.dict_model[dict_shape["index_name"]].append(dict_shape)
+ else:
+ self.dict_model[dict_shape["index_name"]] = [dict_shape]
+
+ def getmodel_by_accelerator_core_and_index_name(self, accelerator_core, index_name):
+ for dict_shape in self.dict_model.get(index_name):
+ if dict_shape['accelerator_core'] == accelerator_core:
+ return dict_shape
+ self._logger.info("can not find the accelerator_core!")
+ return self.dict_model.get(index_name)[0]
+
+ def cal_operator_timer_bymodel(self, operator_list, search_cfg: SearchConfig, ratio=0.3,
+ re_profiling_flag=False):
+ operator_list_re = []
+
+ operator_total_num = len(operator_list)
+ operator_not_found = []
+ for operator_base in operator_list:
+ # Calculate input_shape and output_shape based on tp, cp, and ep.
+ input_shape = cal_new_shape_tce(operator_base.input_cal, search_cfg)
+ output_shape = cal_new_shape_tce(operator_base.output_cal, search_cfg)
+ # 1. search duration through operator_history based on input_shape and types
+ operators = self.db.operator_history_dao.get_by_types_and_input_shape(operator_base.types, input_shape)
+ if len(operators) > 0:
+ operator_list_re.append(Operator(name=operator_base.index_name, types=operator_base.types,
+ accelerator_core=operator_base.accelerator_core,
+ input_shape=input_shape,
+ output_shape=output_shape,
+ duration=operators[0].duration))
+
+ # 2. Predict the results based on the tp --- duration modeling results.
+ else:
+ operator_not_found.append([OperatorHistory(types=operator_base.types,
+ accelerator_core=operator_base.accelerator_core,
+ input_shape=input_shape,
+ output_shape=output_shape,
+ duration=0,
+ device=Hardware().device_type,
+ jit=int(self.model_config.jit_compile),
+ cann="8.0.RC2.alpha002",
+ driver="24.1.rc2.b030",
+ dtype=self.model_config.dtype.value[0]),
+ operator_base.index_name])
+
+ operator_not_found_total_num = len(operator_not_found)
+ if operator_not_found_total_num / operator_total_num > ratio and re_profiling_flag:
+ return operator_list_re, operator_not_found
+
+ else:
+ # If the proportion of missing operators is relatively low, by default, supplement the operators using
+ # linear interpolation.
+ if re_profiling_flag:
+ self._logger.info(
+ f'The total operator not found proportion is {operator_not_found_total_num / operator_total_num},'
+ f' there is no need for re profiling.')
+ for operator_cal_base in operator_not_found:
+ operator_base, operator_index_name = operator_cal_base
+ operator_model = self.getmodel_by_accelerator_core_and_index_name(
+ operator_base.accelerator_core, operator_index_name
+ )
+ flops = cal_operator_flops(operator_base.input_shape, operator_base.output_shape,
+ operator_base.types)
+
+ duration = cal_operator_duration_with_shape(operator_model["shape_model_w"],
+ operator_model["shape_model_b"],
+ flops)
+ operator_list_re.append(Operator(name=operator_index_name, types=operator_base.types,
+ accelerator_core=operator_base.accelerator_core,
+ input_shape=operator_base.input_shape,
+ output_shape=operator_base.output_shape,
+ duration=duration))
+ return operator_list_re, operator_not_found
+
+ def cal_operator_timer(self, search_cfg: SearchConfig) -> tuple:
+ """
+ External interface, returns the duration based on changes in tp.
+ """
+ # Obtain all operators of a model layer.
+ operator_not_found = []
+ if len(self.base_block.fw) == 0:
+ return [], [], [], 1, 1, 1
+ operator_base_list = self.base_block.exist_cal_list
+ operator_list, operator_not_found_list = self.cal_operator_timer_bymodel(operator_base_list,
+ search_cfg)
+ operator_not_found.extend(operator_not_found_list)
+ cp_operator_exist_list = self.cp_block.exist_cal_list
+ cp_operator_diff_list = self.cp_block.diff_cal_list
+ ep_operator_exist_list = self.ep_block.exist_cal_list
+ ep_operator_diff_list = self.ep_block.diff_cal_list
+ cp_exist_list, cp_exist_not_found_list = [], []
+ if len(cp_operator_exist_list) > 0:
+ cp_exist_list, cp_exist_not_found_list = self.cal_operator_timer_bymodel(
+ cp_operator_exist_list,
+ search_cfg)
+ if search_cfg.cp > 1:
+ operator_not_found.extend(cp_exist_not_found_list)
+ cp_diff_list, cp_diff_not_found_list = [], []
+ if len(cp_operator_diff_list) > 0:
+ cp_diff_list, cp_diff_not_found_list = self.cal_operator_timer_bymodel(cp_operator_diff_list,
+ search_cfg)
+ if search_cfg.cp > 1:
+ operator_not_found.extend(cp_diff_not_found_list)
+ ep_exist_list, ep_exist_not_found_list = [], []
+ if len(ep_operator_exist_list) > 0:
+ ep_exist_list, ep_exist_not_found_list = self.cal_operator_timer_bymodel(
+ ep_operator_exist_list, search_cfg
+ )
+ if search_cfg.ep and search_cfg.ep > 1:
+ operator_not_found.extend(ep_exist_not_found_list)
+ ep_diff_list, ep_diff_not_found_list = [], []
+ if len(ep_operator_diff_list) > 0:
+ ep_diff_list, ep_diff_not_found_list = self.cal_operator_timer_bymodel(ep_operator_exist_list,
+ search_cfg)
+ if search_cfg.ep and search_cfg.ep > 1:
+ operator_not_found.extend(ep_diff_not_found_list)
+ self.db.insert_not_found_list(operator_not_found)
+ return operator_list, cp_exist_list, cp_diff_list, ep_exist_list, ep_diff_list, operator_not_found
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_base_block.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_base_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..741224e4f39e8225c7ee1bc19c783291d2feec8c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_base_block.py
@@ -0,0 +1,91 @@
+import copy
+from mindspeed.auto_tuning.module.operator.operator_elemental import OperatorList, ProfileList, DictShape
+
+
+class Block(OperatorList):
+ def __init__(self):
+ super(Block, self).__init__()
+ self.exist_cal_list = []
+
+ @staticmethod
+ def longest_common_subsequence(list1, list2):
+ m, n = len(list1), len(list2)
+ dp = [[] for _ in range(m + 1)]
+ for index in range(m + 1):
+ dp[index] = [[] for _ in range(n + 1)]
+ for i in range(1, m + 1):
+ for j in range(1, n + 1):
+ if list1[i - 1].type == list2[j - 1].type:
+ dp[i][j] = dp[i - 1][j - 1].copy()
+ dp[i][j].append(list1[i - 1])
+ else:
+ if len(dp[i - 1][j]) > len(dp[i][j - 1]):
+ dp[i][j] = dp[i - 1][j].copy()
+ else:
+ dp[i][j] = dp[i][j - 1].copy()
+ return dp[m][n]
+
+ @staticmethod
+ def change_profilelist_into_dictshapelist_withindex(change_profile_list, change_operator_list):
+ for (index, item) in enumerate(change_profile_list.fw):
+ dict_shape_fw = DictShape()
+ dict_shape_fw.change_profile_into_dictshape(item, index)
+ change_operator_list.fw.append(dict_shape_fw)
+ for (index, item) in enumerate(change_profile_list.bw):
+ dict_shape_bw = DictShape()
+ dict_shape_bw.change_profile_into_dictshape(item, index)
+ change_operator_list.bw.append(dict_shape_bw)
+
+ @staticmethod
+ def change_profilelist_into_dictshapelist(change_profile_list, change_operator_list):
+ for (index, item) in enumerate(change_profile_list.fw):
+ dict_shape_fw = DictShape()
+ dict_shape_fw.change_profile_into_dictshape(item, -1)
+ change_operator_list.fw.append(dict_shape_fw)
+ for (index, item) in enumerate(change_profile_list.bw):
+ dict_shape_bw = DictShape()
+ dict_shape_bw.change_profile_into_dictshape(item, -1)
+ change_operator_list.bw.append(dict_shape_bw)
+
+
+class BaseBlock(Block):
+ def __init__(self):
+ super(BaseBlock, self).__init__()
+
+ def get_block(self, data_list):
+ profile_list = self.get_profile(data_list)
+ self.change_profilelist_into_dictshapelist_withindex(profile_list, self)
+
+ def get_profile(self, data_list):
+ profile_list = ProfileList()
+ for origin_profile_data in data_list:
+ fw = origin_profile_data.profile_list.fw
+ bw = origin_profile_data.profile_list.bw
+ if len(profile_list.fw) == 0:
+ profile_list.fw = copy.deepcopy(fw)
+ profile_list.bw = copy.deepcopy(bw)
+ else:
+ profile_list.fw = self.longest_common_subsequence(profile_list.fw, fw)
+ profile_list.bw = self.longest_common_subsequence(profile_list.bw, bw)
+ return profile_list
+ #
+
+ #
+ def reset_index_name(self, list1, list2):
+ m, n = len(list1), len(list2)
+ i, j = 0, 0
+ index = 0
+ last_mat = (0, 0)
+ first_mat = 0
+ while 1:
+ list1, i, j, last_mat, first_mat = self.reset_index_name_single(list1, list2, i, j, last_mat)
+ if j < n - 1 and index < 3:
+ # Skip a base operator.
+ index += 1
+ i = last_mat[0] + 1
+ j += 1
+ else:
+ break
+ if first_mat == 0:
+ first_mat = last_mat[0] + 1
+ return list1, first_mat
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..2da677bacecf4a3baeba8833f5826489727e7c29
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block.py
@@ -0,0 +1,163 @@
+import copy
+from mindspeed.auto_tuning.module.operator.operator_base_block import Block
+from mindspeed.auto_tuning.module.operator.operator_elemental import (OperatorList, ChangeList,
+ ChangeOperatorList)
+
+
+class ChangeBlock(Block):
+ def __init__(self):
+ super(ChangeBlock, self).__init__()
+ self.diff_list = OperatorList()
+ self.diff_cal_list = []
+
+ @staticmethod
+ def get_operator_longest_common_subsequence(list1, list2):
+ m, n = len(list1), len(list2)
+ dp = [[] for _ in range(m + 1)]
+ for index in range(m + 1):
+ dp[index] = [[] for _ in range(n + 1)]
+ for i in range(1, m + 1):
+ for j in range(1, n + 1):
+ if list1[i - 1].type == list2[j - 1].type:
+ dp[i][j] = dp[i - 1][j - 1].copy()
+ dp[i][j].append(list1[i - 1])
+ else:
+ if len(dp[i - 1][j]) > len(dp[i][j - 1]):
+ dp[i][j] = dp[i - 1][j].copy()
+ else:
+ dp[i][j] = dp[i][j - 1].copy()
+ return dp[m][n]
+
+ def get_profile(self, origin_profile_data_list):
+ change_profile_list = ChangeList()
+ change_operator_list = ChangeOperatorList()
+ for origin_profile_data in origin_profile_data_list:
+ fw = origin_profile_data.operator_list.fw
+ bw = origin_profile_data.operator_list.bw
+ cp = origin_profile_data.config_info.cp
+ dp = origin_profile_data.config_info.dp
+ pp = origin_profile_data.config_info.pp
+ ep = origin_profile_data.config_info.ep
+ num_experts = origin_profile_data.config_info.num_experts
+ self.get_profile_info(cp, change_profile_list, fw, bw)
+ self.get_change_operator(change_profile_list, change_operator_list)
+ return change_operator_list
+
+ def get_profile_info(self, change_num, change_profile_list, fw, bw):
+ if change_num == 2:
+ if len(change_profile_list.list_2.fw) == 0:
+ change_profile_list.list_2.fw = copy.deepcopy(fw)
+ change_profile_list.list_2.bw = copy.deepcopy(bw)
+ else:
+ change_profile_list.list_2.fw = self.longest_common_subsequence(change_profile_list.list_2.fw, fw)
+ change_profile_list.list_2.bw = self.longest_common_subsequence(change_profile_list.list_2.bw, bw)
+ if change_num == 4:
+ if len(change_profile_list.list_4.fw) == 0:
+ change_profile_list.list_4.fw = copy.deepcopy(fw)
+ change_profile_list.list_4.bw = copy.deepcopy(bw)
+ else:
+ change_profile_list.list_4.fw = self.longest_common_subsequence(change_profile_list.list_4.fw, fw)
+ change_profile_list.list_4.bw = self.longest_common_subsequence(change_profile_list.list_4.bw, bw)
+ if len(change_profile_list.list_2.fw) * len(change_profile_list.list_4.fw) > 0:
+ change_profile_list.list_2.fw = self.longest_common_subsequence(change_profile_list.list_2.fw,
+ change_profile_list.list_4.fw)
+ change_profile_list.list_2.bw = self.longest_common_subsequence(change_profile_list.list_2.bw,
+ change_profile_list.list_4.bw)
+ return
+
+ def get_change_operator(self, change_profile_list, change_operator_list):
+ self.change_profilelist_into_dictshapelist(change_profile_list.list_2, change_operator_list.list_2)
+ self.change_profilelist_into_dictshapelist(change_profile_list.list_4, change_operator_list.list_4)
+
+ def get_exist_block(self, change_operator_list, base_block, index_id):
+ return
+
+ # calculate the recompute list, 1 for forward, 2 for backward
+ def get_re_block(self, list1, list2):
+ m, n = len(list1), len(list2)
+ list_re = []
+ list_bw = []
+ i, j = 0, 0
+ while i < m:
+ if j < n and list1[i].type == list2[j].type:
+ list_re.append(list1[i])
+ i += 1
+ j += 1
+ else:
+ list_bw.append(list1[i])
+ i += 1
+ return list_re, list_bw
+
+ def comp_with_get_diff_list(self, list1, list2, index_id):
+ return
+
+ #
+
+ def reset_index_name(self, list1, list2):
+ m, n = len(list1), len(list2)
+ i, j = 0, 0
+ index = 0
+ last_mat = (0, 0)
+ first_mat = 0
+ while 1:
+ list1, i, j, last_mat, first_mat = self.reset_index_name_single(list1, list2, i, j, last_mat)
+ if j < n - 1 and index < 3:
+ # Skip a base operator
+ index += 1
+ i = last_mat[0] + 1
+ j += 1
+ else:
+ break
+ if first_mat == 0:
+ first_mat = last_mat[0] + 1
+ return list1, first_mat
+
+ def reset_index_name_single(self, list1, list2, i, j, last_mat):
+ m, n = len(list1), len(list2)
+ dp_flag = False
+ mat_flag = False
+ disperses_list = []
+ first_mat = 0
+ continue_num = 0
+ while i < m:
+ if j < n and list1[i].index_name == '':
+ if list1[i].type == list2[j].type:
+ mat_flag = True
+ if dp_flag:
+ disperses_list.append(i)
+ continue_num += 1
+ if continue_num > 5 or i >= m - 1:
+ dp_flag = False
+ continue_num = 0
+ list1 = self.attract_list(disperses_list, list1, i)
+ disperses_list = []
+ list1[i].index_name = list2[j].index_name
+ last_mat = (i, j)
+ j += 1
+ else:
+ if mat_flag and first_mat == 0:
+ first_mat = i
+ disperses_list.append(i)
+ continue_num = 0
+ dp_flag = True
+ elif dp_flag and len(disperses_list) > 0:
+ while i < m and list1[i].index_name == '':
+ i += 1
+ i = i - 1
+ dp_flag = False
+ continue_num = 0
+ list1 = self.attract_list(disperses_list, list1, i)
+ disperses_list = []
+ i += 1
+ return list1, i, j, last_mat, first_mat
+
+ def attract_list(self, disperses_list, list1, i):
+ index = 0
+ len_dp = len(disperses_list)
+ while (i - index >= 0 and len_dp - index - 1 >= 0 and
+ list1[i - index].type == list1[disperses_list[len_dp - index - 1]].type):
+ temp = list1[disperses_list[len_dp - index - 1]].index_name
+ list1[disperses_list[len_dp - index - 1]].index_name = ''
+ list1[i - index].index_name = temp
+ index += 1
+ return list1
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block_cp.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block_cp.py
new file mode 100644
index 0000000000000000000000000000000000000000..89ffdf87ed46845c2c968be6305345db0ca9e73f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block_cp.py
@@ -0,0 +1,126 @@
+from mindspeed.auto_tuning.module.operator.operator_change_block import ChangeBlock
+from mindspeed.auto_tuning.module.operator.operator_elemental import (DictShape, ChangeList,
+ ChangeOperatorList)
+
+
+class CpBlock(ChangeBlock):
+ def __init__(self):
+ super(CpBlock, self).__init__()
+
+ def get_block(self, origin_profile_data_list, base_block):
+ change_operator_list = self.get_profile(origin_profile_data_list)
+ index_id = 1000
+ self.get_exist_block(change_operator_list, base_block, index_id)
+ self.get_diff_block(change_operator_list, -1)
+
+ def get_profile(self, origin_profile_data_list):
+ change_profile_list = ChangeList()
+ change_operator_list = ChangeOperatorList()
+ for origin_profile_data in origin_profile_data_list.data_list:
+ fw = origin_profile_data.profile_list.fw
+ bw = origin_profile_data.profile_list.bw
+ cp = origin_profile_data.config_info.cp
+ self.get_profile_info(cp, change_profile_list, fw, bw)
+ self.get_change_operator(change_profile_list, change_operator_list)
+ return change_operator_list
+
+ def get_change_operator(self, change_profile_list, change_operator_list):
+ self.change_profilelist_into_dictshapelist(change_profile_list.list_2, change_operator_list.list_2)
+ self.change_profilelist_into_dictshapelist(change_profile_list.list_4, change_operator_list.list_4)
+
+ def get_exist_block(self, change_operator_list, base_block, index_id):
+ self.fw = self.comp_with_get_diff_list(change_operator_list.list_2.fw, base_block.fw, index_id)
+ self.bw = self.comp_with_get_diff_list(change_operator_list.list_2.bw, base_block.bw, index_id + 500)
+ # recompute
+ if len(self.bw) > len(self.fw):
+ self.re, self.bw = self.get_re_block(self.bw, self.fw)
+
+ def get_diff_block(self, change_operator_list, index_id):
+ self.diff_list.fw = self.comp_with_get_diff_list(change_operator_list.list_4.fw, change_operator_list.list_2.fw,
+ -1)
+ self.diff_list.bw = self.comp_with_get_diff_list(change_operator_list.list_4.bw, change_operator_list.list_2.bw,
+ index_id)
+ self.diff_list.fw = self.get_operator_longest_common_subsequence(self.fw, self.diff_list.fw)
+ self.diff_list.re = self.get_operator_longest_common_subsequence(self.re, self.diff_list.bw)
+ self.diff_list.bw = self.get_operator_longest_common_subsequence(self.bw, self.diff_list.bw)
+
+ def get_re_block(self, list1, list2):
+ m, n = len(list1), len(list2)
+ list_re = []
+ list_bw = []
+ i, j = 0, 0
+ while i < m:
+ if j < n and list1[i].type == list2[j].type:
+ list_re.append(list1[i])
+ i += 1
+ j += 1
+ else:
+ list_bw.append(list1[i])
+ i += 1
+ return list_re, list_bw
+
+ def comp_with_get_diff_list(self, list1, list2, index_id):
+ # Align first.
+ list1, first_mat = self.reset_index_name(list1, list2)
+ diff_info = []
+ diff_index = index_id
+ for item in list1:
+ if item.index_name == '':
+ dict_shape = DictShape()
+ if diff_index != -1:
+ item.index_name = str(diff_index) + item.type
+ diff_index += 1
+ else:
+ item.index_name = ''
+ dict_shape.name = item.name
+ dict_shape.type = item.type
+ dict_shape.accelerator_core = item.accelerator_core
+ dict_shape.index_name = item.index_name
+ diff_info.append(dict_shape)
+ return diff_info
+
+ def reset_index_diff_cp(self, list1, list2, diff_flag, cp_num):
+ m, n = len(list1), len(list2)
+ if n < 2 or m < 2:
+ return list1
+ i, j = diff_flag - 1, n
+ index = 0
+ last_mat = (diff_flag - 1, n)
+ temp = -1, -1
+ while j >= n - 2 and last_mat[0] + n < m and last_mat != temp:
+ cp_num -= 1
+ if cp_num <= 0:
+ break
+ # Ensure that the entire process has been gone through once.
+ # Ensure that there is enough remaining space for one round of re-matching.
+ j = 0
+ i = last_mat[0] + 1
+ index = 0
+ temp = last_mat
+ # Restart a round of matching.
+ list1, list2, i, j, last_mat = self.restart_mat(list1, list2, i, j, last_mat)
+ return list1
+
+ @staticmethod
+ def restart_mat(list1, list2, i, j, last_mat):
+ m, n = len(list1), len(list2)
+ flag = 0
+ max_miss = 3
+ while i < m and j < n:
+ if j < n and list1[i].index_name == '' and list1[i].type == list2[j].type:
+ list1[i].index_name = list2[j].index_name
+ last_mat = (i, j)
+ i += 1
+ j += 1
+ else:
+ if i + 1 < m and list1[i + 1].index_name == '' and list1[i + 1].type == list2[j].type:
+ i += 1
+ elif j + 1 < n and list1[i].index_name == '' and list1[i].type == list2[j + 1].type:
+ j += 1
+ else:
+ i += 1
+ j += 1
+ max_miss = max_miss - 1
+ if max_miss <= 0:
+ return list1, list2, i, j, (0, 0)
+ return list1, list2, i, j, last_mat
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block_ep.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block_ep.py
new file mode 100644
index 0000000000000000000000000000000000000000..92fa4d572ad2509885fd3e6837b399774da622e1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_change_block_ep.py
@@ -0,0 +1,195 @@
+import copy
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.module.operator.operator_change_block import ChangeBlock
+from mindspeed.auto_tuning.module.operator.operator_elemental import (DictShape, ChangeList,
+ ChangeOperatorList)
+
+
+class EpBlock(ChangeBlock):
+ def __init__(self):
+ super(EpBlock, self).__init__()
+ self._logger = get_logger('ep_block')
+
+ def get_block(self, origin_profile_data_list, base_block):
+ change_operator_list = self.get_profile(origin_profile_data_list)
+ index_id = 2000
+ self.get_exist_block(change_operator_list, base_block, index_id)
+ self.get_diff_block(change_operator_list, index_id)
+ self.diff_list.bw.pop()
+
+ def get_profile(self, origin_profile_data_list):
+ change_profile_list = ChangeList()
+ change_operator_list = ChangeOperatorList()
+ for origin_profile_data in origin_profile_data_list.data_list:
+ fw = origin_profile_data.profile_list.fw
+ bw = origin_profile_data.profile_list.bw
+ ep = origin_profile_data.config_info.ep
+ num_experts = origin_profile_data.config_info.num_experts
+ self.get_profile_info(num_experts / ep, change_profile_list, fw, bw)
+ self.get_change_operator(change_profile_list, change_operator_list)
+ return change_operator_list
+
+ def get_profile_info(self, change_num, change_profile_list, fw, bw):
+ if change_num == 2:
+ if len(change_profile_list.list_2.fw) == 0:
+ change_profile_list.list_2.fw = copy.deepcopy(fw)
+ change_profile_list.list_2.bw = copy.deepcopy(bw)
+ else:
+ change_profile_list.list_2.fw = self.longest_common_subsequence(change_profile_list.list_2.fw, fw)
+ change_profile_list.list_2.bw = self.longest_common_subsequence(change_profile_list.list_2.bw, bw)
+ if change_num == 4:
+ if len(change_profile_list.list_4.fw) == 0:
+ change_profile_list.list_4.fw = copy.deepcopy(fw)
+ change_profile_list.list_4.bw = copy.deepcopy(bw)
+ else:
+ change_profile_list.list_4.fw = self.longest_common_subsequence(change_profile_list.list_4.fw, fw)
+ change_profile_list.list_4.bw = self.longest_common_subsequence(change_profile_list.list_4.bw, bw)
+ if len(change_profile_list.list_2.fw) * len(change_profile_list.list_4.fw) > 0:
+ change_profile_list.list_2.fw = self.longest_common_subsequence(change_profile_list.list_2.fw,
+ change_profile_list.list_4.fw)
+ change_profile_list.list_2.bw = self.longest_common_subsequence(change_profile_list.list_2.bw,
+ change_profile_list.list_4.bw)
+ return
+
+ def get_change_operator(self, change_profile_list, change_operator_list):
+ self.change_profilelist_into_dictshapelist(change_profile_list.list_2, change_operator_list.list_2)
+ self.change_profilelist_into_dictshapelist(change_profile_list.list_4, change_operator_list.list_4)
+
+ # Compare the longest subsequence of 1 and 2, return the values of 1.
+
+ def get_exist_block(self, change_operator_list, base_block, index_id):
+ self.fw = self.comp_with_get_diff_list(change_operator_list.list_2.fw, base_block.fw, index_id)
+ self.bw = self.comp_with_get_diff_list(change_operator_list.list_2.bw, base_block.bw, index_id + 500)
+ # recompute
+ if len(self.bw) > len(self.fw):
+ self.re, self.bw = self.get_re_block(self.bw, self.fw)
+ return
+
+ def get_diff_block(self, change_operator_list, index_id):
+ if not change_operator_list.list_2.fw:
+ self._logger.warning("warning:缺少了并行配置为 EP=2 的数据,从而无法得到EPdiff")
+ return
+ self.diff_list.fw = self.comp_with_get_diff_list(change_operator_list.list_4.fw, change_operator_list.list_2.fw,
+ -1)
+ self.diff_list.bw = self.comp_with_get_diff_list(change_operator_list.list_4.bw, change_operator_list.list_2.bw,
+ -1)
+ # recompute
+ if len(self.diff_list.bw) > len(self.diff_list.fw):
+ self.diff_list.re, self.diff_list.bw = self.get_re_block(self.diff_list.bw, self.diff_list.fw)
+ self.diff_list.re = self.comp_with_get_diff_list(self.diff_list.re, self.re, -1)
+ self.diff_list.fw = self.comp_with_get_diff_list(self.diff_list.fw, self.fw, -1)
+ self.diff_list.bw = self.comp_with_get_diff_list(self.diff_list.bw, self.bw, -1)
+ return
+
+ # calculate the recompute list, 1 for forward, 2 for backward
+ def get_re_block(self, list1, list2):
+ m, n = len(list1), len(list2)
+ list_re = []
+ list_bw = []
+ i, j = 0, 0
+ while i < m:
+ if j < n and list1[i].type == list2[j].type:
+ list_re.append(list1[i])
+ i += 1
+ j += 1
+ else:
+ list_bw.append(list1[i])
+ i += 1
+ return list_re, list_bw
+
+ # Align list1 with list2
+ def comp_with_get_diff_list(self, list1, list2, index_id):
+ # Align first
+ list1, first_mat = self.reset_index_name(list1, list2)
+ diff_info = []
+ diff_index = index_id
+ for item in list1:
+ if item.index_name == '':
+ dict_shape = DictShape()
+ if diff_index != -1:
+ item.index_name = str(diff_index) + item.type
+ diff_index += 1
+ else:
+ item.index_name = ''
+ dict_shape.name = item.name
+ dict_shape.type = item.type
+ dict_shape.accelerator_core = item.accelerator_core
+ dict_shape.index_name = item.index_name
+ diff_info.append(dict_shape)
+ return diff_info
+
+ def get_exist_base_ep(self):
+ self.fw = self.get_diff_list_without_index(self.fw, self.diff_list.fw)
+ self.re = self.get_diff_list_without_index(self.re, self.diff_list.re)
+ self.bw = self.get_diff_list_without_index(self.bw, self.diff_list.bw)
+
+ # Subtract the subsequence of 2 from 1.
+
+ def get_diff_list_without_index(self, list1, list2):
+ list_comm = self.get_operator_longest_common_subsequence(list1, list2)
+ m, n = len(list1), len(list_comm)
+ flag = 0
+ max_miss = 3
+ diff_list = []
+ i, j = 0, 0
+ while i < m and j < n:
+ if list1[i].type == list_comm[j].type:
+ i += 1
+ j += 1
+ else:
+ diff_list.append(list1[i])
+ i += 1
+ if i < m:
+ diff_list.append(list1[i])
+ i += 1
+ return diff_list
+
+ def reset_index_diff_ep(self, list1, list2, ep_diff_num):
+ m, n = len(list1), len(list2)
+ i, j = 0, 0
+ index = 0
+ last_mat, this_mat = (0, 0), (-1, 0)
+ while 1:
+ # Restart a round
+ if this_mat[0] + n > m or this_mat == last_mat or ep_diff_num <= 0:
+ break
+ last_mat = this_mat
+ list1, i, j, this_mat = self.reset_index_name_single_ep(list1, list2, i, j, last_mat)
+ ep_diff_num -= 1
+ if j < n - 1 and index < 3:
+ # skip one base operator
+ index += 1
+ i = this_mat[0] + 1
+ j += 1
+ else:
+ j = 0
+ i = this_mat[0] + 1
+ return list1
+
+ def reset_index_name_single_ep(self, list1, list2, i, j, start_mat):
+ m, n = len(list1), len(list2)
+ dp_flag = True
+ disperses_list = []
+ continue_num = 0
+ last_mat = start_mat
+ while i < m:
+ if j < n and list1[i].index_name == '':
+ if list1[i].type == list2[j].type:
+ if j == 0 and start_mat[0] > 0 and i - start_mat[0] > 3:
+ break
+ if dp_flag:
+ disperses_list.append(i)
+ continue_num += 1
+ if continue_num > 5 or j + 1 == n:
+ dp_flag = False
+ continue_num = 0
+ list1 = self.attract_list(disperses_list, list1, i)
+ disperses_list = []
+ list1[i].index_name = list2[j].index_name
+ last_mat = (i, j)
+ j += 1
+ else:
+ continue_num = 0
+ dp_flag = True
+ i += 1
+ return list1, i, j, last_mat
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_database.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_database.py
new file mode 100644
index 0000000000000000000000000000000000000000..18899eea935a2ec27ebed423e97fd66f89149145
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_database.py
@@ -0,0 +1,195 @@
+import os
+
+from sqlalchemy import Column, Integer, String, UniqueConstraint, text, desc
+from sqlalchemy import create_engine, Float
+from sqlalchemy.orm import sessionmaker, declarative_base
+
+Base = declarative_base()
+BaseHistory = declarative_base()
+
+
+class DataBase:
+ def __init__(self, working_dir: str):
+ db_uri_history = f'sqlite:///{os.path.join(working_dir, "operator_history.db")}'
+ db_connection_history = DBConnection(db_uri_history)
+ self.operator_history_dao = OperatorHistoryDAO(db_connection_history)
+ BaseHistory.metadata.create_all(db_connection_history.engine)
+ db_uri_different = f'sqlite:///{os.path.join(working_dir, "operator_different.db")}'
+ db_connection_different = DBConnection(db_uri_different)
+ self.operator_different_dao = OperatorHistoryDAO(db_connection_different)
+ BaseHistory.metadata.create_all(db_connection_different.engine)
+ db_uri_profiling = f'sqlite:///{os.path.join(working_dir, "operator_profiling.db")}'
+ db_connection_profiling = DBConnection(db_uri_profiling)
+ self.operator_profiling_dao = OperatorHistoryDAO(db_connection_profiling)
+ BaseHistory.metadata.create_all(db_connection_profiling.engine)
+
+ def insert_not_found_list(self, operator_list):
+ operator_different_list = []
+ for operator in operator_list:
+ operator_different_list.append(operator[0].convert_to_dict())
+ self.operator_different_dao.insert_history(operator_different_list)
+
+
+class OperatorHistory(BaseHistory):
+ __tablename__ = 'operator_history'
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ types = Column(String)
+ accelerator_core = Column(String)
+ input_shape = Column(String)
+ output_shape = Column(String)
+ duration = Column(Float)
+ device = Column(String)
+ jit = Column(Integer)
+ cann = Column(String)
+ driver = Column(String)
+ dtype = Column(String)
+ reverse1 = Column(String)
+ __table_args__ = (
+ UniqueConstraint('types', 'accelerator_core', 'input_shape', 'output_shape', 'device', 'jit',
+ 'cann', 'driver', 'dtype', name='unique_operator'),)
+
+ def __init__(self, types, accelerator_core, input_shape, output_shape, duration, device, jit, cann, driver, dtype):
+ self.types = types
+ self.accelerator_core = accelerator_core
+ self.input_shape = input_shape
+ self.output_shape = output_shape
+ self.duration = duration
+ self.device = device
+ self.jit = jit
+ self.cann = cann
+ self.driver = driver
+ self.dtype = dtype
+ self.reverse1 = "None"
+
+ def __str__(self):
+ rt = []
+ rt.append(f"{'Operator Types':<30}{str(self.types):<40}")
+ rt.append(f"{'accelerator_core':<30}{str(self.accelerator_core):<40}")
+ rt.append(f"{'input_shape':<30}{str(self.input_shape):<40}")
+ rt.append(f"{'output_shape':<30}{str(self.output_shape):<40}")
+ rt.append(f"{'duration':<30}{str(self.duration):<40}")
+ rt.append(f"{'device':<30}{str(self.device):<40}")
+ rt.append(f"{'jit':<30}{str(self.jit):<40}")
+ rt.append(f"{'cann':<30}{str(self.cann):<40}")
+ rt.append(f"{'driver':<30}{str(self.driver):<40}")
+ rt.append(f"{'dtype':<30}{str(self.dtype):<40}")
+ return "\n".join(rt)
+
+ def convert_to_dict(self):
+ return {
+ 'types': self.types,
+ 'accelerator_core': self.accelerator_core,
+ 'input_shape': self.input_shape,
+ 'output_shape': self.output_shape,
+ 'duration': self.duration,
+ 'device': self.device,
+ 'jit': self.jit,
+ 'cann': self.cann,
+ 'driver': self.driver,
+ 'dtype': self.dtype,
+ 'reverse1': self.reverse1
+ }
+
+
+class OperatorHistoryDAO(object):
+ def __init__(self, db_connection):
+ self.db_connection = db_connection
+
+ def insert_history(self, data_list):
+ def insert_data(session, dict_list):
+ for data in dict_list:
+ update_query = text('''
+ UPDATE operator_history
+ SET duration = (duration + :duration) / 2
+ WHERE types = :types AND accelerator_core = :accelerator_core AND input_shape = :input_shape AND
+ output_shape = :output_shape AND device = :device AND jit = :jit AND cann = :cann AND
+ driver = :driver AND dtype = :dtype
+ ''')
+ result = session.execute(update_query, data)
+ if result.rowcount == 0:
+ query = text('''
+ INSERT INTO operator_history
+ (types, accelerator_core, input_shape, output_shape, duration, device, jit, cann, driver, dtype, reverse1)
+ SELECT :types, :accelerator_core, :input_shape, :output_shape, :duration, :device, :jit, :cann, :driver, :dtype, :reverse1
+ WHERE NOT EXISTS(
+ SELECT 1 FROM operator_history WHERE
+ types = :types AND accelerator_core = :accelerator_core AND input_shape = :input_shape AND
+ output_shape = :output_shape AND device = :device AND jit = :jit AND cann = :cann AND
+ driver = :driver AND dtype = :dtype
+ )
+ ''')
+ session.execute(query, data)
+ session.commit()
+
+ self.db_connection.execute(insert_data, data_list)
+
+ def get_by_types_and_input_shape(self, types, input_shape):
+ def get(session, key1, key2):
+ results = session.query(OperatorHistory).filter_by(types=key1, input_shape=key2).all()
+ objects = [OperatorHistory(types=result.types,
+ accelerator_core=result.accelerator_core,
+ input_shape=result.input_shape,
+ output_shape=result.output_shape,
+ duration=result.duration,
+ device=result.device,
+ jit=result.jit,
+ cann=result.cann,
+ driver=result.driver,
+ dtype=result.dtype) for result in results]
+ return objects
+
+ return self.db_connection.execute(get, types, input_shape)
+
+ def get_by_types_and_accelerator_core(self, accelerator_core, types):
+ def get(session, key1, key2):
+ results = session.query(OperatorHistory).filter_by(accelerator_core=key1, types=key2).all()
+ objects = [OperatorHistory(types=result.types,
+ accelerator_core=result.accelerator_core,
+ input_shape=result.input_shape,
+ output_shape=result.output_shape,
+ duration=result.duration,
+ device=result.device,
+ jit=result.jit,
+ cann=result.cann,
+ driver=result.driver,
+ dtype=result.dtype) for result in results]
+ return objects
+
+ return self.db_connection.execute(get, accelerator_core, types)
+
+
+class Operator(object):
+
+ def __init__(self, name, types, accelerator_core, input_shape, output_shape, duration):
+ self.name = name
+ self.types = types
+ self.accelerator_core = accelerator_core
+ self.input_shape = input_shape
+ self.output_shape = output_shape
+ self.duration = duration
+
+ def print_to_string(self):
+ print("name: {}, input_shape: {}, output_shape: {}, duration: {}".format(
+ self.name,
+ self.input_shape,
+ self.output_shape,
+ self.duration
+ ))
+
+
+class DBConnection:
+ def __init__(self, db_url):
+ self.engine = create_engine(db_url)
+ self.Session = sessionmaker(bind=self.engine)
+
+ def execute(self, func, *args, **kwargs):
+ session = self.Session()
+ try:
+ result = func(session, *args, **kwargs)
+ session.commit()
+ return result
+ except Exception as e:
+ session.rollback()
+ raise e
+ finally:
+ session.close()
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_elemental.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_elemental.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cd6b8def9c11a761584bd8999d660833697aab6
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_elemental.py
@@ -0,0 +1,178 @@
+from mindspeed.auto_tuning.utils.logger import get_logger
+
+
+class ProfileList(object):
+ def __init__(self):
+ self.fw = []
+ self.bw = []
+ self._logger = get_logger('ProfileList')
+
+ # test 1
+
+ def reset_index_name(self, list1, list2):
+ M, N = len(list1), len(list2)
+ i, j = 0, 0
+ index = 0
+ last_mat = (0, 0)
+ first_mat = 0
+ while 1:
+ list1, i, j, last_mat, first_mat = self.reset_index_name_single(list1, list2, i, j, last_mat)
+ if j < N - 1 and index < 3:
+ # skip a base operator
+ index += 1
+ i = last_mat[0] + 1
+ j += 1
+ else:
+ break
+ if first_mat == 0:
+ first_mat = last_mat[0] + 1
+ return list1, first_mat
+
+ def reset_index_name_single(self, list1, list2, i, j, last_mat):
+ M, N = len(list1), len(list2)
+ dp_flag = False
+ mat_flag = False
+ disperses_list = []
+ first_mat = 0
+ continue_num = 0
+ while i < M:
+ if j < N and list1[i].index_name == '':
+ if list1[i].type == list2[j].type:
+ mat_flag = True
+ if dp_flag:
+ disperses_list.append(i)
+ continue_num += 1
+ if continue_num > 5 or i >= M - 1:
+ dp_flag = False
+ continue_num = 0
+ list1 = self.attract_list(disperses_list, list1, i)
+ disperses_list = []
+ list1[i].index_name = list2[j].index_name
+ last_mat = (i, j)
+ j += 1
+ else:
+ if mat_flag and first_mat == 0:
+ first_mat = i
+ disperses_list.append(i)
+ continue_num = 0
+ dp_flag = True
+ elif dp_flag and len(disperses_list) > 0:
+ while i < M and list1[i].index_name == '':
+ i += 1
+ i = i - 1
+ dp_flag = False
+ continue_num = 0
+ list1 = self.attract_list(disperses_list, list1, i)
+ disperses_list = []
+ i += 1
+ return list1, i, j, last_mat, first_mat
+
+ def attract_list(self, disperses_list, list1, i):
+ index = 0
+ len_dp = len(disperses_list)
+ while i - index >= 0 and len_dp - index - 1 >= 0 and list1[i - index].type == list1[
+ disperses_list[len_dp - index - 1]].type:
+ temp = list1[disperses_list[len_dp - index - 1]].index_name
+ list1[disperses_list[len_dp - index - 1]].index_name = ''
+ list1[i - index].index_name = temp
+ index += 1
+ return list1
+
+ def print_list(self):
+ self.print_list_fw()
+ self.print_list_bw()
+
+ def print_list_fw(self):
+ self._logger.debug("fw")
+ for item in self.fw:
+ self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
+
+ def print_list_bw(self):
+ self._logger.debug("bw")
+ for item in self.bw:
+ self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
+
+
+class ChangeList:
+ def __init__(self):
+ super(ChangeList, self).__init__()
+ self.list_2 = ProfileList()
+ self.list_4 = ProfileList()
+
+
+class ChangeOperatorList:
+ def __init__(self):
+ super(ChangeOperatorList, self).__init__()
+ self.list_2 = ProfileList()
+ self.list_4 = ProfileList()
+
+
+class DictShape(object):
+ def __init__(self):
+ self.name = ""
+ self.type = ""
+ self.accelerator_core = ""
+ self.index_name = ""
+
+ def change_profile_into_dictshape(self, item, index):
+ self.name = item.name
+ self.type = item.type
+ self.accelerator_core = item.accelerator_core
+ if index == -1:
+ self.index_name = ""
+ else:
+ self.index_name = str(index) + str(item.type)
+
+
+class DictModelShape(DictShape):
+ def __init__(self):
+ super(DictModelShape, self).__init__()
+ self.model_w = 0.0
+ self.model_b = 0.0
+ self.shape_model_w = 0.0
+ self.shape_model_b = 0.0
+
+
+class DictCalShape(DictShape):
+ def __init__(self):
+ super(DictCalShape, self).__init__()
+ self.input_cal = 0.0
+ self.output_cal = 0.0
+
+
+class OperatorList(ProfileList):
+ def __init__(self):
+ super(OperatorList, self).__init__()
+ self.fw = []
+ self.bw = []
+ self.re = []
+ self._logger = get_logger('operator_list')
+
+ def print_list(self):
+ self.print_list_fw()
+ self.print_list_bw()
+ self.print_list_re()
+
+ def print_list_fw(self):
+ self._logger.debug("fw")
+ for item in self.fw:
+ self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
+
+ def print_list_bw(self):
+ self._logger.debug("bw")
+ for item in self.bw:
+ self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
+
+ def print_list_re(self):
+ self._logger.debug("re")
+ for item in self.re:
+ self._logger.debug("name", item.name, "type", item.type, "index_name", item.index_name)
+
+
+class OperatorDetailList(OperatorList):
+ def __init__(self):
+ super(OperatorDetailList, self).__init__()
+ self.fw = []
+ self.bw = []
+ self.re = []
+ self.all = []
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_note_cal.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_note_cal.py
new file mode 100644
index 0000000000000000000000000000000000000000..392baf5623815ba2ff6aeb7d8526a62ce4fc4c06
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_note_cal.py
@@ -0,0 +1,115 @@
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.module.operator.operator_elemental import DictShape, ProfileList
+from mindspeed.auto_tuning.module.operator.operator_shape_cal import cal_operator_flops
+
+
+class DictNoteShape(DictShape):
+ def __init__(self):
+ super(DictNoteShape, self).__init__()
+ self.tp = 0
+ self.cp = 0
+ self.ep = 0
+ self.type = ""
+ self.input_shape = ""
+ self.output_shape = ""
+ self.duration = 0.0
+ self.num_experts = 0
+ self.seq_length = 0
+ self.flops = 0.0
+ self.jit = 0
+
+ def change_profile_into_dictshape(self, item, config_info):
+ flops = cal_operator_flops(item.input_shapes.replace('"', ''),
+ item.output_shapes.replace('"', ''),
+ item.type)
+ self.name = item.name
+ self.type = item.type
+ self.accelerator_core = item.accelerator_core
+ self.index_name = ''
+ self.tp = config_info.tp
+ self.cp = config_info.cp
+ self.ep = config_info.ep
+ self.jit = config_info.jit
+ self.num_experts = config_info.num_experts or 1
+ self.seq_length = config_info.seq_length
+ self.input_shape = item.input_shapes.replace('"', '')
+ self.output_shape = item.output_shapes.replace('"', '')
+ self.duration = float(item.duration_us)
+ self.flops = flops
+
+
+class OperatorNoteList:
+ def __init__(self):
+ self.operator_note_list = []
+ self.operator_note_dict = {}
+ self.seq_length = 0
+ self._logger = get_logger('operator_note_list')
+
+ def get_operator_note(self, block):
+ self.get_operator_list(block.origin_profile_data_list)
+ self.get_note_in_list(block)
+ self.get_note_dict()
+
+ def get_note_in_list(self, block):
+ for (index, operator_note) in enumerate(self.operator_note_list):
+ tp = block.origin_profile_data_list.data_list[index].config_info.tp
+ cp = block.origin_profile_data_list.data_list[index].config_info.cp
+ ep = block.origin_profile_data_list.data_list[index].config_info.ep
+ num_experts = block.origin_profile_data_list.data_list[index].config_info.num_experts
+ # Align the base block
+ operator_note.reset_index_name(operator_note.fw, block.base_block.fw)
+ operator_note.reset_index_name(operator_note.bw, block.base_block.bw)
+ # Align the cp base block
+ if cp > 1:
+ _, cp_fw_index = operator_note.reset_index_name(operator_note.fw, block.cp_block.fw)
+ _, cp_re_index = operator_note.reset_index_name(operator_note.bw, block.cp_block.re)
+ _, cp_bw_index = operator_note.reset_index_name(operator_note.bw, block.cp_block.bw)
+ if cp > 2:
+ operator_note.fw = block.cp_block.reset_index_diff_cp(operator_note.fw, block.cp_block.diff_list.fw,
+ cp_fw_index, cp / 2)
+ operator_note.bw = block.cp_block.reset_index_diff_cp(operator_note.bw, block.cp_block.diff_list.re,
+ cp_re_index, cp / 2)
+ operator_note.bw = block.cp_block.reset_index_diff_cp(operator_note.bw, block.cp_block.diff_list.bw,
+ cp_bw_index, cp / 2)
+ # Align the ep block
+ if num_experts:
+ if num_experts // ep >= 2:
+ operator_note.fw = block.ep_block.reset_index_diff_ep(operator_note.fw, block.ep_block.fw,
+ (num_experts / ep) - 1)
+ operator_note.bw = block.ep_block.reset_index_diff_ep(operator_note.bw, block.ep_block.re,
+ (num_experts / ep) - 1)
+ operator_note.bw = block.ep_block.reset_index_diff_ep(operator_note.bw, block.ep_block.bw,
+ (num_experts / ep) - 1)
+
+ def get_note_dict(self):
+ for operator_note in self.operator_note_list:
+ operator_list = operator_note.fw + operator_note.bw
+ for operator in operator_list:
+ dict_exist = False
+ if operator.index_name in self.operator_note_dict.keys():
+ for dict_temp in self.operator_note_dict[operator.index_name]:
+ if dict_temp == operator:
+ dict_exist = True
+ if not dict_exist:
+ self.operator_note_dict[operator.index_name].append(operator)
+ else:
+ self.operator_note_dict[operator.index_name] = [operator]
+
+ def get_operator_list(self, origin_profile_data_list):
+ self.seq_length = origin_profile_data_list.data_list[0].config_info.seq_length
+ for (index, origin_profile_data) in enumerate(origin_profile_data_list.data_list):
+ operator_note = ProfileList()
+ self.change_profile_list_into_dict_shape_list(origin_profile_data.profile_list, operator_note,
+ origin_profile_data.config_info)
+ self.operator_note_list.append(operator_note)
+
+ @staticmethod
+ def change_profile_list_into_dict_shape_list(change_profile_list, change_operator_list, config_info):
+ for (index, item) in enumerate(change_profile_list.fw):
+ dict_shape_fw = DictNoteShape()
+ dict_shape_fw.change_profile_into_dictshape(item, config_info)
+ change_operator_list.fw.append(dict_shape_fw)
+ for (index, item) in enumerate(change_profile_list.bw):
+ dict_shape_bw = DictNoteShape()
+ dict_shape_bw.change_profile_into_dictshape(item, config_info)
+ change_operator_list.bw.append(dict_shape_bw)
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_profile_get.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_profile_get.py
new file mode 100644
index 0000000000000000000000000000000000000000..66e6613e67c0f5918f519349fa365dc9b091b83a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_profile_get.py
@@ -0,0 +1,96 @@
+import copy
+from mindspeed.auto_tuning.module.operator.operator_elemental import ProfileList
+
+
+class ConfigInfo(object):
+ def __init__(self, config):
+ self.tp = config.tensor_model_parallel_size
+ self.dp = config.data_parallel_size
+ self.pp = config.pipeline_model_parallel_size
+ self.vp = config.num_layers_per_virtual_pipeline_stage if config.num_layers_per_virtual_pipeline_stage else 1
+ self.cp = config.context_parallel_size
+ self.ep = config.expert_model_parallel_size or 1
+ self.jit = 1 if config.jit_compile else 0
+ self.seq_length = config.seq_length
+ self.num_experts = config.num_experts
+
+ def __str__(self):
+ return (f"tp:{self.tp}, dp:{self.dp}, pp:{self.pp}, vp:{self.vp}, cp:{self.cp}, ep:{self.ep}, jit:{self.jit}, "
+ f"seq_length:{self.seq_length}, num_experts:{self.num_experts}")
+
+
+class OriginalProfileData(object):
+ def __init__(self, config):
+ self.config_info = ConfigInfo(config)
+ self.profile_list = ProfileList()
+
+
+class OriginalProfileDataList(object):
+ def __init__(self):
+ self.data_list = []
+
+ def get_origin_profile_data(self, profiling_results):
+ for config, model in profiling_results:
+ origin_profile_data = OriginalProfileData(config)
+
+ profile_list_fw = self.get_profinfo_list_from_profiling(model.forward.operator_info[-1],
+ forwardflag=1)
+ profile_list_bw = self.get_profinfo_list_from_profiling(model.backward.operator_info[-1],
+ forwardflag=0)
+ origin_profile_data.profile_list.fw = copy.deepcopy(profile_list_fw)
+ origin_profile_data.profile_list.bw = copy.deepcopy(profile_list_bw)
+
+ self.data_list.append(origin_profile_data)
+
+ @staticmethod
+ def get_profinfo_list_from_profiling(items, forwardflag):
+ operator_info_list = []
+ alltoall_flag = 0
+ cp_flag1 = 0
+ cp_flag = 0
+ for (index, item) in enumerate(items):
+ # Mark forward network part for CP
+ if forwardflag == 1:
+ if "ConcatD" in item.name and index < (len(items) - 2):
+ if "hcom_send" in items[index + 1].name or "hcom_send" in items[index + 2].name:
+ cp_flag1 = 1
+ if cp_flag1 == 1:
+ if "MatMul" in item.name:
+ cp_flag1 = 0
+ continue
+ item.name = "cp_for_flag_" + item.name
+ # Mark the backward part for CP
+ if forwardflag == 0:
+ # Mark froward network part for CP re-computation
+ if cp_flag == 0 and "ConcatD" in item.name and index < (len(items) - 2):
+ if "hcom_send" in items[index + 1].name or "hcom_send" in items[index + 2].name:
+ cp_flag1 = 2
+ if cp_flag1 == 2:
+ if "MatMul" in item.name:
+ cp_flag1 = 0
+ continue
+ item.name = "cp_re_flag_" + item.name
+ # Mark backward network part for CP
+ if cp_flag == 0 and "Concat" in item.name and index < (len(items) - 2):
+ if "ZerosLike" in items[index + 1].name:
+ cp_flag = 1
+ if cp_flag == 1:
+ if "Mul" in item.name:
+ cp_flag = 0
+ if cp_flag == 1:
+ item.name = "cp_back_flag_" + item.name
+
+ # Mark EP part
+ if "alltoall" in item.name:
+ alltoall_flag = alltoall_flag + 1
+ if alltoall_flag % 2 == 1:
+ item.name = "ep_flag_" + item.name
+
+ if (
+ not ("hcom" in item.name) and item.input_shapes != 'N/A'
+ and item.input_shapes.replace('"', '').replace(';', '') != ''
+ ):
+ operator_info_list.append(item)
+ setattr(item, "index_name", '')
+
+ return operator_info_list
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_re_profile.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_re_profile.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a00d76762fffbdc5a2f7cf34d8f80c56bd1bb1d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_re_profile.py
@@ -0,0 +1,202 @@
+import os
+import stat
+import random
+import pickle
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.utils.runner.model_executor import ExecutorFlag, ModelExecutor
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_node_parse import GatherNodeProfiling
+from mindspeed.auto_tuning.utils.runner.torchrun_runner import TorchRunRunner
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.utils.utils import get_prof_dir
+from mindspeed.auto_tuning.module.operator.operator import OperatorPerformance
+from mindspeed.auto_tuning.module.operator.operator_database import OperatorHistory
+
+
+logger = get_logger('operator_re_profile')
+
+
+def search_operator(working_dir, search_cfg, communication, profile_count,
+ scale_flag=False):
+ # After a certain amount of profiling, the rest operators have not been found will be predicted using
+ # regression method.
+ executor = ModelExecutor(TorchRunRunner())
+ profiling_results = []
+ search_cfg_list = [search_cfg]
+ model_config = communication.model_cfg
+ seed = 1234
+ random.seed(seed)
+ unsampled_profiling_info = []
+ if len(search_cfg_list) > 9:
+ sampled_profiling_info = random.sample(search_cfg_list, min(9, len(search_cfg_list)))
+ unsampled_profiling_info = list(set(search_cfg_list) - set(sampled_profiling_info))
+ else:
+ sampled_profiling_info = [search_cfg]
+ for profiling_config in sampled_profiling_info:
+ if scale_flag:
+ profiling_config = scale_para(model_config, communication, profiling_config)
+ re_profiling_config = SearchConfig()
+ re_profiling_config.copy_from_config(model_config)
+ re_profiling_config.num_layers = profiling_config.pipeline_model_parallel_size
+ re_profiling_config.seq_length = profiling_config.seq_length
+ re_profiling_config.tensor_model_parallel_size = profiling_config.tensor_model_parallel_size
+ re_profiling_config.pipeline_model_parallel_size = profiling_config.pipeline_model_parallel_size
+ re_profiling_config.data_parallel_size = profiling_config.data_parallel_size
+ re_profiling_config.context_parallel_size = profiling_config.context_parallel_size
+ re_profiling_config.expert_model_parallel_size = profiling_config.expert_model_parallel_size
+ re_profiling_config.prepare_for_profiling()
+
+ from mindspeed.auto_tuning.module.hardware import Hardware
+ res_dir = os.path.join(working_dir, get_prof_dir(re_profiling_config, re_profile=True))
+ if not os.path.exists(res_dir):
+ profile_count[0] += 1
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ pkl_filename = os.path.join(working_dir, f'ootb_{Hardware().node_rank}.pkl')
+ with os.fdopen(os.open(pkl_filename, flags, mode=mode), 'wb') as f:
+ pickle.dump(re_profiling_config, f)
+ executor.execute(working_dir=working_dir, output_filename=res_dir, cfg=re_profiling_config,
+ flag=ExecutorFlag.PROFILE)
+ profiling_node_parse = GatherNodeProfiling(res_dir)
+ profiling_res = profiling_node_parse.fuse_node_pkl()
+
+ re_profiling_config.jit_compile = search_cfg.jit_compile
+ profiling_results.append([re_profiling_config, profiling_res])
+
+ operator_list = OperatorPerformance(model_config, working_dir=working_dir)
+ operator_not_found = operator_list.origin_profile_data_list.get_profinfo_list_from_profiling(
+ profiling_res.forward.operator_info[-1],
+ forwardflag=1)
+ operator_not_found_part2 = operator_list.origin_profile_data_list.get_profinfo_list_from_profiling(
+ profiling_res.backward.operator_info[-1],
+ forwardflag=0)
+ operator_not_found.extend(operator_not_found_part2)
+ logger.debug(f'Total number of operator re profiling is {len(operator_not_found)}')
+ operator_history_list = []
+ for operator in operator_not_found:
+ operator_history = OperatorHistory(types=operator.type,
+ accelerator_core=operator.accelerator_core,
+ input_shape=operator.input_shapes.replace('"', ''),
+ output_shape=operator.output_shapes.replace('"', ''),
+ duration=operator.duration_us,
+ device=Hardware().device_type,
+ jit=int(model_config.jit_compile),
+ cann="8.0.RC2.alpha002",
+ driver="24.1.rc2.b030",
+ dtype=model_config.dtype.value[0])
+ operator_history_list.append(operator_history.convert_to_dict())
+ operator_list.db.operator_history_dao.insert_history(operator_history_list)
+ operator_list.db.operator_profiling_dao.insert_history(operator_history_list)
+ return unsampled_profiling_info
+
+
+def generate_scale_config(model_config):
+ scale_config = model_config.copy()
+ scale_config.num_layers = 256
+
+ # parameter need to be adjusted
+ scale_config.tensor_model_parallel_size = 64
+ scale_config.num_attention_heads = 512
+ scale_config.hidden_size = 65536
+ scale_config.ffn_hidden_size = 229376
+ scale_config.context_parallel_size = 32
+ scale_config.seq_length = 131072
+ scale_config.max_position_embeddings = 131072
+ scale_config.expert_model_parallel_size = 32
+ scale_config.num_experts = 32
+ scale_config.pipeline_model_parallel_size = 16
+ scale_config.normalize()
+ return scale_config
+
+
+def scale_para(model_config, communication, search_cfg, test=False):
+ # load base parallel model config
+ tp = search_cfg.tensor_model_parallel_size
+ cp = search_cfg.context_parallel_size
+ pp = search_cfg.pipeline_model_parallel_size
+ ep = search_cfg.expert_model_parallel_size
+ dp = search_cfg.data_parallel_size
+
+ if pp % 2 != 0 and pp != 1:
+ logger.warning('warning: pp value set is not even.')
+
+ # load hardware config
+ # use test because in a mock situation, we do not have the real device number
+ if not test:
+ num_nodes = communication.hardware.num_nodes
+ num_devices = communication.hardware.num_devices
+ else:
+ num_nodes = 8
+ num_devices = 2 * 8
+ num_devices_ootb = 16
+
+ if not test:
+ # load model config
+ num_layers = communication.model_cfg.num_layers
+ num_attention_heads = communication.model_cfg.num_attention_heads
+ hidden_size = communication.model_cfg.hidden_size
+ ffn_hidden_size = communication.model_cfg.ffn_hidden_size
+ num_experts = communication.model_cfg.num_experts
+ sequence_length = communication.model_cfg.seq_length
+ else:
+ # for test only test whether the function works fine
+ num_layers = model_config.num_layers
+ num_attention_heads = model_config.num_attention_heads
+ hidden_size = model_config.hidden_size
+ ffn_hidden_size = model_config.ffn_hidden_size
+ num_experts = model_config.num_experts
+ sequence_length = model_config.seq_length
+
+ scale_factor = 2 # here use default tp value 8 or 4
+ # directly scale pp down to 1
+ pp_scale_factor = pp
+ scale_tp, scale_cp, scale_pp, scale_ep, scale_dp = tp, cp, pp, ep, dp
+ scale_num_layers = num_layers
+ scale_num_attention_heads = num_attention_heads
+ scale_hidden_size = hidden_size
+ scale_ffn_hidden_size = ffn_hidden_size
+ scale_num_experts = num_experts
+ scale_sequence_length = sequence_length
+ scale_space = scale_tp * scale_cp * scale_pp
+ if pp >= 2:
+ scale_pp //= pp_scale_factor
+ scale_num_layers //= num_layers
+ scale_space = scale_tp * scale_cp * scale_pp
+ logger.debug(f"Search configs is\n{search_cfg}")
+
+ while scale_space > num_devices_ootb:
+ logger.debug(f'the scale space is {scale_space}, the scale_tp is {scale_tp}, the scale_cp is {scale_cp}, '
+ f'the scale_pp is {scale_pp}, the scale_ep is {scale_ep}')
+ if scale_cp >= 4:
+ scale_cp //= scale_factor
+ scale_sequence_length //= scale_factor
+ scale_space = scale_tp * scale_cp * scale_pp
+ continue
+ if scale_tp >= 4:
+ scale_tp //= scale_factor
+ scale_num_attention_heads //= scale_factor
+ scale_hidden_size //= scale_factor
+ scale_ffn_hidden_size //= scale_factor
+ scale_space = scale_tp * scale_cp * scale_pp
+ continue
+
+ scale_dp = num_devices_ootb // (scale_tp * scale_cp * scale_pp)
+ while scale_dp * scale_cp < scale_ep:
+ scale_ep //= scale_factor
+ scale_num_experts //= scale_factor
+
+ # set up config group
+ before_scale = SearchConfig()
+ before_scale.copy_from_config(model_config)
+ before_scale.tensor_model_parallel_size = scale_tp
+ before_scale.context_parallel_size = scale_cp
+ before_scale.pipeline_model_parallel_size = scale_pp
+ before_scale.num_layers = scale_num_layers
+ before_scale.num_attention_heads = scale_num_attention_heads
+ before_scale.expert_model_parallel_size = scale_ep
+ before_scale.hidden_size = scale_hidden_size
+ before_scale.ffn_hidden_size = scale_ffn_hidden_size
+ before_scale.num_experts = scale_num_experts
+ before_scale.seq_length = scale_sequence_length
+ before_scale.data_parallel_size = scale_dp
+ return before_scale
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_shape_analysis.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_shape_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..a283610381cf5e3f2cd228e36aa5ae9c71531658
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_shape_analysis.py
@@ -0,0 +1,357 @@
+from mindspeed.auto_tuning.utils.logger import get_logger
+
+logger = get_logger('operator_shape_analysis')
+
+
+class DataEp:
+ def __init__(self):
+ self.tp = 0
+ self.cp = 0
+ self.ep = 0
+ self.input_shape = ""
+ self.output_shape = ""
+
+
+def separate_ep(results):
+ diff_idx_input = []
+ diff_idx_output = []
+ index_visit = [False] * len(results)
+ flag = 0
+ result = []
+ for i, _ in enumerate(results):
+ input_list = {}
+ output_list = {}
+ if index_visit[i]:
+ continue
+ index_visit[i] = True
+ result1 = results[i]
+ tp1 = str(result1.tp)
+ cp1 = str(result1.cp)
+ ep1 = str(result1.ep)
+ seq_length1 = str(result1.seq_length)
+ input_list[ep1] = get_default_shape_change(result1.input_shape)
+ output_list[ep1] = get_default_shape_change(result1.output_shape)
+
+ for j in range(i + 1, len(results)):
+ if index_visit[j]:
+ continue
+ result2 = results[j]
+ cp2 = str(result2.cp)
+ tp2 = str(result2.tp)
+ ep2 = str(result2.ep)
+ seq_length2 = str(result2.seq_length)
+ if tp1 != tp2 or cp1 != cp2 or seq_length1 != seq_length2:
+ continue
+ index_visit[j] = True
+ input_list[ep2] = get_default_shape_change(result2.input_shape)
+ output_list[ep2] = get_default_shape_change(result2.output_shape)
+ # calculate linear relationship
+ ep_arr = list(input_list.keys())
+ # The first occurrence of ep is recorded, other ep operator shapes directly modify the relevant dimension and
+ # insert into the dictionary.
+ if flag == 0:
+ diff_idx_input = [0] * count_num(input_list.get(str(ep1)))
+ diff_idx_output = [0] * count_num(output_list.get(str(ep1)))
+ input_cal_tmp, diff_idx_input = analyze_shape_arr_new(input_list, ep_arr, diff_idx_input, 2)
+ output_cal_tmp, diff_idx_output = analyze_shape_arr_new(output_list, ep_arr, diff_idx_output, 2)
+ if len(input_list) != 1:
+ flag = 1
+ else:
+ input_cal_tmp = modify_by_index(input_list, diff_idx_input, ep_arr, mode=1)
+ output_cal_tmp = modify_by_index(output_list, diff_idx_output, ep_arr, mode=1)
+ tmp = DataEp()
+ tmp.tp = tp1
+ tmp.cp = cp1
+ tmp.ep = ep1
+ tmp.seq_length = seq_length1
+ tmp.input_shape = input_cal_tmp
+ tmp.output_shape = output_cal_tmp
+ result.append(tmp)
+ return result
+
+
+def separate_cp_tp(results):
+ input_shape_dic = {}
+ output_shape_dic = {}
+ index_visit = [False] * len(results)
+ diff_idx_input = []
+ diff_idx_output = []
+ flag = 0
+ for i, _ in enumerate(results):
+ input_list = {}
+ output_list = {}
+ if index_visit[i]:
+ continue
+ index_visit[i] = True
+ result1 = results[i]
+ cp1 = str(result1.cp)
+ tp1 = str(result1.tp)
+ seq_length1 = str(result1.seq_length)
+ input_list[tp1] = result1.input_shape
+ output_list[tp1] = result1.output_shape
+ for j in range(i + 1, len(results)):
+ if index_visit[j]:
+ continue
+ result2 = results[j]
+ cp2 = str(result2.cp)
+ tp2 = str(result2.tp)
+ seq_length2 = str(result2.seq_length)
+ if cp1 != cp2 or seq_length1 != seq_length2:
+ continue
+ index_visit[j] = True
+ input_list[tp2] = result2.input_shape
+ output_list[tp2] = result2.output_shape
+ # calculate linear relationship
+ tp_arr = list(input_list.keys())
+ if set(input_list.keys()) == {'8', '4'}:
+ for index_i, sublist in enumerate(input_list.get('4')):
+ for j, value in enumerate(sublist):
+ check_value = isinstance(value, float) and '.1' in str(value)
+ if (check_value and index_i < len(input_list.get('8'))
+ and j < len(input_list.get('4')[index_i])):
+ input_list.get('8')[index_i][j] = value
+ # The first occurrence of cp is recorded, other cp operator shapes directly modify the relevant dimension
+ if flag == 0:
+ arr_in = input_list.get(str(tp1))
+ arr_out = output_list.get(str(tp1))
+ diff_idx_input = [0] * count_num(arr_in)
+ diff_idx_output = [0] * count_num(arr_out)
+ input_cal_tmp, diff_idx_input = analyze_shape_arr_new(input_list, tp_arr, diff_idx_input, 0)
+ output_cal_tmp, diff_idx_output = analyze_shape_arr_new(output_list, tp_arr, diff_idx_output, 0)
+ if len(input_list) != 1:
+ flag = 1
+ else:
+ input_cal_tmp = modify_by_index(input_list, diff_idx_input, tp_arr, mode=2)
+ output_cal_tmp = modify_by_index(output_list, diff_idx_output, tp_arr, mode=2)
+ input_shape_dic[cp1] = input_cal_tmp
+ output_shape_dic[cp1] = output_cal_tmp
+ if set(input_shape_dic.keys()) == {'4', '2'}:
+ for i, sublist in enumerate(input_shape_dic.get('2')):
+ for j, value in enumerate(sublist):
+ check_value = isinstance(value, float) and '.4' in str(value)
+ if (check_value and
+ i < len(input_shape_dic.get('4')) and j < len(input_shape_dic.get('4')[i])):
+ input_shape_dic.get('4')[i][j] = value
+ # calculate linear relationship
+ cp_arr = list(input_shape_dic.keys())
+ input_cal_arr, diff_idx_input = analyze_shape_arr_new(input_shape_dic, cp_arr, diff_idx_input, 1)
+ output_cal_arr, diff_idx_output = analyze_shape_arr_new(output_shape_dic, cp_arr, diff_idx_output, 1)
+
+ return input_cal_arr, output_cal_arr
+
+
+def analyze_shape_arr_new(input_shape_list, tp_arr, diff, mode=0):
+ # Data cleaning, removing some invalid data.
+ input_shape_list, tp_arr = normal_list(input_shape_list, tp_arr)
+
+ # Initialize the result array, initializing values for each position in the shape, defaulting value means unchanged.
+ result_arr = input_shape_list.get(str(tp_arr[0]))
+
+ # Compare the differences in shape between different TPs, and find the index of the differing columns
+ diff_idx, diff_arr = analyze_shape_list(input_shape_list, str(tp_arr[0]))
+ w_arr = []
+ num = count_num(result_arr)
+ if len(diff_idx) != 0 and len(diff) < num:
+ diff = [0] * num
+ for i in diff_idx:
+ if mode == 0:
+ diff[i] |= 1
+ elif mode == 1:
+ diff[i] += 1
+ elif mode == 2:
+ diff[i] = 1
+ """
+ tp cp ep
+ 1 1 1
+ Only cut by TP with a suffix of 0.4, only CP is 0.2, only EP is 0.1.
+ CP + EP binary corresponds to 0.3.
+ """
+ for index, _ in enumerate(diff_idx):
+ # Calculate and record the pattern of changes based on the different data, with the default tp * shape_x
+ i = diff_idx[index]
+ if mode == 2:
+ w = cal_shape_change_with_ep(diff_arr[index], tp_arr)
+ else:
+ w = cal_shape_change_with_tp_cp(diff_arr[index], tp_arr)
+ flag = 0
+ dis = float(float(w) - int(w))
+ w = modify_special(w)
+ if abs(dis - 0.1) < 0.001:
+ flag = 1
+ if diff[i] == 1:
+ if mode == 0:
+ if flag == 0:
+ # Only cut by TP 0.4
+ w_arr.append(float(w) + 0.4)
+ elif flag == 1:
+ # tp + ep 0.5
+ w_arr.append(float(int(w)) + 0.5)
+ elif mode == 1:
+ if flag == 0:
+ # Only cut by CP 0.2
+ w_arr.append(float(w) + 0.2)
+ elif flag == 1:
+ # cp + ep 0.3
+ w_arr.append(float(int(w)) + 0.3)
+ elif mode == 2:
+ # ep with suffix 0.1
+ w_arr.append(float(w) + 0.1)
+ elif diff[i] == 2:
+ if flag == 0:
+ # tp + cp 0.6
+ w_arr.append(float(int(w)) + 0.6)
+ elif flag == 1:
+ # tp + cp + ep 0.7
+ w_arr.append(float(int(w)) + 0.7)
+ else:
+ logger.warning("error")
+ result_arr = convert_w_to_result_arr(result_arr, diff_idx, w_arr)
+ return result_arr, diff
+
+
+def get_default_shape_change(param):
+ rows = param.split(';')
+ arr = []
+ for row in rows:
+ nums = []
+ for num in row.split(','):
+ if num != '':
+ nums.append(int(num))
+ arr.append(nums)
+ return arr
+
+
+def analyze_shape_list(input_shape_list, row1_value):
+ diff_index = [] # Save different column indices
+ diff_arr = [] # Save different data
+ # Compare the sublist within each list.
+ column_index = 0
+
+ for i in range(len(input_shape_list[row1_value])):
+ for index_n in range(len(input_shape_list[row1_value][i])):
+ tmp_list = []
+ tmp_list_float = []
+ for value in input_shape_list.values():
+ tmp_list.append(int(value[i][index_n]))
+ tmp_list_float.append(value[i][index_n])
+ if len(set(tmp_list)) != 1:
+ diff_arr.append(tmp_list_float)
+ diff_index.append(column_index)
+ column_index += 1
+
+ return diff_index, diff_arr
+
+
+def cal_shape_change_with_tp_cp(y_arr, x_arr):
+ w_arr = []
+ size = len(x_arr)
+ h = float(y_arr[0] - int(y_arr[0]))
+ for index in range(0, size):
+ if abs(h) < 0.001:
+ h = float(y_arr[index] - int(y_arr[index]))
+ w_arr.append(int(y_arr[index]) * int(x_arr[index]))
+
+ return w_arr[0] + h
+
+
+def cal_shape_change_with_ep(y_arr, x_arr):
+ w_arr = []
+ size = len(x_arr)
+ h = float(y_arr[0] - int(y_arr[0]))
+ for index in range(0, size):
+ if abs(h) < 0.001:
+ h = float(y_arr[index] - int(y_arr[index]))
+ w_arr.append(int(y_arr[index]) / float(x_arr[index]))
+
+ return w_arr[0] + h
+
+
+def convert_w_to_result_arr(result_arr, index_arr, w_arr):
+ result_list = []
+ column_index = 0
+ index_index = 0
+ for inner_arr in result_arr:
+ result = []
+ for item in inner_arr:
+ if index_index < len(index_arr) and column_index == index_arr[index_index]:
+ result.append(float(w_arr[index_index]))
+ index_index = index_index + 1
+ else:
+ result.append(float(item))
+ column_index = column_index + 1
+ result_list.append(result)
+ if len(inner_arr) == 0:
+ column_index = column_index + 1
+ return result_list
+
+
+def check_array_format(arr1, arr2):
+ if len(arr1) != len(arr2):
+ return False
+ for i, _ in enumerate(arr1):
+ if isinstance(arr1[i], list) and isinstance(arr2[i], list):
+ if not check_array_format(arr1[i], arr2[i]):
+ return False
+ return True
+
+
+def normal_list(input_shape_list, tp_arr):
+ new_input_shape_list = {}
+ new_tp_arr = []
+ if len(input_shape_list) > 0 and len(tp_arr) > 0:
+ new_input_shape_list[str(tp_arr[0])] = input_shape_list[str(tp_arr[0])]
+ new_tp_arr.append(tp_arr[0])
+ for index in range(1, len(tp_arr)):
+ if check_array_format(input_shape_list[str(tp_arr[0])], input_shape_list[str(tp_arr[index])]):
+ new_input_shape_list[str(tp_arr[index])] = input_shape_list[str(tp_arr[index])]
+ new_tp_arr.append(tp_arr[index])
+ else:
+ logger.warning(f'Incorrect input_shape_list or tp_arr: {input_shape_list}, {tp_arr}')
+
+ return new_input_shape_list, new_tp_arr
+
+
+def modify_special(w):
+ result = int(w)
+ if result == 9016:
+ result = 9024
+ elif result == 1127:
+ result = 1128
+
+ return result
+
+
+def count_num(arr):
+ cnt = 0
+ for i in arr:
+ for _ in i:
+ cnt += 1
+ return cnt
+
+
+def modify_by_index(shape_list, index_diff, tp_arr, mode=0):
+ # Data cleaning, to remove invalid data elements, such as data that doesn't match the shape
+ input_shape_list, tp_arr = normal_list(shape_list, tp_arr)
+
+ input_list = shape_list[str(tp_arr[0])]
+ result_list = []
+ i_diff = 0
+ column_index = 0
+ for arr in input_list:
+ result = []
+ for item in arr:
+ ans = 0.0
+ if column_index < len(index_diff) and index_diff[column_index] == 1:
+ # 修改
+ if mode == 1:
+ ans = float(int(item) / float(tp_arr[0])) + 0.1
+ elif mode == 2:
+ ans = float(int(item) * float(tp_arr[0])) + 0.4
+ i_diff += 1
+ else:
+ ans = float(item)
+ result.append(float(ans))
+ column_index += 1
+ result_list.append(result)
+
+ return result_list
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_shape_cal.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_shape_cal.py
new file mode 100644
index 0000000000000000000000000000000000000000..3488b6dee93b194390739092adaa4e0092b4bd00
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/operator/operator_shape_cal.py
@@ -0,0 +1,178 @@
+import ast
+import math
+import numpy as np
+from sklearn.linear_model import LinearRegression
+
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+
+
+def cal_new_shape_new(cal_arr, search_cfg: SearchConfig):
+ tp = search_cfg.tp
+ cp = search_cfg.cp
+ ep = search_cfg.ep or 1
+ mbs = search_cfg.mbs
+ num_experts = search_cfg.num_experts or 1
+ cal_arr = ast.literal_eval(cal_arr)
+ result_arr = []
+ base = 0.0001
+ mbs_flag = False
+ if mbs > 1:
+ mbs_flag = True
+ for inner_arr in cal_arr:
+ result = []
+ for item in inner_arr:
+ dis = item - float(int(item))
+ if abs(dis - 0) <= base:
+ result.append(int(item))
+ elif abs(dis - 0.1) <= base:
+ result.append(math.ceil(int(item) * ep / num_experts))
+ elif abs(dis - 0.2) <= base and mbs_flag:
+ result.append(math.ceil(int(item) * mbs / cp))
+ elif abs(dis - 0.2) <= base:
+ result.append(math.ceil(int(item) / cp))
+ elif abs(dis - 0.3) <= base and mbs_flag:
+ result.append(math.ceil(int(item) * mbs / cp * ep / num_experts))
+ elif abs(dis - 0.3) <= base:
+ result.append(math.ceil(int(item) / cp * ep / num_experts))
+ elif abs(dis - 0.4) <= base:
+ result.append(math.ceil(int(item) / tp))
+ elif abs(dis - 0.5) <= base:
+ result.append(math.ceil(int(item) / tp * ep / num_experts))
+ elif abs(dis - 0.6) <= base and mbs_flag:
+ result.append(math.ceil(int(item) * mbs / tp / cp))
+ elif abs(dis - 0.6) <= base:
+ result.append(math.ceil(int(item) / tp / cp))
+ elif abs(dis - 0.7) <= base and mbs_flag:
+ result.append(math.ceil(int(item) * mbs / tp / cp * ep / num_experts))
+ elif abs(dis - 0.7) <= base:
+ result.append(math.ceil(int(item) / tp / cp * ep / num_experts))
+ result_arr.append(result)
+ return result_arr
+
+
+def cal_new_shape_tce(cal_arr, search_cfg: SearchConfig):
+ result_cal_arr = cal_new_shape_new(cal_arr, search_cfg)
+ result_str = ';'.join([','.join(map(str, arr)) if arr else '' for arr in result_cal_arr])
+ return result_str
+
+
+def mul_shape(shape):
+ result = 1
+ for item in shape:
+ if item != 0:
+ result *= item
+ return result
+
+
+def model_operator_with_shape(history_result_list):
+ if len(history_result_list) <= 0:
+ return 0, 0
+ x_arr = []
+ y_arr = []
+ for history in history_result_list:
+ x_arr.append([cal_operator_flops(history.input_shape, history.output_shape, history.types)])
+ y_arr.append([history.duration])
+ shape_model_w, shape_model_b = linear_regression(x_arr, y_arr)
+ return shape_model_w, shape_model_b
+
+
+def cal_operator_flops(input_shape, output_shape, types):
+ input_shape_arr_before = []
+ output_shape_arr = []
+ if len(input_shape) < 1 or input_shape == ';':
+ return 1
+ for str_num in input_shape.split(';')[0].split(','):
+ if str_num == '':
+ return 1
+ else:
+ input_shape_arr_before.append(int(str_num))
+ if len(output_shape) < 1 or output_shape == ';':
+ return 1
+ for str_num in output_shape.split(';')[0].split(','):
+ if str_num == '':
+ return 1
+ else:
+ output_shape_arr.append(int(str_num))
+ # other operator flops
+ x_item = mul_shape(input_shape_arr_before)
+
+ # FLOPs(BatchMatMul) = b*x*y*n; [b, x, n] * [b, n, y] == [b, x, y]
+ if types in ['BatchMatMul']:
+ x_item = mul_shape(output_shape_arr)
+ if input_shape_arr_before[1] in output_shape_arr:
+ x_item *= input_shape_arr_before[2]
+ else:
+ x_item *= input_shape_arr_before[1]
+
+ # FLOPs(MatMul) = x*y*n; [x, n] * [n, y] == [x, y]
+ if types in ['MatMul', 'MatMulCommon']:
+ input_shape_arr_after = [int(str_num) for str_num in input_shape.split(';')[1].split(',')]
+ x_item = 2 * mul_shape(output_shape_arr)
+ if input_shape_arr_before[0] in output_shape_arr:
+ x_item *= input_shape_arr_before[1]
+ else:
+ x_item *= input_shape_arr_before[0]
+ # The input matrix A needs to be transposed, resulting in additional FLOPs.
+ if output_shape_arr[0] != input_shape_arr_before[0]:
+ x_item += 2 * mul_shape(input_shape_arr_before)
+ # The input matrix B needs to be transposed, resulting in additional FLOPs.
+ if output_shape_arr[1] != input_shape_arr_after[1]:
+ x_item += 2 * mul_shape(input_shape_arr_after)
+
+ if types in ['Mul', 'MulAiCore', 'ConcatD']:
+ x_item = 0
+ str_arr = input_shape.split(';')
+ for arr in str_arr:
+ if len(arr) > 0:
+ int_arr = [int(str_num) for str_num in arr.split(',')]
+ x_item += mul_shape(int_arr)
+
+ if types in ['Slice', 'SliceAiCore']:
+ x_item = 0
+ str_arr = output_shape.split(';')
+ for arr in str_arr:
+ if len(arr) > 0:
+ int_arr = [int(str_num) for str_num in arr.split(',')]
+ x_item += mul_shape(int_arr)
+
+ if types in ['FlashAttentionScore', 'FlashAttentionScoreGrad']:
+ x_item = mul_shape(input_shape_arr_before)
+ input_shape_arr_after_flash = []
+ for str_num in input_shape.split(';')[1].split(','):
+ if str_num != '':
+ input_shape_arr_after_flash.append(int(str_num))
+ x_tmp = input_shape_arr_after_flash[0] * x_item
+ x_item += x_tmp
+
+ return x_item
+
+
+def cal_operator_duration_with_shape(shape_model_w, shape_model_b, flops):
+ result_duration = float(shape_model_w) * flops + float(shape_model_b)
+ if result_duration < 0:
+ return 0
+ return result_duration
+
+
+def model_operator_with_tp(operator_notes_index_list):
+ """
+ For operators with the same TP and index-name, the duration decreases linearly with TP, duration ~ w / tp.
+ Calculate the proportion of TP as a1 and the proportion of CP as a2.
+ The final result is d = model_w_tp / TP + model_w_cp / CP.
+ """
+ result_tp = 0
+ for operator_notes_index in operator_notes_index_list:
+ result_tp = result_tp + operator_notes_index.tp * operator_notes_index.duration
+ model_w_tp = result_tp / len(operator_notes_index_list)
+
+ return model_w_tp, 0
+
+
+def linear_regression(x, y):
+ x = np.array(x)
+ y = np.array(y)
+ model = LinearRegression()
+ model.fit(x, y)
+ w = model.coef_[0]
+ b = model.intercept_
+ return w[0], b[0]
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_communication_parse.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_communication_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfc42ce06eede90705086abf597d4888e53ecfe2
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_communication_parse.py
@@ -0,0 +1,284 @@
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_config import (
+ TensorParallelCommunication,
+ DataParallelCommunication,
+ PipelineParallelCommunication,
+ ContextParallelCommunication,
+ ExpertParallelCommunication,
+ ProfilingConfig
+)
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import NumberConstant, SpecialKeyName
+
+
+class AnalyseCommunicationMsg(ProfilingConfig):
+ """ Analyse communication massage. """
+
+ def __init__(self, search_cfg, communication_details, kernel_details):
+ super(AnalyseCommunicationMsg, self).__init__(search_cfg)
+ self.collective_hcom = communication_details.get('collective', {})
+ self.p2p_hcom = communication_details.get('p2p', {})
+ self.kernel_details = kernel_details
+ self.tensor_parallel_comm = TensorParallelCommunication()
+ self.pipeline_parallel_comm = PipelineParallelCommunication()
+ self.data_parallel_comm = DataParallelCommunication()
+ self.context_parallel_comm = ContextParallelCommunication()
+ self.expert_parallel_comm = ExpertParallelCommunication()
+ self.pp_stream_id = None
+ self.tp_stream_id = None
+ self.overlap_record = {}
+ self.overlap_list = []
+
+ @classmethod
+ def is_send_or_recv_op(cls, op_name: str) -> bool:
+ return 'send' in op_name or 'receive' in op_name
+
+ def get_hcom_and_hcom_overlap(self, index, info):
+ current_name = self.kernel_details[index][SpecialKeyName.NAME]
+ next_name = self.kernel_details[index + 1][SpecialKeyName.NAME]
+ if current_name in self.overlap_list or next_name in self.overlap_list:
+ return
+
+ if index + 1 >= len(self.kernel_details):
+ return
+
+ hcom_time1 = float(info[SpecialKeyName.DURATION_US])
+ hcom_time2 = float(self.kernel_details[index + 1][SpecialKeyName.DURATION_US])
+ shorter_hcom = current_name if hcom_time1 <= hcom_time2 else next_name
+ self.overlap_list.append(shorter_hcom)
+
+ def get_compute_and_hcom_overlap(self, index, info):
+ overlap_record = {}
+ overlap_list = []
+ overlap_time = float(info[SpecialKeyName.DURATION_US])
+ op1 = self.kernel_details[index + 1]
+ op2 = self.kernel_details[index + 2] if index + 2 < len(self.kernel_details) else None
+ op1_name = op1[SpecialKeyName.NAME]
+ hcom1_duration = float(op1[SpecialKeyName.DURATION_US])
+
+ if op2 and op2[SpecialKeyName.ACCELERATOR_CORE] == 'HCCL':
+ op2_name = op2[SpecialKeyName.NAME]
+ hcom2_duration = float(op2[SpecialKeyName.DURATION_US])
+
+ if hcom2_duration <= hcom1_duration:
+ overlap_list.append(op2_name)
+ overlap_record[op1_name] = min(overlap_time, hcom1_duration)
+ else:
+ overlap_list.append(op1_name)
+ overlap_record[op1_name] = min(overlap_time, hcom2_duration)
+ else:
+ overlap_record[op1_name] = min(overlap_time, hcom1_duration)
+
+ return overlap_record, overlap_list
+
+ def is_compute_and_hcom_overlap(self, index, row):
+ if index + 1 >= len(self.kernel_details):
+ return False
+ op1 = self.kernel_details[index + 1]
+ if op1[SpecialKeyName.ACCELERATOR_CORE] != 'HCCL' or row[SpecialKeyName.ACCELERATOR_CORE] == 'HCCL':
+ return False
+ start_time = float(row[SpecialKeyName.START_TIME_US])
+ duration = float(row[SpecialKeyName.DURATION_US])
+ op1_start_time = float(op1[SpecialKeyName.START_TIME_US])
+ return op1_start_time < start_time + duration
+
+ def is_hcom_hcom_overlap(self, index, row):
+ if index + 1 >= len(self.kernel_details):
+ return False
+ op1 = self.kernel_details[index + 1]
+ if row[SpecialKeyName.ACCELERATOR_CORE] != 'HCCL' or op1[SpecialKeyName.ACCELERATOR_CORE] != 'HCCL':
+ return False
+ start_time = float(row[SpecialKeyName.START_TIME_US])
+ duration = float(row[SpecialKeyName.DURATION_US])
+ op1_start_time = float(op1[SpecialKeyName.START_TIME_US])
+ return op1_start_time < start_time + duration
+
+ def analyse_parallel_comm(self):
+ self._analyse_communication_overlap()
+ min_expert_time = None
+ for name, info in self.collective_hcom.items():
+ if 'hcom' not in name:
+ continue
+ if self.is_send_or_recv_op(name):
+ self._analyse_pp_comm(name, info)
+ continue
+ if 'alltoall' in name:
+ min_expert_time = self._analyse_ep_comm(name, info, min_expert_time)
+ continue
+ if self.search_cfg.tp > 1:
+ self._analyse_tp_comm(name, info)
+ self._analyse_dp_comm(name, info)
+ if self.search_cfg.pp > 1 and self.search_cfg.cp > 1:
+ self.pp_stream_id = self._analyse_pp_cp_process_id()
+ else:
+ self.pp_stream_id = None
+ for name, info in self.p2p_hcom.items():
+ if 'hcom' not in name:
+ continue
+ hcom_name = name.split('@')[0]
+ stream_id = hcom_name.split('_')[3]
+ if (self.pp_stream_id and self.pp_stream_id == stream_id) or self.search_cfg.cp == 1:
+ self._analyse_pp_comm(name, info)
+ else:
+ self._analyse_cp_comm(name, info)
+
+ self._get_zero1_hcom()
+ if min_expert_time:
+ self.expert_parallel_comm.min_comm_time_ms = len(self.expert_parallel_comm.details) * min_expert_time
+ self.expert_parallel_comm.wait_time_ms = self.expert_parallel_comm.total_time_ms - \
+ self.expert_parallel_comm.min_comm_time_ms
+
+ def get_tp_comm(self):
+ return self.tensor_parallel_comm
+
+ def get_pp_comm(self):
+ return self.pipeline_parallel_comm
+
+ def get_dp_comm(self):
+ return self.data_parallel_comm
+
+ def get_cp_comm(self):
+ return self.context_parallel_comm
+
+ def get_ep_comm(self):
+ return self.expert_parallel_comm
+
+ def is_tp_communication(self, name):
+ return "reduceScatter" in name or "allGather" in name
+
+ def _accumulate_communication_stats(self, comm_obj, name, info):
+ if isinstance(comm_obj, TensorParallelCommunication) and not self.is_tp_communication(name):
+ comm_obj.details.append({name: info})
+ return
+ comm_obj.total_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ comm_obj.wait_time_ms += (info[SpecialKeyName.WAIT_TIME_MS] + info[SpecialKeyName.IDLE_TIME_MS])
+ hcom_name = name.split('@')[0]
+ if isinstance(comm_obj, TensorParallelCommunication):
+ if hcom_name in self.overlap_record:
+ comm_obj.overlap_time_ms += self.overlap_record[hcom_name] / NumberConstant.CONVERSION_TIME
+ comm_obj.fixed_wait_time_ms += (info[SpecialKeyName.WAIT_TIME_MS] + info[SpecialKeyName.IDLE_TIME_MS])
+ else:
+ comm_obj.fixed_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ elif hcom_name in self.overlap_record:
+ comm_obj.overlap_time_ms += self.overlap_record[hcom_name] / NumberConstant.CONVERSION_TIME
+ comm_obj.details.append({name: info})
+
+ def _analyse_pp_cp_process_id(self):
+ pp_and_cp_send_id = []
+ pp_and_cp_receive_id = []
+ pp_stream_id = None
+ for name, _ in self.p2p_hcom.items():
+ if 'hcom' not in name:
+ continue
+ hcom_name = name.split('@')[0]
+ stream_id = hcom_name.split('_')[3]
+ if 'send' in name:
+ if len(pp_and_cp_receive_id) > 1 and stream_id in pp_and_cp_receive_id:
+ pp_stream_id = stream_id
+ if stream_id not in pp_and_cp_send_id:
+ pp_and_cp_send_id.append(stream_id)
+ elif 'receive' in name:
+ if len(pp_and_cp_send_id) > 1 and stream_id in pp_and_cp_send_id:
+ pp_stream_id = stream_id
+ if stream_id not in pp_and_cp_receive_id:
+ pp_and_cp_receive_id.append(stream_id)
+ if pp_stream_id is not None:
+ break
+ return pp_stream_id
+
+ def _dp_comm_with_mlp_and_attention(self, mlp_process_id, process_id, name, info):
+ if mlp_process_id and process_id == mlp_process_id:
+ self.data_parallel_comm.mlp_zero_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ if 'allGather' in name:
+ self.data_parallel_comm.mlp_ag_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ if 'reduceScatter' in name:
+ self.data_parallel_comm.mlp_rs_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ else:
+ self.data_parallel_comm.other_zero_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ if 'allGather' in name:
+ self.data_parallel_comm.other_ag_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ if 'reduceScatter' in name:
+ self.data_parallel_comm.other_rs_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+
+ def _get_zero1_hcom(self):
+ mlp_process_id = None
+ if not self.data_parallel_comm.details:
+ return
+ if 'allGather' in list(self.data_parallel_comm.details[-1].keys())[0] \
+ and (self.search_cfg.cp * self.search_cfg.dp / self.search_cfg.ep != 1):
+ mlp_process_id = list(self.data_parallel_comm.details[-1].keys())[0].split('_')[3]
+ for hcom in self.data_parallel_comm.details:
+ for name, info in hcom.items():
+ process_id = name.split('_')[3]
+ if 'allReduce' in name and self.search_cfg.zero1:
+ continue
+ self._dp_comm_with_mlp_and_attention(mlp_process_id, process_id, name, info)
+
+ def _analyse_tp_comm(self, name, info):
+ hcom_name = name.split('@')[0]
+ if hcom_name in self.overlap_list:
+ return
+ if ('reduceScatter' in hcom_name or 'broadcast' in hcom_name) and not self.tp_stream_id:
+ self.tp_stream_id = name.split('_')[3]
+ if self.search_cfg.tp > 1 and self.tp_stream_id and name.split('_')[3] == self.tp_stream_id:
+ self._accumulate_communication_stats(self.tensor_parallel_comm, name, info)
+
+ def _analyse_pp_comm(self, name, info):
+ self._accumulate_communication_stats(self.pipeline_parallel_comm, name, info)
+
+ def _analyse_dp_comm(self, name, info):
+ hcom_name = name.split('@')[0]
+ stream_id = hcom_name.split('_')[3]
+ if stream_id != self.tp_stream_id and hcom_name.split('_')[1] in ["reduceScatter", "allGather"]:
+ self._accumulate_communication_stats(self.data_parallel_comm, name, info)
+
+ def _analyse_cp_comm(self, name, info):
+ self._accumulate_communication_stats(self.context_parallel_comm, name, info)
+
+ cp_vector_time = self._analyse_cp_vector_time()
+ self.context_parallel_comm.vector_time_ms = cp_vector_time
+
+ def _analyse_ep_comm(self, name, info, min_expert_time):
+ if not min_expert_time:
+ min_expert_time = info[SpecialKeyName.ELAPSE_TIME_MS]
+ else:
+ min_expert_time = min(min_expert_time, info[SpecialKeyName.ELAPSE_TIME_MS])
+ self.expert_parallel_comm.total_time_ms += info[SpecialKeyName.ELAPSE_TIME_MS]
+ self.expert_parallel_comm.details.append({name: info})
+ return min_expert_time
+
+ def _analyse_communication_overlap(self):
+ for index, row in enumerate(self.kernel_details):
+ if "Name" not in row or "Type" not in row:
+ continue
+ if self.is_compute_and_hcom_overlap(index, row):
+ per_overlap_record, per_overlap_list = self.get_compute_and_hcom_overlap(index, row)
+ self.overlap_record = {**self.overlap_record, **per_overlap_record}
+ self.overlap_list.extend(per_overlap_list)
+ elif self.is_hcom_hcom_overlap(index, row):
+ self.get_hcom_and_hcom_overlap(index, row)
+
+ def _cp_vector_operator_overlap(self, index, row):
+ if index >= len(self.kernel_details) - 1:
+ return False
+ is_hccl = row[SpecialKeyName.ACCELERATOR_CORE] == 'HCCL'
+ is_ai_vector_core = self.kernel_details[index + 1][SpecialKeyName.ACCELERATOR_CORE] == 'AI_VECTOR_CORE'
+ is_time_overlap = float(self.kernel_details[index + 1][SpecialKeyName.START_TIME_US]) < float(
+ row[SpecialKeyName.START_TIME_US]) + float(row[SpecialKeyName.DURATION_US])
+ is_overlap = is_hccl and is_ai_vector_core and is_time_overlap
+ if is_overlap and self.is_send_or_recv_op(row[SpecialKeyName.NAME]):
+ return True
+ return False
+
+ def _analyse_cp_vector_time(self):
+ is_cp_vector = False
+ total_cp_vector = 0
+ for index, row in enumerate(self.kernel_details):
+ if "Name" not in row or "Type" not in row:
+ continue
+ is_ai_vector_core = row[SpecialKeyName.ACCELERATOR_CORE] == 'AI_VECTOR_CORE'
+ if is_cp_vector and is_ai_vector_core and 'Grad' not in row[SpecialKeyName.NAME]:
+ total_cp_vector += float(row[SpecialKeyName.DURATION_US]) / NumberConstant.CONVERSION_TIME
+ elif is_cp_vector and row[SpecialKeyName.ACCELERATOR_CORE] != 'AI_VECTOR_CORE':
+ is_cp_vector = False
+ if self._cp_vector_operator_overlap(index, row):
+ is_cp_vector = True
+ return total_cp_vector
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_config.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5850c8db93f5769b6706d8b32b628e9b0d7c3cbf
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_config.py
@@ -0,0 +1,259 @@
+from copy import deepcopy
+from typing import List
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import NumberConstant
+
+
+class ProfilingConfig:
+ """
+ Basic parameters of profiling
+ """
+
+ def __init__(self, search_cfg=None, args=None):
+ self.search_cfg = deepcopy(search_cfg)
+ self.per_micro_layer = search_cfg.num_layers // search_cfg.pp
+ self.vpp = search_cfg.vpp if search_cfg.vpp else 1
+ self.micro_num = search_cfg.gbs // (search_cfg.mbs * search_cfg.dp) * self.vpp
+ self.stage_id = 0
+
+ # hardware config
+ if args:
+ if isinstance(args, Hardware):
+ self.nodes = args.num_nodes
+ self.devices_per_node = args.devices_per_node
+ self.node_rank = args.node_rank
+ else:
+ self.nodes = args.nnodes
+ self.devices_per_node = args.nproc_per_node
+ self.node_rank = args.node_rank
+ else:
+ self.nodes = 1
+ self.devices_per_node = 8
+ self.node_rank = 0
+
+ def search_first_operator_idx_for_per_layer_enable_pp_last_stage(self, fw_norm_index, bw_norm_index):
+ fw_layer_start = []
+ bw_layer_end = []
+ recompute_fw = []
+ warm_micro_num = self._calculate_warm_micro_num()
+ bw_idx = 0
+ fw_idx = 0
+ for micro in range(self.micro_num):
+ i = micro // (self.vpp * self.search_cfg.pp)
+ fw_layer_start.append([fw_norm_index[fw_idx]])
+ fw_idx = self._calculate_fw_idx(fw_idx, i, micro)
+ bw_idx = self._calculate_bw_idx(bw_idx, i, micro)
+ bw_layer_end.append([bw_norm_index[bw_idx - 1]])
+ if self.search_cfg.is_full_recompute:
+ if warm_micro_num <= micro + 1:
+ recompute_fw.append([fw_norm_index[fw_idx]])
+ fw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
+ if micro == self.micro_num - 1:
+ for i in range(warm_micro_num - 1):
+ fw_idx += i * NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
+ recompute_fw.append([fw_norm_index[fw_idx]])
+ if self.vpp > 1:
+ fw_per_micro_opt_num = fw_layer_start[1][0] - fw_layer_start[0][0]
+ else:
+ fw_per_micro_opt_num = fw_norm_index[2] - fw_norm_index[0]
+ bw_per_micro_opt_num = bw_norm_index[2] - bw_norm_index[0]
+ return fw_layer_start, bw_layer_end, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num
+
+ def search_first_operator_idx_for_per_layer_enable_pp_other_stage(self, fw_norm_index, bw_norm_index):
+ fw_layer_start = []
+ bw_layer_end = []
+ recompute_fw = []
+ fw_norm_index = [fw_norm_index[i * 2: (i + 1) * 2] for i in range(len(fw_norm_index) // 2)]
+ bw_norm_index = [bw_norm_index[i * 2: (i + 1) * 2] for i in range(len(bw_norm_index) // 2)]
+ warm_micro_num = self._calculate_warm_micro_num()
+
+ for micro in range(self.micro_num):
+ if micro < warm_micro_num:
+ fw_layer_start.append([fw_norm_index[micro][0]])
+ else:
+ fw_layer_start.append([fw_norm_index[micro + micro - warm_micro_num + 1][0]])
+ recompute_fw.append([fw_norm_index[micro + micro - warm_micro_num][0]])
+ if micro == self.micro_num - 1:
+ recompute_fw.extend(
+ [[index[0]] for index in fw_norm_index[len(fw_norm_index) - warm_micro_num:]])
+ bw_layer_end.append([bw_norm_index[micro][-1]])
+ if self.search_cfg.is_full_recompute:
+ if len(recompute_fw) != self.micro_num:
+ for i in range(len(recompute_fw), self.micro_num):
+ recompute_fw.append([fw_norm_index[i + self.micro_num][0]])
+ bw_per_micro_opt_num = bw_norm_index[0][-1] - recompute_fw[0][0]
+ else:
+ bw_per_micro_opt_num = bw_norm_index[1][0] - bw_norm_index[0][0]
+ fw_per_micro_opt_num = fw_layer_start[1][0] - fw_layer_start[0][0]
+ return fw_layer_start, bw_layer_end, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num
+
+ def search_first_operator_idx_for_per_layer_enable_pp(self, fw_norm_index, bw_norm_index):
+ if self.stage_id == self.search_cfg.pp - 1:
+ return self.search_first_operator_idx_for_per_layer_enable_pp_last_stage(fw_norm_index, bw_norm_index)
+ else:
+ return self.search_first_operator_idx_for_per_layer_enable_pp_other_stage(fw_norm_index, bw_norm_index)
+
+ def search_first_operator_idx_for_per_layer_disable_pp(self, fw_norm_index, bw_norm_index):
+ fw_layer_start = []
+ bw_layer_end = []
+ recompute_fw = []
+ if self.search_cfg.is_full_recompute:
+ fw_micro_rms_num = len(fw_norm_index) // self.micro_num
+
+ fw_norm_index = [fw_norm_index[fw_micro_rms_num * i:fw_micro_rms_num * (i + 1)]
+ for i in range(self.micro_num)]
+ bw_micro_rms_num = len(bw_norm_index) // self.micro_num
+
+ bw_norm_index = [bw_norm_index[bw_micro_rms_num * i:bw_micro_rms_num * (i + 1)]
+ for i in range(self.micro_num)]
+ fw_per_micro_opt_num = fw_norm_index[0][2] - fw_norm_index[0][0]
+ bw_per_micro_opt_num = bw_norm_index[0][2] - bw_norm_index[0][0]
+
+ for micro in range(self.micro_num):
+ fw_layer_start.append([fw_norm_index[micro][0]])
+ bw_layer_end.append([bw_norm_index[micro][-1]])
+ recompute_fw.append([fw_norm_index[micro][3]])
+ else:
+ fw_per_micro_opt_num = fw_norm_index[2] - fw_norm_index[0]
+ bw_per_micro_opt_num = bw_norm_index[2] - bw_norm_index[0]
+
+ for micro in range(self.micro_num):
+ fw_layer_start.append([fw_norm_index[3 * micro]])
+ bw_layer_end.append([bw_norm_index[3 * (micro + 1) - 1]])
+ return fw_layer_start, bw_layer_end, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num
+
+ def _calculate_warm_micro_num(self):
+ if self.vpp != 1:
+ return self.search_cfg.pp * (self.vpp - 1) + 1 + (self.search_cfg.pp - self.stage_id - 1) * 2
+ else:
+ return self.search_cfg.pp - self.stage_id
+
+ def _calculate_fw_idx(self, fw_idx, i, micro):
+ if i * (self.vpp * self.search_cfg.pp) <= micro < i * (
+ self.vpp * self.search_cfg.pp) + self.search_cfg.pp and self.vpp > 1:
+ fw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
+ else:
+ fw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_LAST_STAGE
+ return fw_idx
+
+ def _calculate_bw_idx(self, bw_idx, i, micro):
+ if i * (self.vpp * self.search_cfg.pp) <= micro < i * (
+ self.vpp * self.search_cfg.pp) + self.search_cfg.pp or self.vpp == 1:
+ bw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_LAST_STAGE
+ else:
+ bw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
+ return bw_idx
+
+
+class ProfilingLayerInfo:
+ def __init__(self):
+ self.time = []
+ self.start_memory = []
+ self.peak_memory = []
+ self.reserved_memory = []
+ self.operator_info = []
+ self.communication_info = []
+
+ def extend_attr(self, new_layer):
+ for attr_name in self.__dict__.keys():
+ obj_attr = getattr(self, attr_name)
+ if isinstance(obj_attr, list):
+ target_attr = getattr(new_layer, attr_name, [])
+ obj_attr.extend(target_attr)
+ setattr(self, attr_name, obj_attr)
+
+
+class ProfilingModelInfo:
+ def __init__(self):
+ self.embedding = ProfilingLayerInfo()
+ self.forward = ProfilingLayerInfo()
+ self.loss = ProfilingLayerInfo()
+ self.backward = ProfilingLayerInfo()
+ self.optimizer = ProfilingLayerInfo()
+ self.hccl_memory = []
+ self.cann_and_driver_memory = []
+ self.communication_matrix = []
+ self.context_parallel_comm = []
+ self.pipeline_parallel_comm = []
+ self.data_parallel_comm = []
+ self.tensor_parallel_comm = []
+ self.expert_parallel_comm = []
+ self.search_cfg = None
+ self.stage_id = 0
+ self.mc2_total_time = []
+ self.matmul_total_time = []
+
+ def extend_stage_info(self, new_model):
+ for attr_name in self.__dict__.keys():
+ obj_attr = getattr(self, attr_name)
+ if isinstance(obj_attr, list):
+ target_attr = getattr(new_model, attr_name, [])
+ obj_attr.extend(target_attr)
+ setattr(self, attr_name, obj_attr)
+ elif isinstance(obj_attr, ProfilingLayerInfo):
+ target_attr = getattr(new_model, attr_name, None)
+ obj_attr.extend_attr(target_attr)
+
+
+class BaseParallelCommunication:
+ """
+ Basic parallel communication information.
+ """
+
+ def __init__(self):
+ self.total_time_ms: float = 0.0
+ self.wait_time_ms: float = 0.0
+ self.overlap_time_ms: float = 0.0
+ self.details: List[dict] = []
+
+
+class ExpertParallelCommunication(BaseParallelCommunication):
+ """
+ Expert parallel communication
+ """
+
+ def __init__(self):
+ super(ExpertParallelCommunication, self).__init__()
+ self.min_comm_time_ms: float = 0.0
+
+
+class TensorParallelCommunication(BaseParallelCommunication):
+ """
+ Tensor parallel communication
+ """
+
+ def __init__(self):
+ super(TensorParallelCommunication, self).__init__()
+ self.fixed_time_ms: float = 0.0
+ self.fixed_wait_time_ms: float = 0.0
+
+
+class ContextParallelCommunication(BaseParallelCommunication):
+ """
+ Context parallel communication
+ """
+
+ def __init__(self):
+ super(ContextParallelCommunication, self).__init__()
+ self.vector_time_ms: float = 0.0
+
+
+class DataParallelCommunication(BaseParallelCommunication):
+ """
+ Data parallel communication
+ """
+
+ def __init__(self):
+ super(DataParallelCommunication, self).__init__()
+ self.mlp_zero_time_ms: float = 0.0
+ self.mlp_ag_time_ms: float = 0.0
+ self.mlp_rs_time_ms: float = 0.0
+ self.other_zero_time_ms: float = 0.0
+ self.other_ag_time_ms: float = 0.0
+ self.other_rs_time_ms: float = 0.0
+
+
+class PipelineParallelCommunication(BaseParallelCommunication):
+ """
+ Pipeline parallel communication
+ """
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_constant.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_constant.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd731b04278e093662dff365da1085b2c6d06839
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_constant.py
@@ -0,0 +1,52 @@
+class NumberConstant:
+ """
+ Constant for number
+ """
+ CONVERSION_TIME = 1000.0
+ FW_NORM_OP_NUM_DISABLE_PP = 3
+ BW_NORM_OP_NUM_DISABLE_PP = 3
+ FW_NORM_OP_NUM_ENABLE_PP_LAST_STAGE = 3
+ FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE = 2
+
+ @property
+ def conversion_time(self: any) -> float:
+ """
+ time conversion us to ms
+ :return: time conversion
+ """
+ return self.CONVERSION_TIME
+
+
+class OperatorDetails:
+ def __init__(self, name, type_, input_shapes, output_shapes, duration_us, wait_time_us, accelerator_core):
+ self.name: str = name
+ self.type: str = type_
+ self.input_shapes: str = input_shapes
+ self.output_shapes: str = output_shapes
+ self.duration_us: float = duration_us
+ self.wait_time_us: float = wait_time_us
+ self.accelerator_core: str = accelerator_core
+
+
+class SpecialOperatorName:
+ EMBEDDING = 'embedding'
+ FW_RMS_NORM_TYPE = 'RmsNorm'
+ BW_RMS_NORM_TYPE = 'RmsNormGrad'
+ FW_LAYER_NORM_TYPE = 'LayerNormV3WithImplMode'
+ BW_LAYER_NORM_TYPE = 'LayerNormBetaGammaBackpropV2'
+ RMS_NORM = 'rms_norm'
+ LAYER_NORM = 'layer_norm'
+ BACKWARD = 'backward'
+
+
+class SpecialKeyName:
+ NAME = 'Name'
+ COMPONENT = 'Component'
+ TOTAL_RESERVED = 'Total Reserved(MB)'
+ ALLOCATED_MEMORY = 'Allocation Total Allocated(MB)'
+ ACCELERATOR_CORE = 'Accelerator Core'
+ DURATION_US = 'Duration(us)'
+ START_TIME_US = 'Start Time(us)'
+ ELAPSE_TIME_MS = 'Elapse Time(ms)'
+ WAIT_TIME_MS = 'Wait Time(ms)'
+ IDLE_TIME_MS = 'Idle Time(ms)'
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_memory_parse.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_memory_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..154a5777d6197c3679da2c8e0c889ee495e955f9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_memory_parse.py
@@ -0,0 +1,143 @@
+from typing import List
+
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_meta_parse import StructureAnalyseTool
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import SpecialOperatorName
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_config import ProfilingConfig
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import SpecialKeyName
+
+
+class AnalyseMemoryMsg(ProfilingConfig):
+ """ Analyse memory massage. """
+
+ def __init__(self, rank_file_path, search_cfg, memory_details, stage_id=0):
+ super(AnalyseMemoryMsg, self).__init__(search_cfg)
+ self._rank_file_path = rank_file_path
+ self._memory_details = memory_details
+ self._update_norm_op()
+ self.fw_memory_indices: List[List[int]]
+ self.bw_memory_indices: List[List[int]]
+ self.fw_memory_per_micro_opt_num: int
+ self.bw_memory_per_micro_opt_num: int
+ self.stage_id = stage_id
+
+ @staticmethod
+ def compare_memory(row, start_memory, peak_memory):
+ """compare memory"""
+ if start_memory == 0:
+ start_memory = float(row[SpecialKeyName.ALLOCATED_MEMORY])
+ peak_memory = max(peak_memory, float(row[SpecialKeyName.ALLOCATED_MEMORY]))
+ return start_memory, peak_memory
+
+ @staticmethod
+ def analyse_cann_and_driver(memory_record_details):
+ app_mem = 0
+ pta_mem = None
+ for row in memory_record_details:
+ if row[SpecialKeyName.COMPONENT] == 'APP':
+ app_mem = row[SpecialKeyName.TOTAL_RESERVED]
+ elif not pta_mem and row[SpecialKeyName.COMPONENT] == 'PTA':
+ pta_mem = row[SpecialKeyName.TOTAL_RESERVED]
+ if app_mem and pta_mem:
+ break
+ return [float(app_mem) - float(pta_mem)]
+
+ def update_norm_indices(self):
+ fw_memory_indices, bw_memory_indices = self._analyse_norm_op()
+ if self.search_cfg.pp > 1:
+ self.fw_memory_indices, \
+ self.bw_memory_indices, \
+ recompute_fw, \
+ self.fw_memory_per_micro_opt_num, \
+ self.bw_memory_per_micro_opt_num = \
+ self.search_first_operator_idx_for_per_layer_enable_pp(fw_memory_indices, bw_memory_indices)
+ else:
+ self.fw_memory_indices, \
+ self.bw_memory_indices, \
+ recompute_fw, \
+ self.fw_memory_per_micro_opt_num, \
+ self.bw_memory_per_micro_opt_num = \
+ self.search_first_operator_idx_for_per_layer_disable_pp(fw_memory_indices, bw_memory_indices)
+
+ def analyse_embedding(self):
+ em_start_memory, em_peak_memory = 0, 0
+ if self.stage_id != 0:
+ return [em_start_memory], [em_peak_memory]
+ embedding_start_idx = 0
+ for idx, msg in enumerate(self._memory_details[1:], start=1):
+ op_name = msg[SpecialKeyName.NAME]
+ if self.norm_op in op_name:
+ break
+ if SpecialOperatorName.EMBEDDING in op_name:
+ embedding_start_idx = idx
+ em_start_memory, em_peak_memory = self.compare_memory(self._memory_details[idx - 1],
+ em_start_memory, em_peak_memory)
+ if idx > embedding_start_idx != 0:
+ em_start_memory, em_peak_memory = self.compare_memory(msg, em_start_memory, em_peak_memory)
+
+ return [em_start_memory], [em_peak_memory]
+
+ def analyse_forward(self):
+ fw_start_memory = [0.0 for _ in range(self.micro_num)]
+ fw_peak_memory = [0.0 for _ in range(self.micro_num)]
+ for micro in range(self.micro_num):
+ self.fw_memory_indices[micro].append(
+ self.fw_memory_indices[micro][-1] + self.fw_memory_per_micro_opt_num - 1)
+ fw_start_memory[micro] = float(
+ self._memory_details[self.fw_memory_indices[micro][0]][SpecialKeyName.ALLOCATED_MEMORY])
+ for msg in self._memory_details[self.fw_memory_indices[micro][0]: self.fw_memory_indices[micro][-1]]:
+ fw_start_memory[micro], fw_peak_memory[micro] = \
+ self.compare_memory(msg, fw_start_memory[micro], fw_peak_memory[micro])
+
+ return fw_start_memory, fw_peak_memory
+
+ def analyse_loss(self):
+ ls_start_memory, ls_peak_memory = 0, 0
+ if self.stage_id != self.search_cfg.pp - 1:
+ return [ls_start_memory], [ls_peak_memory]
+ for idx, msg in enumerate(
+ self._memory_details[self.fw_memory_indices[0][-1] + 1: self.bw_memory_indices[0][0]]):
+ if 'norm' in self._memory_details[idx + 1 + self.fw_memory_indices[0][-1] + 1][SpecialKeyName.NAME]:
+ continue
+ ls_start_memory, ls_peak_memory = self.compare_memory(msg, ls_start_memory, ls_peak_memory)
+ return [ls_start_memory], [ls_peak_memory]
+
+ def analyse_backward(self):
+ bw_start_memory = [0.0 for _ in range(self.micro_num)]
+ bw_peak_memory = [0.0 for _ in range(self.micro_num)]
+ for micro in range(self.micro_num):
+ self.bw_memory_indices[micro].insert(0,
+ self.bw_memory_indices[micro][-1] - self.bw_memory_per_micro_opt_num)
+ bw_start_memory[micro] = float(
+ self._memory_details[self.bw_memory_indices[micro][0]][SpecialKeyName.ALLOCATED_MEMORY])
+ for msg in self._memory_details[self.bw_memory_indices[micro][0]: self.bw_memory_indices[micro][-1]]:
+ bw_start_memory[micro], bw_peak_memory[micro] = \
+ self.compare_memory(msg, bw_start_memory[micro], bw_peak_memory[micro])
+
+ return bw_start_memory, bw_peak_memory
+
+ def analyse_optimizer(self):
+ op_start_memory, op_peak_memory = 0, 0
+ for msg in self._memory_details[self.bw_memory_indices[-1][-1] + 1:]:
+ op_start_memory, op_peak_memory = self.compare_memory(msg, op_start_memory, op_peak_memory)
+ return [op_start_memory], [op_peak_memory]
+
+ def _analyse_norm_op(self):
+ fw_memory_indices, bw_memory_indices = [], []
+ for index, row in enumerate(self._memory_details[1:], start=1):
+ if self.norm_op in self._memory_details[index - 1][SpecialKeyName.NAME]:
+ continue
+ if self.norm_op in row[SpecialKeyName.NAME] \
+ and SpecialOperatorName.BACKWARD not in row[SpecialKeyName.NAME]:
+ fw_memory_indices.append(index)
+ elif self.norm_op in row[SpecialKeyName.NAME] \
+ and SpecialOperatorName.BACKWARD in row[SpecialKeyName.NAME]:
+ bw_memory_indices.append(index)
+
+ return fw_memory_indices, bw_memory_indices
+
+ def _update_norm_op(self):
+ structure_cls = StructureAnalyseTool(self._rank_file_path, self._memory_details)
+ if structure_cls.fw_norm_op == SpecialOperatorName.FW_LAYER_NORM_TYPE:
+ self.norm_op = SpecialOperatorName.LAYER_NORM
+ else:
+ self.norm_op = SpecialOperatorName.RMS_NORM
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_meta_parse.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_meta_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f6b9470426492d4d15204688b6a94f9a4c57017
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_meta_parse.py
@@ -0,0 +1,95 @@
+import csv
+import json
+import os
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import SpecialOperatorName
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import NumberConstant
+from mindspeed.auto_tuning.utils.file_utils import check_file_size
+
+
+class FileAnalyseTool:
+ """
+ support csv and json parse
+ """
+
+ @classmethod
+ def analyse_csv_info(cls, file_path: str, csv_name: str):
+ csv_path = os.path.join(file_path, csv_name)
+ try:
+ with open(csv_path, newline='') as csvfile:
+ check_file_size(csvfile)
+ reader = csv.DictReader(csvfile)
+ csv_details = list(reader)
+
+ except FileNotFoundError as e:
+ raise f"'Please check file name, {e}"
+ except csv.Error as e:
+ raise f"An error occurred while reading the CSV file: {e}"
+ return csv_details
+
+ @classmethod
+ def analyse_json_info(cls, file_path: str, json_name: str):
+ json_path = os.path.join(file_path, json_name)
+ json_details = {"p2p": {}, "collective": {}}
+ try:
+ with open(json_path, 'r') as f:
+ check_file_size(f)
+ details = json.load(f)
+ details_value = list(details.values())[0]
+ for name, info in details_value.get('p2p', {}).items():
+ comm_name = name.split("@")[0]
+ json_details['p2p'][comm_name] = info["Communication Time Info"]
+ for name, info in details_value.get('collective', {}).items():
+ comm_name = name.split("@")[0]
+ json_details['collective'][comm_name] = info["Communication Time Info"]
+ except KeyError as e:
+ raise f"'Please check file name, {e}"
+ except Exception as e:
+ raise f"Read communication file error: {e}"
+
+ return json_details
+
+
+class StructureAnalyseTool:
+ """
+ support structure parse
+ """
+
+ def __init__(self, rank_file_path, memory_details):
+ self._rank_file_path = rank_file_path
+ self._memory_details = memory_details
+ self.fw_norm_op = SpecialOperatorName.FW_RMS_NORM_TYPE
+ self.bw_norm_op = SpecialOperatorName.BW_RMS_NORM_TYPE
+ self._search_special_norm_op()
+
+ def analyse_norm_op(self):
+ """ Analyse the norm op details in kernel_details.csv. """
+ fw_norm_op_idx_list = []
+ bw_norm_op_idx_list = []
+ matmul_total_time = 0
+ mc2_total_time = 0
+ for idx, row in enumerate(self._memory_details):
+ if "Name" not in row or "Type" not in row:
+ continue
+ if row["Type"] == "MatMulCommon":
+ time = float(row["Duration(us)"]) / NumberConstant.CONVERSION_TIME
+ matmul_total_time += time
+ mc2_total_time += time
+ if row["Type"] == "AllGatherMatmul" or row["Type"] == "MatmulReduceScatter":
+ mc2_total_time += float(row["Duration(us)"]) / NumberConstant.CONVERSION_TIME
+ if row["Type"] == self.fw_norm_op:
+ fw_norm_op_idx_list.append(idx)
+ elif row["Type"] == self.bw_norm_op:
+ bw_norm_op_idx_list.append(idx)
+ return fw_norm_op_idx_list, bw_norm_op_idx_list, matmul_total_time, mc2_total_time
+
+ def get_fw_norm_op(self):
+ return self.fw_norm_op
+
+ def _search_special_norm_op(self):
+ """ Special norm op: rms_norm, layer_norm, rms_norm_grad """
+ op_statistic_details = FileAnalyseTool.analyse_csv_info(self._rank_file_path, 'op_statistic.csv')
+ for op in op_statistic_details:
+ if SpecialOperatorName.FW_LAYER_NORM_TYPE in op['OP Type']:
+ self.fw_norm_op = SpecialOperatorName.FW_LAYER_NORM_TYPE
+ self.bw_norm_op = SpecialOperatorName.BW_LAYER_NORM_TYPE
+ break
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_node_parse.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_node_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d09869aef1afe7effce1d341a705bd2913f8c2a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_node_parse.py
@@ -0,0 +1,94 @@
+import os
+import stat
+import pickle
+import subprocess
+import time
+
+import torch
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.utils.restricted_unpickler import restricted_loads
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_config import ProfilingModelInfo
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_parse import ProfilingParser
+
+
+class GatherNodeProfiling:
+ """
+ Gather other node profiling result to rank0
+ """
+
+ def __init__(self, profiling_file_path):
+ self.profiling_file_path = profiling_file_path
+ self.fusion_model = ProfilingModelInfo()
+ self.stage_id_list = []
+ self.logger = get_logger('profiling_parser')
+
+ @staticmethod
+ def _extend_stage_lists(source, target):
+ source.time.extend(target.time)
+ source.start_memory.extend(target.start_memory)
+ source.peak_memory.extend(target.peak_memory)
+ source.communication_info.extend(target.communication_info)
+ source.operator_info.extend(target.operator_info)
+
+ def fuse_node_pkl(self):
+ """
+ Args:
+ pkl_path: str
+
+ Returns:
+ fusion_model: ProfilingModelInfo
+ """
+ pkl_path = os.path.join(self.profiling_file_path, 'pkl_path')
+ pkl_files = sorted(os.listdir(pkl_path))
+ if len(pkl_files) > 1:
+ self.logger.info(f'Get pp profiling parse result.')
+ for pkl_file in pkl_files:
+ node_pkl_path = os.path.join(pkl_path, pkl_file)
+ with open(node_pkl_path, 'rb') as f:
+ pkl_model = restricted_loads(f)
+ self._fuse_models(pkl_model)
+ else:
+ node_pkl_path = os.path.join(pkl_path, pkl_files[0])
+ with open(node_pkl_path, 'rb') as f:
+ pkl_model = restricted_loads(f)
+ self.fusion_model = pkl_model
+ return self.fusion_model
+
+ def parse_node_pkl(self, args):
+ parent_dir = os.path.dirname(self.profiling_file_path)
+ ootb_node_path = os.path.join(parent_dir, f'ootb_{args.node_rank}.pkl')
+ with open(ootb_node_path, 'rb') as f:
+ cfg = restricted_loads(f)
+ profiling_parser = ProfilingParser(self.profiling_file_path, search_cfg=cfg, args=args)
+ profiling_res = profiling_parser.parser()
+ if args.pipeline_model_parallel_size > 1 and profiling_parser.nodes > 1:
+ ranks = [i * profiling_parser.devices_per_node for i in range(profiling_parser.nodes)]
+ profiling_group = torch.distributed.new_group(ranks)
+ gather_objects = [None for _ in range(profiling_parser.nodes)]
+ torch.distributed.all_gather_object(gather_objects, profiling_res, group=profiling_group)
+ for i in range(profiling_parser.nodes):
+ pkl_path = os.path.join(self.profiling_file_path, 'pkl_path')
+ if not os.path.exists(pkl_path):
+ os.mkdir(pkl_path)
+ pkl_node_path = os.path.join(pkl_path, f'node_{i}.pkl')
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ with os.fdopen(os.open(pkl_node_path, flags, mode=mode), 'wb') as f:
+ pickle.dump(gather_objects[i], f)
+
+ torch.distributed.barrier(group=profiling_group)
+ torch.distributed.destroy_process_group(group=profiling_group)
+ else:
+ pkl_path = os.path.join(self.profiling_file_path, 'pkl_path')
+ if not os.path.exists(pkl_path):
+ os.mkdir(pkl_path)
+ pkl_node_path = os.path.join(pkl_path, f'node_{args.node_rank}.pkl')
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ with os.fdopen(os.open(pkl_node_path, flags, mode=mode), 'wb') as f:
+ pickle.dump(profiling_res, f)
+
+ def _fuse_models(self, new_model):
+ if new_model.stage_id not in self.stage_id_list:
+ self.stage_id_list.append(new_model.stage_id)
+ self.fusion_model.extend_stage_info(new_model)
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_operator_parse.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_operator_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..8726824568a715afe31ce2b407d28b9cd745255f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_operator_parse.py
@@ -0,0 +1,39 @@
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_constant import OperatorDetails
+
+
+class AnalyseOperatorMsg:
+ """ Analyse operator message. """
+
+ def __init__(self, operator_details):
+ self._operator_details = operator_details
+
+ def analyse_embedding(self, start_idx, end_idx):
+ return self._analyse_operators(start_idx, end_idx)
+
+ def analyse_forward(self, start_idx, end_idx):
+ return self._analyse_operators(start_idx, end_idx)
+
+ def analyse_loss(self, start_idx, end_idx):
+ return self._analyse_operators(start_idx, end_idx)
+
+ def analyse_backward(self, start_idx, end_idx):
+ return self._analyse_operators(start_idx, end_idx)
+
+ def analyse_optimizer(self, start_idx, end_idx):
+ return self._analyse_operators(start_idx, end_idx)
+
+ def _analyse_operators(self, start_idx, end_idx):
+ details_list = []
+ for i in range(start_idx, end_idx):
+ detail = self._operator_details[i]
+ op_detail = OperatorDetails(
+ name=detail['Name'],
+ type_=detail['Type'],
+ input_shapes=detail['Input Shapes'],
+ output_shapes=detail['Output Shapes'],
+ duration_us=detail['Duration(us)'],
+ wait_time_us=detail['Wait Time(us)'],
+ accelerator_core=detail['Accelerator Core']
+ )
+ details_list.append(op_detail)
+ return details_list
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_parse.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d7f57854a2c492ef9db702a5cd83fa556c0b12
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/profiling_parse/profiling_parse.py
@@ -0,0 +1,200 @@
+import math
+import os
+import re
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_config import ProfilingConfig, \
+ ProfilingModelInfo
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_meta_parse import StructureAnalyseTool
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_operator_parse import AnalyseOperatorMsg
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_communication_parse import \
+ AnalyseCommunicationMsg
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_memory_parse import AnalyseMemoryMsg
+from mindspeed.auto_tuning.module.parse.profiling_parse.profiling_meta_parse import FileAnalyseTool
+
+
+class ProfilingParser(ProfilingConfig):
+ def __init__(self, root_path, search_cfg=None, args=None):
+ super(ProfilingParser, self).__init__(search_cfg, args)
+ self._root_path = root_path
+ self._ascend_operator_details = None
+ self.stage_id = 0
+ self.rank_file_path = None
+ self.model = ProfilingModelInfo()
+ self.logger = get_logger('profiling_parser')
+
+ def parse_fw_bw_structure(self, fw_norm_op_idx_list, bw_norm_op_idx_list):
+ if self.search_cfg.pp > 1:
+ fw_layer_start_index, bw_layer_start_index, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num = \
+ self.search_first_operator_idx_for_per_layer_enable_pp(fw_norm_op_idx_list, bw_norm_op_idx_list)
+ else:
+ fw_layer_start_index, bw_layer_start_index, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num = \
+ self.search_first_operator_idx_for_per_layer_disable_pp(fw_norm_op_idx_list, bw_norm_op_idx_list)
+ for micro in range(self.micro_num):
+ if self.per_micro_layer != 1:
+ fw_per_micro_opt_num = fw_layer_start_index[micro][-1] - fw_layer_start_index[micro][-2]
+ bw_per_micro_opt_num = bw_layer_start_index[micro][-1] - bw_layer_start_index[micro][-2]
+ fw_layer_start_index[micro].append(fw_layer_start_index[micro][-1] + fw_per_micro_opt_num - 1)
+ bw_layer_start_index[micro].insert(0, bw_layer_start_index[micro][-1] - bw_per_micro_opt_num)
+ return fw_layer_start_index, bw_layer_start_index
+
+ def parse_model_structure(self):
+ self._update_profiling_file_path()
+ kernel_details = FileAnalyseTool.analyse_csv_info(self.rank_file_path, 'kernel_details.csv')
+ communication_details = FileAnalyseTool.analyse_json_info(self.rank_file_path, 'communication.json')
+ memory_details = FileAnalyseTool.analyse_csv_info(self.rank_file_path, 'operator_memory.csv')
+ memory_record_details = FileAnalyseTool.analyse_csv_info(self.rank_file_path, 'memory_record.csv')
+ structure_cls = StructureAnalyseTool(self.rank_file_path, kernel_details)
+ fw_norm_op_idx_list, bw_norm_op_idx_list, matmul_total_time, mc2_total_time = structure_cls.analyse_norm_op()
+ self.model.matmul_total_time = [matmul_total_time]
+ self.model.mc2_total_time = [mc2_total_time]
+ fw_layer_start_index, bw_layer_start_index = self.parse_fw_bw_structure(fw_norm_op_idx_list,
+ bw_norm_op_idx_list)
+
+ self._parse_operator_info(kernel_details, fw_layer_start_index, bw_layer_start_index)
+ self._parse_communication_info(communication_details, kernel_details)
+ self._parse_memory_info(memory_details, memory_record_details)
+
+ def parser(self):
+ """
+ Parse profiling files.
+ Returns:
+ model: ProfilingModelInfo
+ """
+ self.logger.info('>>>> Profiling parse starting!')
+ self._parse_each_node()
+ self.logger.info('>>>> Profiling parse success!')
+ return self.model
+
+ def _validate_file_path(self, filename, attr_name):
+ file_path = os.path.join(self.rank_file_path, filename)
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"The file {file_path} was not found.")
+ setattr(self, attr_name, file_path)
+
+ def _update_profiling_file_path(self):
+ self._validate_file_path('kernel_details.csv', '_kernel_details_csv_path')
+ self._validate_file_path('memory_record.csv', '_memory_record_csv_path')
+ self._validate_file_path('operator_memory.csv', '_operator_memory_csv_path')
+ self._validate_file_path('npu_module_mem.csv', '_npu_module_mem_csv_path')
+ self._validate_file_path('communication.json', '_communication_json_path')
+ self._validate_file_path('op_statistic.csv', '_op_statistic_csv_path')
+
+ def _extract_rank_file_path(self):
+ """
+ Get all rank file path, the profiling process generates the profiler_info_{rank_id}.json file.
+ Returns:
+ rank_file_path: Dict[rank_id] = path
+ """
+
+ def extract_rankid_from_filename(filename):
+ match = re.search(r'profiler_info_(\d+)\.json', filename)
+ if match:
+ return int(match.group(1))
+ else:
+ return None
+
+ rank_file_path = {}
+ for ascend_dir in os.listdir(self._root_path):
+ profiling_path = os.path.join(self._root_path, ascend_dir)
+ if os.path.isdir(profiling_path) and 'ascend' in ascend_dir:
+ json_files = [f
+ for f in os.listdir(profiling_path)
+ if f.endswith('.json') and f.startswith('profiler_info_')]
+ if not json_files:
+ raise ValueError(f"Args profile error, JSON is not exist in {ascend_dir}.")
+
+ rank_id = extract_rankid_from_filename(json_files[0])
+ if rank_id is not None:
+ rank_file_path[rank_id] = profiling_path
+ return rank_file_path
+
+ def _join_rank_ascend_path(self, file_name):
+ rank_file_path = os.path.join(self._root_path, file_name, "ASCEND_PROFILER_OUTPUT")
+ if not os.path.exists(rank_file_path):
+ raise f" {rank_file_path} is not exist."
+ return rank_file_path
+
+ def _get_first_rank_and_stage_id_of_each_stage(self, node_first_rank_id, devices_each_stage, rank_file_path):
+ """
+ Get the rank file path based on the number of devices each stage. For example:
+ devices_each_node devices_each_stage node pp
+ 1. 8 16 2 1
+ 2. 8 8 2 2
+ 3. 8 4 2 4
+ """
+ if devices_each_stage == self.devices_per_node:
+ return self._join_rank_ascend_path(rank_file_path[node_first_rank_id]), self.node_rank
+ elif devices_each_stage < self.devices_per_node:
+ paths_and_ids = []
+ stage_num_each_node = math.ceil(len(rank_file_path) / devices_each_stage)
+ for i in range(stage_num_each_node):
+ cur_stage_rank = i * devices_each_stage + node_first_rank_id
+ cur_stage_id = i + self.node_rank * stage_num_each_node
+ paths_and_ids.append((self._join_rank_ascend_path(rank_file_path[cur_stage_rank]), cur_stage_id))
+ return paths_and_ids
+ else:
+ return self._join_rank_ascend_path(rank_file_path[node_first_rank_id]), self.node_rank // (
+ self.nodes // self.search_cfg.pp)
+
+ def _parse_first_rank_of_each_stage(self, rank_file_path: dict):
+ """Parses the first rank file of each stage."""
+ node_first_rank_id = self.node_rank * self.devices_per_node
+ devices_each_stage = self.nodes * self.devices_per_node // self.search_cfg.pp
+ paths_and_ids = self._get_first_rank_and_stage_id_of_each_stage(node_first_rank_id, devices_each_stage,
+ rank_file_path)
+ if isinstance(paths_and_ids, list):
+ for path, stage_id in paths_and_ids:
+ self.rank_file_path = path
+ self.stage_id = stage_id
+ self.model.stage_id = stage_id
+ self.parse_model_structure()
+ else:
+ self.rank_file_path, self.stage_id = paths_and_ids
+ self.model.stage_id = self.stage_id
+ self.parse_model_structure()
+
+ def _parse_each_node(self):
+ rank_file_path = self._extract_rank_file_path()
+ self._parse_first_rank_of_each_stage(rank_file_path)
+
+ def _parse_operator_info(self, kernel_details, fw_layer_start_index, bw_layer_start_index):
+ operator = AnalyseOperatorMsg(kernel_details)
+ embedding_operator = operator.analyse_embedding(0, fw_layer_start_index[0][0] - 1)
+ forward_operator = operator.analyse_forward(fw_layer_start_index[0][0], fw_layer_start_index[0][-1])
+ loss_operator = operator.analyse_loss(fw_layer_start_index[0][-1], bw_layer_start_index[0][0] - 1)
+ backward_operator = operator.analyse_backward(bw_layer_start_index[0][0], bw_layer_start_index[0][-1])
+ optimizer_operator = operator.analyse_optimizer(bw_layer_start_index[0][-1] + 1, len(kernel_details) - 1)
+ self.model.embedding.operator_info.append(embedding_operator)
+ self.model.forward.operator_info.append(forward_operator)
+ self.model.loss.operator_info.append(loss_operator)
+ self.model.backward.operator_info.append(backward_operator)
+ self.model.optimizer.operator_info.append(optimizer_operator)
+
+ def _parse_memory_info(self, memory_details, memory_record_details):
+ memory_cls = AnalyseMemoryMsg(self.rank_file_path, self.search_cfg, memory_details, stage_id=self.stage_id)
+ memory_cls.update_norm_indices()
+ embedding_start, embedding_peak = memory_cls.analyse_embedding()
+ self.model.embedding.start_memory.append(embedding_start)
+ self.model.embedding.peak_memory.append(embedding_peak)
+ fw_start, fw_peak = memory_cls.analyse_forward()
+ self.model.forward.start_memory.append(fw_start)
+ self.model.forward.peak_memory.append(fw_peak)
+ loss_start, loss_peak = memory_cls.analyse_loss()
+ self.model.loss.start_memory.append(loss_start)
+ self.model.loss.peak_memory.append(loss_peak)
+ bw_start, bw_peak = memory_cls.analyse_backward()
+ self.model.backward.start_memory.append(bw_start)
+ self.model.backward.peak_memory.append(bw_peak)
+ optimizer_start, optimizer_peak = memory_cls.analyse_optimizer()
+ self.model.optimizer.start_memory.append(optimizer_start)
+ self.model.optimizer.peak_memory.append(optimizer_peak)
+ self.model.cann_and_driver_memory = memory_cls.analyse_cann_and_driver(memory_record_details)
+
+ def _parse_communication_info(self, communication_details, kernel_details):
+ communication_cls = AnalyseCommunicationMsg(self.search_cfg, communication_details, kernel_details)
+ communication_cls.analyse_parallel_comm()
+ self.model.tensor_parallel_comm.append(communication_cls.get_tp_comm())
+ self.model.pipeline_parallel_comm.append(communication_cls.get_pp_comm())
+ self.model.data_parallel_comm.append(communication_cls.get_dp_comm())
+ self.model.context_parallel_comm.append(communication_cls.get_cp_comm())
+ self.model.expert_parallel_comm.append(communication_cls.get_ep_comm())
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/recompute_module_info.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/recompute_module_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c02337d64ce26d0a3d243bf6f03374069833cb8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/recompute_module_info.py
@@ -0,0 +1,12 @@
+from typing import Dict
+
+
+class ModuleRecomputeInfo:
+ def __init__(self, context: Dict):
+ self.name = context.get("name")
+ self.prefix_name = context.get("prefix_name")
+ self.full_name = self.prefix_name + '.' + self.name
+ self.memory = context.get("memory")
+ self.input_size = context.get("input")
+ self.time = context.get("time")
+ self.recompute = False
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/recompute_parser.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/recompute_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e34c6dbae4c5c08b44f6324c96a3042eb750b64
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/parse/recompute_parser.py
@@ -0,0 +1,286 @@
+import os
+import stat
+
+from functools import wraps
+from collections.abc import Iterable
+from typing import Dict, List
+import pickle
+import acl
+import torch
+import torch.nn
+from megatron.training.global_vars import get_args
+
+from mindspeed.core.memory.adaptive_recomputing.swap_manager import get_tensor_mem_size
+
+
+class RecomputeParser:
+ recompute_parser = None
+
+ def __init__(self):
+ # layer profiling info
+ self.context = {
+ 'module': []
+ }
+ self.models = None
+ # record allowed recomputing module
+ self.allowed_recomputing_module = []
+ # profiling prefix
+ self.profiling_prefix = ""
+ # save modules hook, remove it after apply policy
+ self.modules_hooks = []
+ # current profiling step
+ self.profiling_step = 0
+ # step skip profiling, default is 3
+ self.skip_profiling_step = 3
+ # step for stop profiling, default is 6
+ self.stop_profiling_step = 6
+ # unit for device memory size(MB)
+ self.unit_mb = 1024 * 1024
+ # store all module event
+ '''
+ {
+ full_name1: [[, ][, ][, ][, ]]
+ full_name2: [[, ][, ][, ]]
+ full_name3: [[, ][, ]]
+ }
+ '''
+ self.event_dict: Dict[str, List] = {}
+
+ @staticmethod
+ def get_memory_status():
+ free, all_memory, _ = acl.rt.get_mem_info(1)
+ memory_info = {
+ "free": free,
+ "all_memory": all_memory,
+ "used_memory": torch.npu.memory_allocated(),
+ "reserved_memory": torch.npu.memory_reserved(),
+ "max_memory_allocated": torch.npu.max_memory_allocated()
+ }
+
+ return memory_info
+
+ def pre_hook_func(self, state, *args, **kargs):
+ if 'memory' not in state:
+ state['memory'] = 0
+ state['input'] = self.cal_input_output_size(args)
+ if self.profiling_step == self.stop_profiling_step:
+ state['memory'] = torch.npu.memory_allocated() - state['input'] * self.unit_mb
+ print(f"success print pre hook memory = {state['memory']}")
+ cur_module_full_name = state['prefix_name'] + '.' + state['name']
+ if cur_module_full_name not in self.event_dict.keys():
+ self.event_dict[cur_module_full_name] = []
+ if self.profiling_step < self.stop_profiling_step:
+ start_event = torch.npu.Event(enable_timing=True)
+ self.event_dict[cur_module_full_name].append([start_event])
+ start_event.record()
+
+ def post_hook_func(self, state, args, output):
+ if self.profiling_step < self.stop_profiling_step:
+ cur_module_full_name = state['prefix_name'] + '.' + state['name']
+ end_event = torch.npu.Event(enable_timing=True)
+ end_event.record()
+ # add end_event to corresponding position of list
+ for item in reversed(self.event_dict[cur_module_full_name]):
+ if len(item) == 1:
+ item.append(end_event)
+ break
+
+ if self.profiling_step == self.stop_profiling_step:
+ output_memory = self.cal_input_output_size(output)
+ state['memory'] = (torch.npu.memory_allocated() - state['memory']) // self.unit_mb
+ print(f"success print post hook memory = {state['memory']} and output_memory = {output_memory}")
+ state['input'] += output_memory
+
+ def forward_pre_hook(self, ctx):
+ def hook(module, *args, **kargs):
+ if 'module' in self.context:
+ self.context['module'].append(ctx)
+ self.pre_hook_func(ctx, *args, **kargs)
+
+ return hook
+
+ def forward_post_hook(self, ctx):
+ def hook(module, args, output):
+ self.post_hook_func(ctx, args, output)
+ if 'module' in self.context:
+ self.context['module'].pop()
+
+ return hook
+
+ def construct_context_recursive(self, prefix_name, model, ctx, have_allowed_recomputing):
+ # 1.construct context
+ next_have_allowed_recomputing = have_allowed_recomputing
+ for name, module in model.named_children():
+ if 'layers' not in ctx:
+ ctx['layers'] = []
+
+ current_ctx = {'name': name, 'prefix_name': prefix_name}
+ if 'layers' in ctx:
+ ctx['layers'].append(current_ctx)
+
+ next_name = prefix_name + "." + name if prefix_name != "" else name
+
+ # 2.tag allowed_recomputing module
+ if have_allowed_recomputing:
+ for allowed_recomputing_module in self.allowed_recomputing_module:
+ if isinstance(module, allowed_recomputing_module):
+ current_ctx['allowed_recomputing'] = True
+ if isinstance(model, torch.nn.ModuleList):
+ ctx['is_module_list'] = True
+ ctx['is_recomputing_layer'] = True
+ else:
+ current_ctx['is_recomputing_layer'] = True
+ next_have_allowed_recomputing = False
+ self.construct_context_recursive(next_name, module, current_ctx, next_have_allowed_recomputing)
+
+ def register_recursive_hook(self, model, ctx, profiling_prefix, layer_index=0):
+ index = layer_index or 0
+ for module in model.children():
+ if 'layers' not in ctx:
+ continue
+ current_ctx = ctx['layers'][index]
+ prefix_name = current_ctx['prefix_name']
+ name = current_ctx['name']
+
+ is_recomputing_layer = not isinstance(module, torch.nn.ModuleList) and 'is_recomputing_layer' in current_ctx
+ is_allowed_recomputing = 'allowed_recomputing' in current_ctx and index == 0
+ if is_recomputing_layer or is_allowed_recomputing:
+ profiling_prefix = prefix_name + "." + name
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+ elif profiling_prefix and prefix_name.startswith(profiling_prefix):
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+ self.register_recursive_hook(module, current_ctx, profiling_prefix)
+ index += 1
+
+ def reset_modules(self):
+ if torch.distributed.get_rank() % 8 == 0:
+ ootb_context_path = get_args().profile_save_path
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ ootb_context_path_json = f'{ootb_context_path}.json'
+ with os.fdopen(os.open(ootb_context_path_json, flags, mode=mode), "wb") as file:
+ file.write(pickle.dumps(self.context))
+
+ def hook_step_func(self, step_func, models):
+ def custom_step_func(*args, **kargs):
+ result = step_func(*args, **kargs)
+ if self.profiling_step >= self.stop_profiling_step + 1:
+ return result
+ memory_info = self.get_memory_status()
+ try:
+ self.context['used_mem'] = memory_info["used_memory"] // self.unit_mb
+ self.context['max_device_memory'] = memory_info["all_memory"] // self.unit_mb
+ except KeyError:
+ print("[ERROR] Some of these keys don't exist.")
+ self.profiling_step += 1
+ torch.npu.synchronize()
+ # record module time
+ cal_module_forward_time(self.context, self.event_dict)
+
+ # reset modules
+ if self.profiling_step == self.stop_profiling_step + 1:
+ self.reset_modules()
+ return result
+ return custom_step_func
+
+ def add_allowed_recomputing_module(self, module):
+ if module not in self.allowed_recomputing_module:
+ self.allowed_recomputing_module.append(module)
+ print(f"after append self.allowed_recomputing_module = {self.allowed_recomputing_module} and module = {module}")
+
+ def cal_input_output_size(self, args):
+ size = 0
+ if isinstance(args, torch.Tensor):
+ size += get_tensor_mem_size(args)
+ return size // self.unit_mb
+ for arg in args:
+ if isinstance(arg, torch.Tensor):
+ size += get_tensor_mem_size(arg)
+ elif isinstance(arg, Iterable):
+ for t in arg:
+ if isinstance(t, torch.Tensor):
+ size += get_tensor_mem_size(t)
+ elif t is None:
+ pass
+ else:
+ print(f"warning: unknown input/output type {str(type(t))}")
+ elif arg is None:
+ pass
+ else:
+ print(f"warning: unknown input/output type {str(type(arg))}")
+ return size // self.unit_mb
+
+
+def get_recompute_parser():
+ if RecomputeParser.recompute_parser is None:
+ RecomputeParser.recompute_parser = RecomputeParser()
+ return RecomputeParser.recompute_parser
+
+
+def setup_model_and_optimizer_decorator(setup_model_and_optimizer):
+ @wraps(setup_model_and_optimizer)
+ def wrapper(*args, **kargs):
+ models, optimizer, opt_param_scheduler = setup_model_and_optimizer(*args, **kargs)
+ if os.getenv('OOTB_OPTIMIZER_PROFILING', 'FALSE') != 'TRUE':
+ print("OOTB_OPTIMIZER_PROFILING wrapper Error!")
+ return models, optimizer, opt_param_scheduler
+ print("OOTB_OPTIMIZER_PROFILING wrapper success!")
+ recompute_parser = get_recompute_parser()
+ recompute_parser.models = models
+ optimizer.step = recompute_parser.hook_step_func(optimizer.step, models)
+
+ if isinstance(models, list):
+ for model in models:
+ recompute_parser.construct_context_recursive("module", model, recompute_parser.context, True)
+ else:
+ recompute_parser.construct_context_recursive("module", models, recompute_parser.context, True)
+ print("OOTB_OPTIMIZER-MODEL-PARSER: successfully hooking module")
+ return models, optimizer, opt_param_scheduler
+
+ return wrapper
+
+
+def call_hook_func():
+ print("success enter call_hook_func")
+ recompute_parser = get_recompute_parser()
+ models = recompute_parser.models
+ if isinstance(models, list):
+ for index, model in enumerate(models):
+ recompute_parser.register_recursive_hook(model, recompute_parser.context,
+ recompute_parser.profiling_prefix, index)
+ else:
+ recompute_parser.register_recursive_hook(models, recompute_parser.context,
+ recompute_parser.profiling_prefix)
+
+
+def allowed_recompute_parser_module_wrapper(allowed_recomputing_module):
+ recomputing = get_recompute_parser()
+ recomputing.add_allowed_recomputing_module(allowed_recomputing_module)
+
+
+def cal_module_forward_time(context, event_dict: Dict[str, List]):
+ cur_module_full_name = context.get('prefix_name', "") + '.' + context.get('name', "")
+ if "memory" in context and cur_module_full_name in event_dict.keys():
+ cur_module_event_list = event_dict.get(cur_module_full_name, [])
+ for cur_level_event_list in cur_module_event_list:
+ start_event = cur_level_event_list[0]
+ end_event = cur_level_event_list[1]
+ total_time = start_event.elapsed_time(end_event)
+
+ context['forward_cnt'] = context.get('forward_cnt', 0) + 1
+ context['pre_total_time'] = context.get('pre_total_time', 0) + total_time
+ try:
+ context['time'] = context['pre_total_time'] / context['forward_cnt']
+ except ZeroDivisionError:
+ context['time'] = 0
+
+ if "layers" not in context:
+ return
+ for sub_layer_context in context["layers"]:
+ cal_module_forward_time(sub_layer_context, event_dict)
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/search/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/search/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/search/recompute_solver.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/search/recompute_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..c202d20adea4eb58b0fb750042480da08ae203ac
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/search/recompute_solver.py
@@ -0,0 +1,304 @@
+from copy import deepcopy
+from typing import List, Dict
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.module.parse.recompute_module_info import ModuleRecomputeInfo
+
+
+class RecomputeSolver:
+
+ def __init__(self, first_layer_context, perf, static_memory, memory_limit, search_cfg: SearchConfig, model_config):
+ self.num_layers_per_pp = model_config.num_layers // search_cfg.pipeline_model_parallel_size
+ self.layer_num_per_chunk = 0
+ self.virtual_pipeline_model_parallel_size = 1 if not search_cfg.num_layers_per_virtual_pipeline_stage \
+ else (search_cfg.num_layers // search_cfg.num_layers_per_virtual_pipeline_stage //
+ search_cfg.pipeline_model_parallel_size)
+ self.search_config = search_cfg
+ self.model_config = model_config
+ self.module_layers: List[ModuleRecomputeInfo] = []
+ self.parent_layers: List[ModuleRecomputeInfo] = []
+ self.parent_children_dict: Dict[str, List[ModuleRecomputeInfo]] = {}
+
+ self.first_layer_context = first_layer_context
+ self.first_layer_recompute_info = ModuleRecomputeInfo(self.first_layer_context)
+ self.full_recompute_performance = perf
+ self.static_memory = static_memory
+ self.memory_limit = memory_limit
+
+ self.recompute_module: Dict[str, ModuleRecomputeInfo] = {}
+
+ self.layers_combination: List[LayerCombination] = []
+ self.layer_full_recompute_combination: LayerCombination = None
+ self.layer_without_recompute_combination: LayerCombination = None
+ self.layer_recompute_one_combination: LayerCombination = None
+
+ self.node_split_flag = ','
+
+ self.num_warmup_micro_batches_per_chunk = []
+ self.num_micro_batches = 0
+
+ if search_cfg.num_layers_per_virtual_pipeline_stage:
+ self.num_model_chunks = (search_cfg.num_layers // search_cfg.num_layers_per_virtual_pipeline_stage //
+ search_cfg.pipeline_model_parallel_size)
+ else:
+ self.num_model_chunks = 1
+
+ def get_num_warmup_micro_batches(self):
+ pipeline_parallel_size = self.search_config.pipeline_model_parallel_size
+ data_parallel_size = self.search_config.data_parallel_size
+ self.num_micro_batches = self.model_config.global_batch_size // self.model_config.micro_batch_size // data_parallel_size
+ if pipeline_parallel_size <= 1:
+ self.num_warmup_micro_batches_per_chunk.append(1)
+ return
+ pipeline_parallel_rank = 0
+ total_num_micro_batches = self.num_micro_batches * self.num_model_chunks
+ if self.num_model_chunks == 1:
+ num_warmup_micro_batches = pipeline_parallel_size - pipeline_parallel_rank - 1
+ num_warmup_micro_batches += 1
+ self.num_warmup_micro_batches_per_chunk.append(num_warmup_micro_batches)
+ else:
+ num_warmup_micro_batches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
+ num_warmup_micro_batches += (self.num_model_chunks - 1) * pipeline_parallel_size
+ num_warmup_micro_batches += 1
+ num_warmup_micro_batches = min(num_warmup_micro_batches, total_num_micro_batches)
+ remain_batch_num = (num_warmup_micro_batches - pipeline_parallel_size * self.num_model_chunks)
+ for i in range(self.num_model_chunks):
+ if i == 0:
+ self.num_warmup_micro_batches_per_chunk.append(pipeline_parallel_size + max(0, remain_batch_num))
+ elif i == self.num_model_chunks - 1:
+ self.num_warmup_micro_batches_per_chunk.append(pipeline_parallel_size + min(0, remain_batch_num))
+ else:
+ self.num_warmup_micro_batches_per_chunk.append(pipeline_parallel_size)
+
+ def build_solver_info(self):
+ self.prune_no_recompute_layer()
+ self.layers_combination_init(0)
+ self.get_num_warmup_micro_batches()
+ return self.knapsack_best()
+
+ def get_recompute_op(self):
+ recompute_nodes = []
+ parent_node_list = []
+ for module_recompute_info in self.module_layers:
+ if not module_recompute_info.recompute:
+ continue
+ name = module_recompute_info.full_name
+ recompute_nodes.append(name)
+ separate_node_name_list = name.split(".")
+ for i in range(1, len(separate_node_name_list)):
+ parent_node_name = ".".join(separate_node_name_list[:-i])
+ if parent_node_name not in parent_node_list:
+ parent_node_list.append(parent_node_name)
+
+ for n in parent_node_list:
+ if n in recompute_nodes:
+ recompute_nodes.clear()
+ return recompute_nodes
+ return self.remove_full_selective_node(recompute_nodes)
+
+ def prune_no_recompute_layer(self):
+ module_layers = []
+ parent_layers = [self.first_layer_recompute_info]
+ children_module_list = []
+ self.recursive_prune_modules(self.first_layer_context, module_layers, parent_layers, children_module_list)
+ cur_layer_name = self.first_layer_recompute_info.full_name
+ self.parent_children_dict.update({cur_layer_name: children_module_list})
+ self.parent_layers = parent_layers
+ self.module_layers = module_layers
+
+ def recursive_prune_modules(self, parent_module, module_layers: List, parent_layers: List,
+ children_module_list: List):
+ if "layers" not in parent_module:
+ return
+ parent_modules = parent_module['layers']
+ parent_module_recompute_info = ModuleRecomputeInfo(parent_module)
+ if len(parent_modules) == 0:
+ return
+ parent_module_memory_time_rate = get_module_memory_time_rate(parent_module_recompute_info)
+ cur_sub_module_list = []
+ for sub_layer in parent_modules:
+ sub_layer_recompute_info = ModuleRecomputeInfo(sub_layer)
+ cur_layer_name = sub_layer_recompute_info.full_name
+ cur_sub_module_list.append(sub_layer_recompute_info)
+ children_layer_name = []
+ self.recursive_prune_modules(sub_layer, module_layers, parent_layers, children_layer_name)
+ if children_layer_name:
+ self.parent_children_dict.update({cur_layer_name: children_layer_name})
+ parent_layers.append(sub_layer_recompute_info)
+ sub_layer_memory_time_rate = get_module_memory_time_rate(sub_layer_recompute_info)
+ if sub_layer_memory_time_rate < parent_module_memory_time_rate:
+ continue
+ if not sub_layer_recompute_info.memory or len(children_layer_name) == 1 and children_layer_name[0].memory == sub_layer.get("memory"):
+ continue
+ module_layers.append(sub_layer_recompute_info)
+ self.recompute_module.update({cur_layer_name: sub_layer_recompute_info})
+
+ children_module_list.extend(cur_sub_module_list)
+
+ def remove_full_selective_node(self, recompute_nodes):
+ if len(recompute_nodes) == 0:
+ return recompute_nodes
+ try:
+ for parent_module in self.parent_layers:
+ parent_module_name = parent_module.full_name
+ if parent_module_name not in self.parent_children_dict.keys():
+ continue
+ sub_layers_recompute_count = 0
+ for sub_layer in self.parent_children_dict[parent_module_name]:
+ if sub_layer.full_name in recompute_nodes:
+ sub_layers_recompute_count += 1
+ if sub_layers_recompute_count == len(self.parent_children_dict[parent_module_name]):
+ recompute_nodes.clear()
+ break
+ except KeyError:
+ print("[ERROR] Some of these keys don't exist.")
+ return recompute_nodes
+
+ def layers_combination_init(self, idx):
+ if idx == 0:
+ self.layer_full_recompute_combination = LayerCombination({
+ "name": "full_recompute",
+ "memory": self.first_layer_recompute_info.input_size,
+ "cost": self.first_layer_recompute_info.time,
+ "policy_name": "n_full"
+ })
+ self.layers_combination.append(self.layer_full_recompute_combination)
+ self.layer_without_recompute_combination = LayerCombination({
+ "name": "without_recompute",
+ "memory": self.first_layer_recompute_info.memory,
+ "cost": 0,
+ "policy_name": "n_without"
+ })
+ self.layers_combination.append(self.layer_without_recompute_combination)
+ try:
+ if idx >= len(self.module_layers):
+ recompute_nodes = self.get_recompute_op()
+ if len(recompute_nodes) == 0:
+ return
+
+ stash_mem_per_layer = (self.first_layer_recompute_info.memory -
+ self.first_layer_recompute_info.input_size)
+ recompute_cost = 0
+ for recompute_module in recompute_nodes:
+ stash_mem_per_layer -= (self.recompute_module.get(recompute_module).memory -
+ self.recompute_module.get(recompute_module).input_size)
+ recompute_cost += self.recompute_module.get(recompute_module).time
+ self.layer_recompute_one_combination = LayerCombination({
+ "name": self.node_split_flag.join(recompute_nodes),
+ "memory": stash_mem_per_layer,
+ "cost": recompute_cost,
+ "policy_name": "n_selective"
+ })
+ self.layers_combination.append(self.layer_recompute_one_combination)
+ return
+ except KeyError:
+ print("[ERROR] The key \"module_layers\" doesn't exist.")
+ if self.module_layers[idx].memory > self.module_layers[idx].input_size:
+ self.module_layers[idx].recompute = True
+ self.layers_combination_init(idx + 1)
+ self.module_layers[idx].recompute = False
+ self.layers_combination_init(idx + 1)
+
+ def get_max_goods_value(self, idx, ans):
+ i, j, k = idx[0], idx[1], idx[2]
+ pre_step_ans = ans[i - 1][j - k]
+ if k == 0:
+ return deepcopy(pre_step_ans)
+
+ goods_value = ans[i][j]
+ memory = pre_step_ans.memory
+ pre_layer_num = j - k
+ for index in range(k):
+ cur_layer_index = pre_layer_num + index
+ cur_layer_chunk_rank = cur_layer_index // self.layer_num_per_chunk
+ memory += self.num_warmup_micro_batches_per_chunk[cur_layer_chunk_rank] * self.layers_combination[i].memory
+ cost = pre_step_ans.cost + k * self.layers_combination[i].cost * self.num_micro_batches
+ if pre_step_ans.cost == float('inf'):
+ cost = k * self.layers_combination[i].cost * self.num_micro_batches
+
+ device_memory = self.memory_limit
+
+ if device_memory >= memory and cost <= goods_value.cost and (len(pre_step_ans.layer_names) + k) == j:
+ goods_value.memory = memory
+ goods_value.cost = cost
+ goods_value.layer_names.clear()
+ if len(pre_step_ans.layer_names) > 0:
+ goods_value.layer_names.extend(pre_step_ans.layer_names)
+ goods_value.layer_names.extend(self.layers_combination[i].name for _ in range(k))
+
+ return goods_value
+
+ def knapsack_best(self):
+ combination_num = len(self.layers_combination)
+ base_memory = (self.static_memory - self.num_layers_per_pp / self.num_model_chunks * sum(self.num_warmup_micro_batches_per_chunk) *
+ self.first_layer_recompute_info.input_size)
+ base_cost = (self.full_recompute_performance - self.num_layers_per_pp * self.num_micro_batches *
+ self.first_layer_recompute_info.time)
+ ans = [[GoodsValue(base_memory, base_cost) for _ in range(self.num_layers_per_pp + 1)] for _ in range(combination_num)]
+ self.layer_num_per_chunk = self.num_layers_per_pp // self.num_model_chunks
+ for i in range(1, self.num_layers_per_pp + 1):
+ ans[0][i].cost += self.first_layer_recompute_info.time * self.num_micro_batches * i
+ for j in range(i):
+ cur_layer_chunk_rank = j // self.layer_num_per_chunk
+ ans[0][i].memory += (self.first_layer_recompute_info.input_size *
+ self.num_warmup_micro_batches_per_chunk[cur_layer_chunk_rank])
+ ans[0][i].layer_names.extend([self.layer_full_recompute_combination.name for _ in range(i)])
+
+ for i in range(1, combination_num):
+ for j in range(1, self.num_layers_per_pp + 1):
+ k = 0
+ while k <= j:
+ ans[i][j] = self.get_max_goods_value([i, j, k], ans)
+ k += 1
+
+ best_goods_value = ans[combination_num - 1][self.num_layers_per_pp]
+ print(f"after solve, current memory is {best_goods_value.memory} and current perf = {best_goods_value.cost} "
+ f"and cur_recompute_combination is {best_goods_value.layer_names}")
+ need_recompute = False
+ for combination_name in best_goods_value.layer_names:
+ if combination_name != self.layer_without_recompute_combination.name:
+ need_recompute = True
+ break
+ return need_recompute, best_goods_value.memory, best_goods_value.cost
+
+
+class LayerCombination:
+ def __init__(self, config):
+ self.name = config["name"]
+ self.memory = config["memory"]
+ self.cost = config["cost"]
+ self.policy_name = config["policy_name"]
+
+
+class GoodsValue:
+ def __init__(self, memory, cost=float('inf')):
+ self.layer_names = []
+ self.memory = memory
+ self.cost = cost
+
+
+def get_module_memory_time_rate(module_recompute_info: ModuleRecomputeInfo):
+ if module_recompute_info.memory:
+ return (module_recompute_info.memory - module_recompute_info.input_size) / module_recompute_info.time
+ return -1
+
+
+def module_name_split_num(module):
+ module_name_split_list = module['prefix_name'].split(".")
+ return len(module_name_split_list)
+
+
+def get_parent_module_list(sub_layer_full_name):
+ parent_module_list = []
+ separate_node_name_list = sub_layer_full_name.split(".")
+ for i in range(1, len(separate_node_name_list)):
+ parent_node_name = ".".join(separate_node_name_list[:-i])
+ parent_module_list.append(parent_node_name)
+
+ return parent_module_list
+
+
+def is_cur_module_tag_recompute(recompute_module_name_list, cur_module_name):
+ if cur_module_name in recompute_module_name_list:
+ return True
+ else:
+ return False
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/search/search_engine.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/search/search_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..b964006313a5059fd77e86c2dbcdd5988a980c85
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/search/search_engine.py
@@ -0,0 +1,272 @@
+from typing import Deque, List, Optional, Tuple
+from collections import deque
+from copy import deepcopy
+import pickle
+import os
+import sys
+import traceback as tb
+
+
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.module.memory.memory_modeling import MemoryModeling
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.module.search.stage_1_prune import stage_1_discrete_search_space_prune
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.utils.utils import get_prof_dir
+from mindspeed.auto_tuning.utils.restricted_unpickler import restricted_loads
+from mindspeed.auto_tuning.config.generate_profiling_configs import generate_profiling_configs
+
+
+_logger = get_logger("search")
+
+
+def search_demo(model_config: ModelConfig,
+ perf_obj_function,
+ working_dir: str,
+ re_profiling_flag=True,
+ recomp_cfg_list=None) -> [List[Optional[SearchConfig]], tuple]:
+ device_mem_cap = Hardware().memory_limit
+ _logger.info(f"Search: total_device_num: {Hardware().num_devices}")
+ _logger.info(f"Search: device_mem_cap: {device_mem_cap}")
+ best_perf_cfg_map: Deque[Tuple[float, Optional[SearchConfig]]] = deque([(float("inf"), None)] * 3, 3)
+
+ stage_1_valid_ptd_configs = stage_1_discrete_search_space_prune(
+ model_config,
+ pod_limit=8
+ )
+
+ _logger.info(f"Stage [1] pruned result: number of valid PTD configurations [{len(stage_1_valid_ptd_configs)}]")
+ for cfg in stage_1_valid_ptd_configs:
+ _logger.info(f"Stage [1] pruned config: TP=[{cfg.tp}] PP=[{cfg.pp}] LAYERS_PER_VPP=[{cfg.layers_per_vpp}] DP=[{cfg.dp}] CP=[{cfg.cp}] EP=[{cfg.ep}] ZeRO=[{cfg.zero1}]")
+
+ base_context = ""
+ base_search_cfg = None
+ for cfg in generate_profiling_configs(model_config):
+ json_path = os.path.join(working_dir, f'{get_prof_dir(cfg)}.json')
+ # find ep = 1 config
+ if (not os.path.exists(json_path) or cfg.expert_model_parallel_size and
+ cfg.expert_model_parallel_size != 1):
+ continue
+ try:
+ with open(json_path, "rb") as file:
+ base_context = restricted_loads(file)
+ base_search_cfg = cfg
+ except pickle.UnpicklingError as e:
+ _logger.warning(f"Incorrect pickle format. UnpicklingError: {e}")
+ raise e
+ if base_context:
+ break
+
+ _logger.debug(f"success print base_context = {base_context}")
+ uncovered_prof = []
+ profile_count = [0]
+
+ for cfg in stage_1_valid_ptd_configs:
+ _logger.info("====================")
+ _logger.info(f"Looking at:\n\n{cfg}")
+ mem_estimated, _ = MemoryModeling.estimate(cfg)
+ if mem_estimated <= device_mem_cap:
+ try:
+ perf, uncovered_prof, use_mc2 = perf_obj_function(cfg, working_dir, profile_count, re_profiling_flag)
+ except Exception as err:
+ _logger.warning(f"Search: ERROR during perf_modeling_calculation: {type(err).__name__}")
+ tb.print_exc()
+
+ context = ""
+ json_path = os.path.join(working_dir, f'{get_prof_dir(cfg)}.json')
+ if not os.path.exists(json_path):
+ _logger.debug("success modeling context…………")
+ context = get_context_by_ptd_config(base_context, base_search_cfg, cfg, model_config)
+ else:
+ try:
+ with open(json_path, "rb") as file:
+ context = restricted_loads(file)
+ except pickle.UnpicklingError as e:
+ _logger.warning(f"Incorrect pickle format. UnpicklingError: {e}")
+ raise e
+ _logger.debug(f"before recompute, perf = {perf} and memory = {mem_estimated}")
+ _logger.debug(f"success enter recompute_solver and tp = {cfg.tensor_model_parallel_size} "
+ f"pp = {cfg.pipeline_model_parallel_size} "
+ f"layers_per_vpp={cfg.num_layers_per_virtual_pipeline_stage} "
+ f"dp = {cfg.data_parallel_size} cp = {cfg.context_parallel_size} "
+ f"ep = {cfg.expert_model_parallel_size} zero = {cfg.use_distributed_optimizer}")
+ need_recompute, new_perf, add_mem, recompute_layer = full_recompute_solver(device_mem_cap - mem_estimated, context,
+ model_config, perf, cfg)
+ new_memory = add_mem + mem_estimated
+ _logger.debug(f"after recompute, perf = {new_perf} and need_recompute = {need_recompute}")
+ _logger.debug(f"cur mem_estimated = {new_memory}, recompute_layer = {recompute_layer}")
+
+ better_found = False
+ for i, perf_cfg in enumerate(best_perf_cfg_map):
+ if new_perf < perf_cfg[0]:
+ better_found = True
+ cfg.adaptive_recompute_device_swap = need_recompute
+ cfg.performance = new_perf
+ cfg.memory = new_memory
+ cfg.recompute_num_layers = recompute_layer
+ cfg.use_ascend_mc2 = use_mc2 if cfg.tensor_model_parallel_size > 1 else False
+ _logger.info(f"Search: SUCCESSFUL Better #{i} Config Found.")
+ _logger.debug(f"Performance Estimation: {new_perf}.")
+ best_perf_cfg_map.pop()
+ best_perf_cfg_map.insert(i, (new_perf, deepcopy(cfg)))
+ break
+ if not better_found:
+ _logger.info(f"Sub-optimal performance, next!")
+
+ else:
+ _logger.info(f"OOM found, next!")
+
+ return [cfg for _, cfg in best_perf_cfg_map], uncovered_prof
+
+
+def get_context_by_ptd_config(base_context, base_search_cfg, search_cfg, model_config):
+ cur_cfg_seq_multi_mbs_div_tp_cp = (search_cfg.seq_length / search_cfg.tensor_model_parallel_size /
+ search_cfg.context_parallel_size) * search_cfg.micro_batch_size
+ base_cfg_seq_multi_mbs_div_tp_cp = (base_search_cfg.seq_length / base_search_cfg.tensor_model_parallel_size /
+ base_search_cfg.context_parallel_size) * base_search_cfg.micro_batch_size
+ cur_cfg_resize_time = cur_cfg_seq_multi_mbs_div_tp_cp / base_cfg_seq_multi_mbs_div_tp_cp
+ context = deepcopy(base_context)
+
+ cur_experts_num = 0 if model_config.num_experts is None \
+ else model_config.num_experts // search_cfg.expert_model_parallel_size
+ recursive_change_context(context, cur_cfg_resize_time, cur_experts_num)
+
+ return context
+
+
+def recursive_change_context(context, cur_cfg_resize_time, cur_experts_num):
+ if "memory" in context:
+ context['memory'] *= cur_cfg_resize_time
+ if 'input' in context:
+ context['input'] *= cur_cfg_resize_time
+ if 'time' in context:
+ context['time'] *= cur_cfg_resize_time
+
+ check_prefix_name = 'prefix_name' in context and 'mlp' in context.get('prefix_name')
+ check_layer = 'layers' in context and context['layers'][0]['name'] == '0'
+ if check_prefix_name and check_layer:
+ context['layers'] = context['layers'][:cur_experts_num]
+ if "layers" not in context:
+ return
+ for layer_context in context["layers"]:
+ recursive_change_context(layer_context, cur_cfg_resize_time, cur_experts_num)
+
+
+class ToyModel(object):
+ def __init__(self):
+ return
+
+
+def perf_test_obj_function(search_config):
+ return
+
+
+def mem_test_toy_function(search_config):
+ return
+
+
+def get_first_layer_context(context):
+ if "memory" in context:
+ return context
+
+ if "layers" not in context:
+ return None
+ for layer_context in context["layers"]:
+ first_layer_context = get_first_layer_context(layer_context)
+ if first_layer_context is not None:
+ return first_layer_context
+ return None
+
+
+def memory_time_rate(ele):
+ if ele["memory"] - ele["input"] == 0:
+ return sys.maxsize
+ return ele["time"] / (ele["memory"] - ele["input"])
+
+
+def full_recompute_solver(oom_cap, model_context, model_cfg, perf, search_config):
+ if search_config.layers_per_vpp:
+ num_model_chunks = search_config.num_layers // search_config.layers_per_vpp // search_config.pp
+ layers_per_vpp = search_config.layers_per_vpp
+ else:
+ num_model_chunks = 1
+ layers_per_vpp = model_cfg.num_layers // search_config.pp
+ warmup_micro_batchs, total_num_micro_batches = get_num_warmup_micro_batches(num_model_chunks, search_config,
+ model_cfg)
+ ret_list = []
+ find_recompute_layer(model_context, ret_list)
+ layer_module = ret_list[0]
+
+ release_mem = 0
+ time_cost = 0
+ num_layers = model_cfg.num_layers // search_config.pp
+ ret_list.sort(key=memory_time_rate, reverse=True)
+ need_recompute = True
+ memory_per_layer = layer_module["memory"] - layer_module["input"]
+ # 1.No full recompute
+ max_release_mem = warmup_micro_batchs * layers_per_vpp * memory_per_layer - memory_per_layer
+
+ if max_release_mem <= oom_cap:
+ return False, perf - total_num_micro_batches * num_layers * layer_module["time"], max_release_mem, 0
+
+ if search_config.layers_per_vpp:
+ # 2.Situation under per pp stage and per mbs recompute layers <= layers_per_vpp
+ max_release_mem = (num_model_chunks - 1) * search_config.pp * layers_per_vpp * memory_per_layer
+ if max_release_mem <= oom_cap:
+ layer_calculate = (oom_cap - max_release_mem) // ((2 * search_config.pp - 1) * memory_per_layer)
+ release_mem += (2 * search_config.pp - 1) * layer_calculate * memory_per_layer + max_release_mem - memory_per_layer
+ time_cost += (num_layers - layers_per_vpp + layer_calculate) * total_num_micro_batches * layer_module["time"]
+ return True, perf - time_cost, release_mem, layers_per_vpp - layer_calculate
+
+ # Only consider layers temporarily
+ layer_calculate = (oom_cap // (memory_per_layer * search_config.pp))
+ release_mem += layer_calculate * memory_per_layer * search_config.pp
+ if layer_calculate < num_layers:
+ release_mem -= memory_per_layer
+ time_cost += total_num_micro_batches * layer_calculate * layer_module["time"]
+ return need_recompute, perf - time_cost, release_mem, num_layers - layer_calculate
+
+ else:
+ layer_calculate = (oom_cap // (memory_per_layer * search_config.pp))
+ release_mem += layer_calculate * memory_per_layer * search_config.pp
+ if layer_calculate < num_layers:
+ release_mem -= memory_per_layer
+ time_cost += total_num_micro_batches * layer_calculate * layer_module["time"]
+ return need_recompute, perf - time_cost, release_mem, num_layers - layer_calculate
+
+
+def get_num_warmup_micro_batches(num_model_chunks, search_config, model_cfg):
+ pipeline_parallel_size = search_config.pp
+ data_parallel_size = search_config.dp
+ num_microbatches = model_cfg.gbs // (search_config.mbs * data_parallel_size)
+
+ if pipeline_parallel_size <= 1:
+ return 1, num_microbatches
+
+ pipeline_parallel_size = pipeline_parallel_size
+ pipeline_parallel_rank = 0
+ total_num_micro_batches = num_microbatches * num_model_chunks
+ if num_model_chunks == 1:
+ num_warmup_micro_batches = pipeline_parallel_size - pipeline_parallel_rank - 1
+
+ else:
+ num_warmup_micro_batches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
+ num_warmup_micro_batches += (num_model_chunks - 1) * pipeline_parallel_size
+ num_warmup_micro_batches += 1
+ num_warmup_micro_batches = min(num_warmup_micro_batches, total_num_micro_batches)
+ return num_warmup_micro_batches, num_microbatches
+
+
+def find_recompute_layer(context, ret_list):
+ if "memory" in context and context["input"] <= context["memory"]:
+ layer_dict = {"memory": context["memory"], "time": context["time"],
+ "input": context["input"], "prefix_name": context["prefix_name"], "name": context["name"]}
+ ret_list.append(layer_dict)
+
+ # layer module the first to be appened
+ if "layers" not in context:
+ return
+ for layer_context in context["layers"]:
+ find_recompute_layer(layer_context, ret_list)
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/module/search/stage_1_prune.py b/model/train/yoco_moe/mindspeed/auto_tuning/module/search/stage_1_prune.py
new file mode 100644
index 0000000000000000000000000000000000000000..16236ea16fc3d97c59b4688de61895300d4e1cf9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/module/search/stage_1_prune.py
@@ -0,0 +1,127 @@
+from typing import List
+from dataclasses import replace
+
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+
+
+def stage_1_discrete_search_space_prune(
+ mcfg: ModelConfig,
+ pod_limit=0,
+ model_in_pod=False,
+ device_fluctuation_down_ratio=0
+) -> List[SearchConfig]:
+ """
+ Stage 1 prune is without any modeling.
+ This function prunes the search space for a distributed training job based on given constraints.
+
+ Parameters:
+ layer_number (int): The total number of layers.
+ total_device_number (int): The total number of devices.
+ micro_batch_number (int): The number of micro-batches.
+ expert_number (int): The number of experts.
+ pod_limit (int, optional): The maximum number of devices in a super pod. Default is 0.
+ model_in_pod (bool, optional): If True, the product of tp and pp should be less than or equal to pod_limit. Default is False.
+ device_fluctuation_ratio (float, optional): The ratio of device fluctuation. Must be between 0 and 1. Default is 0.
+
+ Returns:
+ list of dict: A list of valid configurations (tp, cp, pp, dp, ep, zero which stored as a dict) that satisfy all constraints.
+ """
+
+ num_devices = mcfg.global_world_size
+ device_type = Hardware().device_type
+
+ valid_configs: List[SearchConfig] = list()
+
+ # Iterate over all possible combinations of tp, cp, pp, dp, ep and zero
+ # Prune tp based on device_type, tp = 1 or 8 only if running on 910B
+ tp_search_list = [2 ** i for i in range(num_devices + 1)]
+ if "910B" in device_type:
+ tp_search_list = [1, 8]
+ for tp in tp_search_list:
+
+ # Check if tp is less than or equal to pod_limit
+ if 0 < pod_limit < tp:
+ continue
+
+ for cp in range(1, num_devices // tp + 1):
+
+ # Check cp long sequence based on device_type
+ if cp > 1:
+ if ("910B" in device_type) and \
+ ((mcfg.seq_length // cp) < 8 * 1024):
+ continue
+ if ("910_9" in device_type) and \
+ ((mcfg.seq_length // cp) < 4 * 1024):
+ continue
+
+ for pp in range(1, num_devices // (tp * cp) + 1):
+
+ # Check if tp * pp is less than or equal to pod_limit
+ if model_in_pod and tp * pp > pod_limit:
+ continue
+ # Check if layer_number is divisible by pp
+ if mcfg.num_layers % pp != 0:
+ continue
+
+ for dp in range(1, num_devices // (tp * cp * pp) + 1):
+
+ # Check device number compatibility
+ if device_fluctuation_down_ratio > 0:
+ if not ((1 - device_fluctuation_down_ratio) * num_devices < tp * cp * pp * dp <= num_devices):
+ continue
+ else:
+ if tp * cp * pp * dp != num_devices:
+ continue
+ # Check if micro_batch_number is divisible by dp
+ if mcfg.num_micro_batches % dp != 0:
+ continue
+ # Check if micro_batch_number / (pp * dp) is greater than 1
+ if mcfg.num_micro_batches // (pp * dp) <= 1:
+ continue
+
+ num_experts = mcfg.num_experts if mcfg.num_experts else 1
+ for ep in range(1, min(cp * dp, num_experts) + 1):
+
+ # Check if (ep | cp * dp) and (ep | expert_number)
+ if ((cp * dp) % ep != 0) or (num_experts % ep != 0):
+ continue
+
+ layers_per_vpp_search_domain = [None]
+ # Search vpp only if pp is enabled
+ if pp > 1:
+ # Search domain drops the last possible value (layer_number // pp)
+ # due to the constraint $layers_per_vpp * pp != layer_number$
+ layers_per_vpp_search_domain += \
+ [x for x in range(1, mcfg.num_layers // pp)]
+ for layers_per_vpp in layers_per_vpp_search_domain:
+
+ # Check if $layers_per_vpp$ not None and $layers_per_vpp * pp | layer_number$
+ if layers_per_vpp and \
+ mcfg.num_layers % (layers_per_vpp * pp) != 0:
+ continue
+
+ for mbs in [1, 2]:
+ cfg_zero0 = SearchConfig()
+ cfg_zero0.copy_from_config(mcfg)
+ cfg_zero0.tensor_model_parallel_size = tp
+ cfg_zero0.context_parallel_size = cp
+ cfg_zero0.pipeline_model_parallel_size = pp
+ cfg_zero0.num_layers_per_virtual_pipeline_stage = \
+ layers_per_vpp
+ cfg_zero0.use_distributed_optimizer = False
+ cfg_zero0.micro_batch_size = mbs
+ if mcfg.is_moe():
+ cfg_zero0.expert_model_parallel_size = ep
+ cfg_zero0.normalize()
+
+ valid_configs.append(cfg_zero0)
+
+ # When (dp * cp > 1), zero can be 1; add this config to the list
+ if dp * cp > 1:
+ cfg_zero1 = replace(cfg_zero0,
+ use_distributed_optimizer=True)
+ valid_configs.append(cfg_zero1)
+
+ return valid_configs
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/dtype.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fd3885978cea527f92fb7cda08eefd628d68661
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/dtype.py
@@ -0,0 +1,7 @@
+from enum import Enum
+
+
+class DTYPE(Enum):
+ fp16 = ("fp16", 2)
+ fp32 = ("fp32", 4)
+ bf16 = ("bf16", 2)
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/file_utils.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..798fa12deb585802606535eb8d590d98127ca6d1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/file_utils.py
@@ -0,0 +1,9 @@
+import os
+
+
+def check_file_size(file):
+ max_file_size = 5 * 1024 * 1024 * 1024
+ if os.fstat(file.fileno()).st_size <= max_file_size:
+ return
+ else:
+ raise IOError("file too large to read")
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/logger.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2685083d020dab671bb2bbee139bfb6ffc2cd7a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/logger.py
@@ -0,0 +1,41 @@
+from typing import Optional, Set
+import logging
+import os
+from sys import stdout
+
+_LOGGERS: Set[str] = set()
+_LOG_FMT = "[%(levelname)s] %(name)s: %(message)s"
+_LOG_LEVEL = logging.INFO
+_LOGGER_NAME_PREFIX = "auto-tuning"
+
+
+def init_logger(level: str = "info"):
+ global _LOG_LEVEL
+ if level == "warning":
+ _LOG_LEVEL = logging.WARNING
+ elif level == "debug":
+ _LOG_LEVEL = logging.DEBUG
+ else:
+ _LOG_LEVEL = logging.INFO
+
+ for name in _LOGGERS:
+ logger_name = f"{_LOGGER_NAME_PREFIX}.{name}"
+ logger = logging.getLogger(name=logger_name)
+ logger.setLevel(_LOG_LEVEL)
+ for handler in logger.handlers:
+ handler.setFormatter(logging.Formatter(fmt=_LOG_FMT))
+
+
+def get_logger(name: str):
+ global _LOGGERS
+ logger_name = f"{_LOGGER_NAME_PREFIX}.{name}"
+ if name not in _LOGGERS:
+ logger = logging.getLogger(name=logger_name)
+ logger.propagate = False
+ logger.setLevel(_LOG_LEVEL)
+ logger.addHandler(logging.StreamHandler(stream=stdout))
+ for handler in logger.handlers:
+ handler.setFormatter(logging.Formatter(fmt=_LOG_FMT))
+ _LOGGERS.add(name)
+ return logger
+ return logging.getLogger(logger_name)
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/mem_utils.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/mem_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5047b7c7ee6d3c36f477b4ceea2983fa5c07069
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/mem_utils.py
@@ -0,0 +1,29 @@
+from typing import Callable, List
+
+
+def mem_b_to_kb(n: float, d: int = 2) -> float:
+ return round(n / pow(1024, 1), d)
+
+
+def mem_kb_to_b(n: float, d: int = 2) -> float:
+ return round(n * pow(1024, 1), d)
+
+
+def mem_b_to_mb(n: float, d: int = 2) -> float:
+ return round(n / pow(1024, 2), d)
+
+
+def mem_mb_to_b(n: float, d: int = 2) -> float:
+ return round(n * pow(1024, 2), d)
+
+
+def mem_b_to_gb(n: float, d: int = 2) -> float:
+ return round(n / pow(1024, 3), d)
+
+
+def mem_gb_to_b(n: float, d: int = 2) -> float:
+ return round(n * pow(1024, 3), d)
+
+
+def mem_convert_list(ns: List[float], func: Callable[[float, int], float], d: int = 2) -> List[float]:
+ return [func(n, d) for n in ns]
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/restricted_unpickler.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/restricted_unpickler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a5be908e1d80cdd228171c20957a1888c9e613c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/restricted_unpickler.py
@@ -0,0 +1,18 @@
+import pickle
+import importlib
+from mindspeed.auto_tuning.utils.file_utils import check_file_size
+
+
+class _RestrictedUnpickler(pickle.Unpickler):
+ def find_class(self, module: str, name: str):
+ if module.startswith("mindspeed.auto_tuning"):
+ module_ = importlib.import_module(module)
+ clazz = getattr(module_, name)
+ if isinstance(clazz, type):
+ return clazz
+ raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name))
+
+
+def restricted_loads(s):
+ check_file_size(s)
+ return _RestrictedUnpickler(s).load()
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/__init__.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/irunner.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/irunner.py
new file mode 100644
index 0000000000000000000000000000000000000000..f88d46f6cf747d2e4877fae355ef91f879d7ec75
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/irunner.py
@@ -0,0 +1,22 @@
+from typing import Dict
+import abc
+
+_Env = Dict[str, str]
+
+
+class IRunner(metaclass=abc.ABCMeta):
+
+ @classmethod
+ def __subclasshook__(cls, subclass: type) -> bool:
+ if cls is IRunner:
+ return hasattr(subclass, "get_base_env") and \
+ hasattr(subclass, "run")
+ return NotImplemented
+
+ @abc.abstractmethod
+ def get_base_env(self) -> _Env:
+ return NotImplemented
+
+ @abc.abstractmethod
+ def run(self, env: _Env) -> int:
+ return NotImplemented
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/model_executor.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/model_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bd15af7ea0e82794682cd848116293059eddf1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/model_executor.py
@@ -0,0 +1,267 @@
+from typing import Dict, Optional, Tuple
+import os
+import stat
+from enum import Enum
+
+import pickle
+
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+from mindspeed.auto_tuning.utils.runner.irunner import _Env, IRunner
+
+
+_Argv = Dict[str, Optional[str]]
+
+
+class ExecutorFlag(Enum):
+ RUN = 0
+ PARSE_ARGS = 1
+ PARSE_MODEL = 2
+ PROFILE = 3
+
+
+class ModelExecutor:
+ """
+ Execute the model with different configs.
+ """
+ MODIFIED_ARGV_FILENAME = "auto_tuning_modified_argv.json"
+ PARSE_ARGS_ENV = "OOTB_OPTIMIZER_PARSE_ARGS"
+ PARSE_MODEL_ENV = "OOTB_OPTIMIZER_PARSE_MODEL"
+ PROFILING_ENV = "OOTB_OPTIMIZER_PROFILING"
+ MODIFIED_ARGV_PATH_ENV = "OOTB_OPTIMIZER_MODIFIED_ARGV_PATH"
+ ENABLED_ENV_MARKER = "TRUE"
+
+ def __init__(self,
+ runner: IRunner,
+ num_layers_config="--num-layers",
+ num_experts_config="--num-experts",
+ seq_length_config="--seq-length",
+ max_position_embeddings_config="--max-position-embeddings",
+ micro_batch_size_config="--micro-batch-size",
+ global_batch_size_config="--global-batch-size",
+ recompute_granularity_config="--recompute-granularity",
+ recompute_method_config="--recompute-method",
+ recompute_num_layers_config="--recompute-num-layers",
+ adaptive_recompute_device_swap_config="--adaptive-recompute-device-swap",
+ enable_token_rearrange_opt_config="--enable-token-rearrange-opt",
+ tensor_model_parallel_size_config="--tensor-model-parallel-size",
+ pipeline_model_parallel_size_config="--pipeline-model-parallel-size",
+ num_layers_per_virtual_pipeline_stage_config="--num-layers-per-virtual-pipeline-stage",
+ expert_model_parallel_size_config="--expert-model-parallel-size",
+ context_parallel_size_config="--context-parallel-size",
+ use_distributed_optimizer_config="--use-distributed-optimizer",
+ use_ascend_mc2_config="--use-ascend-mc2",
+ train_iters_config="--train-iters",
+ profile_config="--profile",
+ profile_step_start_config="--profile-step-start",
+ profile_step_end_config="--profile-step-end",
+ profile_ranks_config="--profile-ranks",
+ profile_level_config="--profile-level",
+ profile_with_cpu_config="--profile-with-cpu",
+ profile_with_stack_config="--profile-with-stack",
+ profile_with_memory_config="--profile-with-memory",
+ profile_record_shapes_config="--profile-record-shapes",
+ profile_save_path_config="--profile-save-path"
+ ) -> None:
+ self.runner = runner
+ self.num_layers_config = num_layers_config
+ self.num_experts_config = num_experts_config
+ self.seq_length_config = seq_length_config
+ self.max_position_embeddings_config = max_position_embeddings_config
+ self.micro_batch_size_config = micro_batch_size_config
+ self.global_batch_size_config = global_batch_size_config
+ self.recompute_granularity_config = recompute_granularity_config
+ self.recompute_method_config = recompute_method_config
+ self.recompute_num_layers_config = recompute_num_layers_config
+ self.adaptive_recompute_device_swap_config = adaptive_recompute_device_swap_config
+ self.enable_token_rearrange_opt_config = enable_token_rearrange_opt_config
+ self.tensor_model_parallel_size_config = tensor_model_parallel_size_config
+ self.pipeline_model_parallel_size_config = pipeline_model_parallel_size_config
+ self.num_layers_per_virutal_pipeline_stage_config = num_layers_per_virtual_pipeline_stage_config
+ self.expert_model_parallel_size_config = expert_model_parallel_size_config
+ self.context_parallel_size_config = context_parallel_size_config
+ self.use_distributed_optimizer_config = use_distributed_optimizer_config
+ self.use_ascend_mc2_config = use_ascend_mc2_config
+ self.train_iters_config = train_iters_config
+ self.profile_config = profile_config
+ self.profile_step_start_config = profile_step_start_config
+ self.profile_step_end_config = profile_step_end_config
+ self.profile_ranks_config = profile_ranks_config
+ self.profile_level_config = profile_level_config
+ self.profile_with_cpu_config = profile_with_cpu_config
+ self.profile_with_stack_config = profile_with_stack_config
+ self.profile_with_memory_config = profile_with_memory_config
+ self.profile_record_shapes_config = profile_record_shapes_config
+ self.profile_save_path_config = profile_save_path_config
+
+ def execute(self,
+ working_dir: str,
+ output_filename: str = str(),
+ cfg: Optional[SearchConfig] = None,
+ flag: ExecutorFlag = ExecutorFlag.RUN
+ ) -> int:
+ env = self.runner.get_base_env()
+ self._prepare_envvars(env, flag)
+
+ modified_argv_path = os.path.join(working_dir, self.MODIFIED_ARGV_FILENAME)
+
+ self._prepare_modified_argv_envvars(env, modified_argv_path)
+
+ modified_argv = self._prepare_modified_argv(cfg, working_dir, output_filename, flag)
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ with os.fdopen(os.open(modified_argv_path, flags, mode=mode), 'wb') as f:
+ pickle.dump(modified_argv, f)
+
+ returncode = self.runner.run(env)
+
+ return returncode
+
+ def _prepare_envvars(self,
+ env: _Env,
+ flag: ExecutorFlag
+ ) -> _Env:
+ env.pop(self.PARSE_ARGS_ENV, None)
+ env.pop(self.PARSE_MODEL_ENV, None)
+ env.pop(self.PROFILING_ENV, None)
+
+ if flag == ExecutorFlag.PARSE_ARGS:
+ env.update({self.PARSE_ARGS_ENV: self.ENABLED_ENV_MARKER})
+ elif flag == ExecutorFlag.PARSE_MODEL:
+ env.update({self.PARSE_MODEL_ENV: self.ENABLED_ENV_MARKER})
+ elif flag == ExecutorFlag.PROFILE:
+ env.update({self.PROFILING_ENV: self.ENABLED_ENV_MARKER})
+
+ return env
+
+ def _prepare_modified_argv_envvars(self,
+ env: _Env,
+ modified_argv_path: str
+ ) -> _Env:
+ env.update({self.MODIFIED_ARGV_PATH_ENV: modified_argv_path})
+
+ return env
+
+ def _prepare_modified_argv(
+ self,
+ cfg: Optional[SearchConfig],
+ working_dir: str,
+ output_filename: str,
+ flag: ExecutorFlag
+ ) -> Tuple[_Argv, _Argv]:
+ enabled_argv: _Argv = dict()
+ disabled_argv: _Argv = dict()
+ if cfg:
+ cfg.normalize()
+
+ def _modify_model_argv():
+ if self.recompute_granularity_config and self.recompute_method_config and self.recompute_num_layers_config:
+ if cfg.is_full_recompute():
+ enabled_argv.update({self.recompute_granularity_config: cfg.recompute_granularity})
+ enabled_argv.update({self.recompute_method_config: cfg.recompute_method})
+ enabled_argv.update({self.recompute_num_layers_config: str(cfg.recompute_num_layers)})
+ else:
+ disabled_argv.update({self.recompute_granularity_config: str()})
+ disabled_argv.update({self.recompute_method_config: str()})
+ disabled_argv.update({self.recompute_num_layers_config: str()})
+
+ if self.num_layers_config:
+ enabled_argv.update({self.num_layers_config: str(cfg.num_layers)})
+
+ if self.num_experts_config:
+ if cfg.num_experts:
+ enabled_argv.update({self.num_experts_config: str(cfg.num_experts)})
+ else:
+ disabled_argv.update({self.num_experts_config: str()})
+
+ if self.seq_length_config:
+ enabled_argv.update({self.seq_length_config: str(cfg.seq_length)})
+ enabled_argv.update({self.max_position_embeddings_config: str(cfg.seq_length)})
+
+ if self.micro_batch_size_config:
+ enabled_argv.update({self.micro_batch_size_config: str(cfg.micro_batch_size)})
+
+ if self.global_batch_size_config:
+ enabled_argv.update({self.global_batch_size_config: str(cfg.global_batch_size)})
+
+ if self.adaptive_recompute_device_swap_config:
+ if cfg.adaptive_recompute_device_swap:
+ enabled_argv.update({self.adaptive_recompute_device_swap_config: None})
+ else:
+ disabled_argv.update({self.adaptive_recompute_device_swap_config: None})
+
+ if self.enable_token_rearrange_opt_config:
+ if cfg.enable_token_rearrange_opt:
+ enabled_argv.update({self.enable_token_rearrange_opt_config: None})
+ else:
+ disabled_argv.update({self.enable_token_rearrange_opt_config: None})
+
+ if self.use_ascend_mc2_config:
+ if cfg.use_ascend_mc2:
+ enabled_argv.update({self.use_ascend_mc2_config: None})
+ else:
+ disabled_argv.update({self.use_ascend_mc2_config: None})
+
+ def _modify_parallel_argv():
+ if self.tensor_model_parallel_size_config:
+ enabled_argv.update({self.tensor_model_parallel_size_config: str(cfg.tensor_model_parallel_size)})
+
+ if self.pipeline_model_parallel_size_config:
+ enabled_argv.update({self.pipeline_model_parallel_size_config: str(cfg.pipeline_model_parallel_size)})
+
+ if self.num_layers_per_virutal_pipeline_stage_config:
+ if cfg.num_layers_per_virtual_pipeline_stage:
+ enabled_argv.update({self.num_layers_per_virutal_pipeline_stage_config:
+ str(cfg.num_layers_per_virtual_pipeline_stage)})
+ else:
+ disabled_argv.update({self.num_layers_per_virutal_pipeline_stage_config: str()})
+
+ if self.expert_model_parallel_size_config:
+ if cfg.expert_model_parallel_size:
+ enabled_argv.update({self.expert_model_parallel_size_config: str(cfg.expert_model_parallel_size)})
+ else:
+ disabled_argv.update({self.expert_model_parallel_size_config: str()})
+
+ if self.context_parallel_size_config:
+ enabled_argv.update({self.context_parallel_size_config: str(cfg.context_parallel_size)})
+
+ if self.use_distributed_optimizer_config:
+ if cfg.use_distributed_optimizer:
+ enabled_argv.update({self.use_distributed_optimizer_config: None})
+ else:
+ disabled_argv.update({self.use_distributed_optimizer_config: None})
+
+ def _modify_profile_argv():
+ if cfg.profile:
+ enabled_argv.update({self.train_iters_config: str(cfg.train_iters)})
+ enabled_argv.update({self.profile_config: None})
+ enabled_argv.update({self.profile_step_start_config: str(cfg.profile_step_start)})
+ enabled_argv.update({self.profile_step_end_config: str(cfg.profile_step_end)})
+ enabled_argv.update({self.profile_ranks_config: str(cfg.profile_ranks)})
+ enabled_argv.update({self.profile_level_config: cfg.profile_level})
+ if cfg.profile_with_cpu:
+ enabled_argv.update({self.profile_with_cpu_config: None})
+ else:
+ disabled_argv.update({self.profile_with_cpu_config: None})
+ if cfg.profile_with_stack:
+ enabled_argv.update({self.profile_with_stack_config: None})
+ else:
+ disabled_argv.update({self.profile_with_stack_config: None})
+ if cfg.profile_with_memory:
+ enabled_argv.update({self.profile_with_memory_config: None})
+ else:
+ enabled_argv.update({self.profile_with_memory_config: None})
+ if cfg.profile_record_shapes:
+ enabled_argv.update({self.profile_record_shapes_config: None})
+ else:
+ disabled_argv.update({self.profile_record_shapes_config: None})
+
+ _modify_model_argv()
+ _modify_parallel_argv()
+ _modify_profile_argv()
+
+ if flag == ExecutorFlag.PARSE_ARGS:
+ enabled_argv.update({self.profile_save_path_config: working_dir})
+ elif flag == ExecutorFlag.PARSE_MODEL or flag == ExecutorFlag.PROFILE:
+ enabled_argv.update({self.profile_save_path_config: os.path.join(working_dir, output_filename)})
+
+ return enabled_argv, disabled_argv
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/torchrun_runner.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/torchrun_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac9db7d2a46a4d9249753b2c6529c58ce5f72e78
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/runner/torchrun_runner.py
@@ -0,0 +1,71 @@
+import os
+import sys
+import subprocess
+
+from megatron.training import get_args
+
+from mindspeed.auto_tuning.utils.logger import get_logger
+from mindspeed.auto_tuning.utils.runner.irunner import _Env, IRunner
+
+_AUTO_TUNING_ARGS = "--auto-tuning"
+_logger = get_logger("runner")
+
+
+class TorchRunRunner(IRunner):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def get_base_env(self) -> _Env:
+ return os.environ.copy()
+
+ def run(self, env: _Env) -> int:
+
+ args = get_args()
+ argv: list = sys.argv[1:]
+ auto_tuning_filter_args_switch = ["--use-ascend-mc2", "--swap-attention",
+ "--ampipe-tp-sp-comm-overlap",
+ "--use-pipe-experts", "--pipe-experts-multi-stream",
+ "--recompute-in-advance", "--recompute-in-bubble", "--use-nanopipe"]
+ auto_tuning_filter_args_config = ["--ampipe-degree", "--pipe-experts-multi-data"]
+
+ if _AUTO_TUNING_ARGS in sys.argv:
+ argv[argv.index("--tensor-model-parallel-size") + 1] = '8'
+ argv[argv.index("--pipeline-model-parallel-size") + 1] = '1'
+ argv[argv.index("--context-parallel-size") + 1] = '1'
+ if "--num-layers-per-virtual-pipeline-stage" in argv:
+ vpp_index = argv.index("--num-layers-per-virtual-pipeline-stage")
+ argv.pop(vpp_index + 1)
+ argv.pop(vpp_index)
+ if "--expert-model-parallel-size" in argv:
+ argv[argv.index("--expert-model-parallel-size") + 1] = '1'
+ if "--use-ascend-mc2" in argv:
+ argv.pop(argv.index("--use-ascend-mc2"))
+ for feature_args in auto_tuning_filter_args_switch:
+ if feature_args in argv:
+ argv.pop(argv.index(feature_args))
+ for feature_args in auto_tuning_filter_args_config:
+ if feature_args in argv:
+ args_index = argv.index(feature_args)
+ argv.pop(args_index + 1)
+ argv.pop(args_index)
+
+ while _AUTO_TUNING_ARGS in argv:
+ pos = argv.index(_AUTO_TUNING_ARGS)
+ argv.pop(pos)
+
+ command = [
+ 'torchrun',
+ '--nproc_per_node', str(args.nproc_per_node),
+ '--nnodes', str(args.nnodes),
+ '--node-rank', str(args.node_rank),
+ '--master_addr', str(args.master_addr),
+ '--master_port', str(args.master_port),
+ str(sys.argv[0])
+ ] + argv
+ process = subprocess.Popen(command, shell=False, preexec_fn=lambda: os.setpgrp(), env=env)
+ process.wait()
+ returncode = process.returncode
+ _logger.info(returncode)
+
+ return returncode
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/singleton.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/singleton.py
new file mode 100644
index 0000000000000000000000000000000000000000..a368474085e0d21bd9ce2c12651449c653110599
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/singleton.py
@@ -0,0 +1,14 @@
+import threading
+
+
+class Singleton(type):
+ _instances = dict()
+ _lock = threading.Lock()
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ with cls._lock:
+ if cls not in cls._instances:
+ instance = super(Singleton, cls).__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
diff --git a/model/train/yoco_moe/mindspeed/auto_tuning/utils/utils.py b/model/train/yoco_moe/mindspeed/auto_tuning/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba294d981cd3410b896ea5f143a4e86a4a6003e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/auto_tuning/utils/utils.py
@@ -0,0 +1,34 @@
+from mindspeed.auto_tuning.module.hardware import Hardware
+from mindspeed.auto_tuning.config.model_config import ModelConfig
+from mindspeed.auto_tuning.config.search_config import SearchConfig
+
+
+def get_tp_for_profiling() -> int:
+ tp = Hardware().num_devices // 4
+ if "910B" in Hardware().device_type:
+ tp = min(tp, 8)
+ return tp
+
+
+def get_seq_length_for_profiling(model_cfg: ModelConfig) -> int:
+ if model_cfg.disable_cp_flag:
+ return model_cfg.seq_length
+ return min(model_cfg.seq_length, 32 * 1024)
+
+
+def get_prof_dir(cfg: SearchConfig, re_profile=False) -> str:
+ prof_dir = "auto_tuning_profiling"
+ prof_dir += f"_{cfg.tp}tp"
+ prof_dir += f"_{cfg.dp}dp"
+ prof_dir += f"_{cfg.pp}pp"
+ prof_dir += f"_{cfg.cp}cp"
+ prof_dir += f"_{cfg.mbs}mbs"
+ if cfg.is_moe():
+ prof_dir += f"_{cfg.ep}ep"
+ prof_dir += f"_{cfg.num_experts}experts"
+ if cfg.use_ascend_mc2:
+ prof_dir += f"_mc2"
+ prof_dir += f"_{cfg.seq_length}seq"
+ if re_profile:
+ prof_dir += f"_re_profile"
+ return prof_dir
diff --git a/model/train/yoco_moe/mindspeed/checkpointing.py b/model/train/yoco_moe/mindspeed/checkpointing.py
new file mode 100644
index 0000000000000000000000000000000000000000..268ae2b9ec3f8950cc1dc9de4742bdd3fbe44264
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/checkpointing.py
@@ -0,0 +1,412 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Input/output checkpointing."""
+
+import os
+import sys
+from functools import wraps
+import torch
+
+from megatron.core import mpu, tensor_parallel, dist_checkpointing
+from megatron.core.optimizer.optimizer import ChainedOptimizer
+from megatron.training import get_args
+from megatron.training.utils import (unwrap_model, print_rank_0)
+
+from megatron.training.checkpointing import (
+ get_rng_state,
+ get_checkpoint_name,
+ get_distributed_optimizer_checkpoint_name,
+ ensure_directory_exists,
+ get_checkpoint_tracker_filename,
+ read_metadata,
+ find_checkpoint_rank_0
+)
+
+
+def save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
+ num_floating_point_operations_so_far, checkpointing_context=None):
+ args = get_args()
+
+ # Only rank zero of the data parallel writes to the disk.
+ model = unwrap_model(model)
+
+ ckpt_format = args.dist_ckpt_format if args.use_dist_ckpt else 'torch'
+ print_rank_0('saving checkpoint at iteration {:7d} to {} in {} format'.format(
+ iteration, args.save, ckpt_format))
+
+ # Collect rng state across data parallel ranks.
+ rng_state = get_rng_state(args.use_dist_ckpt)
+
+ # Checkpoint name.
+ checkpoint_name = get_checkpoint_name(args.save, iteration, return_base_dir=args.use_dist_ckpt)
+
+ # Save distributed optimizer's custom parameter state.
+ if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None and not args.use_dist_ckpt:
+ optim_checkpoint_name = \
+ get_distributed_optimizer_checkpoint_name(checkpoint_name)
+ ensure_directory_exists(optim_checkpoint_name)
+ optimizer.save_parameter_state(optim_checkpoint_name)
+
+ async_save_request = None
+ if args.async_save:
+ if not args.use_dist_ckpt:
+ raise NotImplementedError('Async checkpoint save not implemented for legacy checkpoints')
+ elif args.dist_ckpt_format != 'torch_dist':
+ raise NotImplementedError(
+ f'Async checkpoint save not implemented for {args.dist_ckpt_format} distributed checkpoint format')
+
+ # Collect args, model, RNG.
+ if not torch.distributed.is_initialized() \
+ or mpu.get_data_modulo_expert_parallel_rank() == 0 \
+ or args.use_dist_ckpt:
+
+ optim_sd_kwargs = {}
+ if args.use_dist_ckpt and args.use_distributed_optimizer:
+ optim_sd_kwargs['sharding_type'] = ('fully_sharded_bucket_space'
+ if args.ckpt_fully_parallel_save
+ else 'dp_zero_gather_scatter')
+ print_rank_0(f'Storing distributed optimizer sharded state of type {optim_sd_kwargs["sharding_type"]}')
+ state_dict = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state,
+ args.use_dist_ckpt, iteration, optim_sd_kwargs=optim_sd_kwargs)
+
+ state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
+ if args.use_dist_ckpt:
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+ ensure_directory_exists(checkpoint_name, check_parent=False)
+ validate_sharding_integrity = True
+ save_strategy = (checkpointing_context or {}).get('save_strategy',
+ get_default_save_sharded_strategy(args.dist_ckpt_format))
+ if args.ckpt_fully_parallel_save:
+ if checkpointing_context is not None and 'save_strategy' in checkpointing_context:
+ # Already saved once before - don't need to rerun sharding validation
+ validate_sharding_integrity = not args.ckpt_assume_constant_structure
+ else:
+ save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, mpu.get_data_parallel_group(
+ with_context_parallel=True),
+ args.ckpt_assume_constant_structure)
+ # Store save strategy for future checkpoint saves
+ if checkpointing_context is not None:
+ checkpointing_context['save_strategy'] = save_strategy
+ async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
+ async_sharded_save=args.async_save)
+ else:
+ # Save.
+ if args.use_ema:
+ ema_state_dict = {k: v for k, v in state_dict.items() if k.startswith('ema')}
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith('ema')}
+
+ ensure_directory_exists(checkpoint_name)
+ torch.save(state_dict, checkpoint_name)
+
+ if args.use_ema:
+ ema_state_dict = {k.replace('ema', 'model'): v for k, v in ema_state_dict.items()}
+ torch.save(ema_state_dict, checkpoint_name + ".ema")
+
+ if not args.async_save:
+ assert async_save_request is None
+ # Wait so everyone is done (necessary)
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+ # And update the latest iteration
+ if not torch.distributed.is_initialized() \
+ or torch.distributed.get_rank() == 0:
+ tracker_filename = get_checkpoint_tracker_filename(args.save)
+
+ def iter_finalize_fn():
+ with open(tracker_filename, 'w') as f:
+ f.write(str(iteration))
+ print_rank_0(' successfully saved checkpoint from iteration {:7d} to {}'
+ .format(iteration, args.save))
+ if args.log_progress and args.async_save:
+ append_to_progress_log(f'Saved async checkpoint\tIteration: {iteration}',
+ barrier=False)
+
+ if args.async_save:
+ assert async_save_request is not None
+ async_save_request.add_finalize_fn(iter_finalize_fn)
+ else:
+ iter_finalize_fn()
+
+ if args.async_save:
+ schedule_async_save(async_save_request)
+ print_rank_0(' scheduled an async checkpoint save at iteration {:7d} to {}' \
+ .format(iteration, args.save))
+
+ # Wait so everyone is done (not necessary)
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+
+def generate_state_dict(args, model, optimizer, opt_param_scheduler,
+ rng_state, use_dist_ckpt=False, iteration=None,
+ optim_sd_kwargs=None):
+ # Arguments, iteration, and model.
+ state_dict = {}
+ ema_state_dict = {}
+ state_dict['args'] = args
+ state_dict['checkpoint_version'] = 3.0
+ if iteration is not None:
+ state_dict['iteration'] = iteration
+
+ if len(model) == 1:
+ state_dict['model'] = (model[0].sharded_state_dict()
+ if use_dist_ckpt else
+ model[0].state_dict_for_save_checkpoint())
+ else:
+ for i in range(len(model)):
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
+ state_dict['model%d' % i] = (
+ model[i].sharded_state_dict()
+ if use_dist_ckpt else
+ model[i].state_dict_for_save_checkpoint())
+
+ if args.use_ema:
+ if len(model) == 1:
+ state_dict['ema'] = {k: v for k, v in state_dict['model'].items() if k.startswith('ema')}
+ state_dict['model'] = {k: v for k, v in state_dict['model'].items() if not k.startswith('ema')}
+ else:
+ for i in range(len(model)):
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
+ state_dict['ema%d' % i] = {k.replace('ema.', ''): v for k, v in state_dict['model%d' % i].items() if
+ k.startswith('ema')}
+ state_dict['model%d' % i] = {k: v for k, v in state_dict['model%d' % i].items() if
+ not k.startswith('ema')}
+
+ # Optimizer stuff.
+ if not args.no_save_optim:
+ if optimizer is not None:
+ state_dict['optimizer'] = (optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {}))
+ if use_dist_ckpt else
+ optimizer.state_dict())
+ if opt_param_scheduler is not None:
+ state_dict['opt_param_scheduler'] = \
+ opt_param_scheduler.state_dict()
+ # RNG states.
+ if not args.no_save_rng:
+ state_dict["rng_state"] = rng_state
+ return state_dict
+
+
+def _load_base_checkpoint(load_dir, rank0=False, sharded_state_dict=None,
+ exit_on_missing_checkpoint=False, checkpoint_step=None):
+ """ Load the base state_dict from the given directory
+
+ If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
+ """
+ args = get_args()
+
+ # Read the tracker file and set the iteration.
+ tracker_filename = get_checkpoint_tracker_filename(load_dir)
+
+ # If no tracker file, return nothing
+ if not os.path.isfile(tracker_filename):
+ if not rank0:
+ print_rank_0('WARNING: could not find the metadata file {} '.format(
+ tracker_filename))
+ print_rank_0(' will not load any checkpoints and will start from '
+ 'random')
+
+ # Conditionally exit if checkpoint not found.
+ if exit_on_missing_checkpoint:
+ print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+ sys.exit()
+
+ return None, "", False
+
+ # Otherwise, read the tracker file and either set the iteration or
+ # mark it as a release checkpoint.
+ if checkpoint_step is not None:
+ iteration = checkpoint_step
+ release = False
+ else:
+ iteration, release = read_metadata(tracker_filename)
+
+ # Checkpoint.
+ if rank0:
+ checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
+ is_dist_ckpt = checkpoint_name is not None and dist_checkpointing.check_is_distributed_checkpoint(
+ checkpoint_name)
+ else:
+ checkpoint_name = get_checkpoint_name(load_dir, iteration, release,
+ return_base_dir=True)
+ is_dist_ckpt = dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name)
+ if not is_dist_ckpt:
+ checkpoint_name = get_checkpoint_name(load_dir, iteration, release,
+ return_base_dir=False)
+ dist_infix = "distributed " if is_dist_ckpt else ""
+ if release:
+ print_rank_0(f' loading release {dist_infix}checkpoint from {load_dir}')
+ else:
+ print_rank_0(f' loading {dist_infix}checkpoint from {load_dir} at iteration {iteration}')
+
+ # Load the checkpoint.
+ if is_dist_ckpt:
+ if rank0:
+ state_dict = dist_checkpointing.load_common_state_dict(checkpoint_name)
+ return state_dict, checkpoint_name, release
+
+ # at this point args are available
+ args = get_args()
+ if sharded_state_dict is None:
+ assert not args.auto_detect_ckpt_format and not args.use_dist_ckpt, (
+ args.auto_detect_ckpt_format, args.use_dist_ckpt)
+ raise RuntimeError(
+ 'Detected load from a distributed checkpoint, but neither --use-dist-ckpt nor --auto-detect-ckpt-format is set.')
+
+ load_strategy = get_default_load_sharded_strategy(checkpoint_name)
+ if args.ckpt_fully_parallel_load:
+ load_strategy = FullyParallelLoadStrategyWrapper(load_strategy,
+ mpu.get_data_parallel_group(with_context_parallel=True))
+ state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_name, load_strategy)
+ return state_dict, checkpoint_name, release
+
+ try:
+ state_dict = torch.load(checkpoint_name, map_location='cpu')
+ try:
+ args = get_args()
+ if not args.use_ema:
+ return state_dict, checkpoint_name, release
+
+ len_model = sum(1 for key in state_dict if key.startswith('model'))
+ ema_state_dict = torch.load(checkpoint_name + ".ema", map_location='cpu')
+
+ if len(ema_state_dict) == 0:
+ return state_dict, checkpoint_name, release
+
+ if len_model == 1:
+ ema_state_dict['model'] = {f'ema.{k}': v for k, v in ema_state_dict['model'].items()}
+ state_dict['model'].update(ema_state_dict['ema'])
+ else:
+ for i in range(len_model):
+ ema_state_dict['model%d' % i] = {f'ema.{k}': v for k, v in ema_state_dict['model%d' % i].items()}
+ state_dict['model%d' % i].update(ema_state_dict['model%d' % i])
+ except BaseException as e:
+ print_rank_0('could not load the ema checkpoint, continue without ema checkpoint')
+ print_rank_0(e)
+ ema_state_dict = {}
+ except ModuleNotFoundError:
+ from megatron.legacy.fp16_deprecated import loss_scaler
+ # For backward compatibility.
+ if not rank0:
+ print_rank_0(' > deserializing using the old code structure ...')
+ sys.modules['fp16.loss_scaler'] = sys.modules[
+ 'megatron.legacy.fp16_deprecated.loss_scaler']
+ sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
+ 'megatron.legacy.fp16_deprecated.loss_scaler']
+ sys.modules['megatron.model'] = sys.modules['megatron.legacy.model']
+ state_dict = torch.load(checkpoint_name, map_location='cpu')
+ sys.modules.pop('fp16.loss_scaler', None)
+ sys.modules.pop('megatron.fp16.loss_scaler', None)
+ sys.modules.pop('megatron.model', None)
+ except BaseException as e:
+ print_rank_0('could not load the checkpoint')
+ print_rank_0(e)
+ sys.exit()
+
+ return state_dict, checkpoint_name, release
+
+
+def save_checkpoint_ema_wrapper(func):
+ @wraps(func)
+ def save_checkpoint_ema(*args, **kwargs):
+ model, optimizer, opt_param_scheduler = args[1:4]
+ state_dict = get_ema_model(model, optimizer)
+ setattr(opt_param_scheduler, 'ema_model_state_dict', state_dict)
+ func(*args[:3], opt_param_scheduler, *args[4:], **kwargs)
+ setattr(opt_param_scheduler, 'ema_model_state_dict', None)
+
+ return save_checkpoint_ema
+
+
+def generate_state_dict_ema_wrapper(func):
+ @wraps(func)
+ def generate_state_dict_ema(*args, **kwargs):
+ opt_param_scheduler = args[3]
+ state_dict = func(*args, **kwargs)
+ if hasattr(opt_param_scheduler, 'ema_model_state_dict'):
+ ema_model_state_dict = getattr(opt_param_scheduler, 'ema_model_state_dict')
+ state_dict.update(ema_model_state_dict)
+ return state_dict
+
+ return generate_state_dict_ema
+
+
+def get_ema_model(model, optimizer):
+ state_dict = dict()
+ global_args = get_args()
+ use_dist_ckpt = global_args.use_dist_ckpt
+ unwrapped_model = unwrap_model(model)
+ unchained_optimizer = unchain_optimizer(optimizer)
+ ema_optimizer_applier(unchained_optimizer)
+ if len(unwrapped_model) == 1:
+ state_dict['ema_model'] = (unwrapped_model[0].shared_state_dict()
+ if use_dist_ckpt else
+ unwrapped_model[0].state_dict_for_save_checkpoint())
+ state_dict = ema_state_dict_to_cpu(state_dict, 'ema_model')
+ ema_optimizer_restore(unchained_optimizer)
+ return state_dict
+ for sub_model in unwrapped_model:
+ sub_model_idx = unwrapped_model.index(sub_model)
+ mpu.set_virtual_pipeline_model_parallel_rank(sub_model_idx)
+ state_dict['ema_model%d' % sub_model_idx] = (
+ sub_model.sharded_state_dict()
+ if use_dist_ckpt else
+ sub_model.state_dict_for_save_checkpoint())
+ state_dict = ema_state_dict_to_cpu(state_dict, 'ema_model%d' % sub_model_idx)
+ ema_optimizer_restore(unchained_optimizer)
+ return state_dict
+
+
+def unchain_optimizer(chained_optimizer):
+ if isinstance(chained_optimizer, ChainedOptimizer):
+ return chained_optimizer.chained_optimizers
+ return [chained_optimizer]
+
+
+def ema_optimizer_applier(unchained_optimizer):
+ for optim in unchained_optimizer:
+ optim.optimizer.store(optim.optimizer.param_groups)
+ optim.optimizer.copy_to()
+ param_sync(optim)
+
+
+def ema_optimizer_restore(unchained_optimizer):
+ for optim in unchained_optimizer:
+ optim.optimizer.restore(optim.optimizer.param_groups)
+ param_sync(optim)
+ torch.distributed.barrier()
+ for optim in unchained_optimizer:
+ optim.update_successful = False
+
+
+def param_sync(optim):
+ if hasattr(optim, "_copy_main_params_to_model_params"):
+ optim._copy_main_params_to_model_params()
+ if hasattr(optim, "_reset_metadata_and_sync_gather_all_model_params"):
+ optim.update_successful = True
+ optim._reset_metadata_and_sync_gather_all_model_params(force_sync=True)
+
+
+def ema_state_dict_to_cpu(state_dict, ema_key):
+ for k, v in state_dict[ema_key].items():
+ if not torch.is_tensor(v):
+ continue
+ new_v = v.detach().cpu().clone()
+ state_dict[ema_key][k] = new_v
+ return state_dict
diff --git a/model/train/yoco_moe/mindspeed/core/__init__.py b/model/train/yoco_moe/mindspeed/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/__init__.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9745d7d51f6c37e7a01a6cc0500a89fc86eb2e7a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/__init__.py
@@ -0,0 +1,245 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import json
+import operator
+from functools import reduce
+import datetime
+import threading
+
+import torch
+import numpy as np
+import pandas as pd
+
+KV_STORE = None
+ITERATION_LOOP_TIME = 5
+BAND_WIDTH_UNIDIRECTIONAL = 25 # GB/s
+SEARCH_CACHE_PATH = None
+MODULE_PATTERN = 'PP{}_TP{}_DP{}_CP{}_UP{}_MBS{}_MODULE.json'
+OPERATOR_PATTERN = 'PP{}_TP{}_DP{}_CP{}_UP{}_MBS{}_OPERATOR'
+
+
+# Operator dims after merging
+ARD_NUM_DIMS = {
+ 'MatMul': 3,
+ 'BatchMatMul': 4,
+ 'Softmax': 4,
+ 'SoftmaxGrad': 4,
+ 'RmsNorm': 3,
+ 'RmsNormGrad': 3,
+ 'LayerNorm': 3,
+ 'LayerNormGrad': 3,
+ 'FlashAttentionScore': 3,
+ 'FlashAttentionScoreGrad': 3
+}
+
+
+# profiling data filed
+class KeyField:
+ OpType = 'Type'
+ InputShapes = 'Input Shapes'
+ OutputShapes = 'Output Shapes'
+ Duration = 'Duration(us)'
+ FwdTime = 'fwd_time'
+ BwdTime = 'bwd_time'
+
+
+class GlobalMemoryBuffer:
+ buffers_length = [0, 0, 0]
+ buffers = [None, None, None]
+
+ @staticmethod
+ def get_tensor(shape: list, index):
+ if index not in (0, 1, 2):
+ raise AssertionError('index must be 0, 1, 2')
+ data_type = torch.float16
+ required_len = reduce(operator.mul, shape, 1)
+ if GlobalMemoryBuffer.buffers_length[index] < required_len:
+ GlobalMemoryBuffer.buffers[index] = torch.empty(
+ required_len, dtype=data_type, requires_grad=False, device=torch.cuda.current_device()
+ )
+ GlobalMemoryBuffer.buffers_length[index] = required_len
+ return GlobalMemoryBuffer.buffers[index][0:required_len].view(*shape).uniform_()
+
+
+class SingletonType(type):
+ single_lock = threading.RLock()
+
+ def __call__(cls, *args, **kwargs):
+ with SingletonType.single_lock:
+ if not hasattr(cls, "_instance"):
+ cls._instance = super(SingletonType, cls).__call__(*args, **kwargs)
+ return cls._instance
+
+
+class SampleCache:
+ def __init__(self):
+ self.MatMul = {}
+ self.RmsNorm = {}
+ self.RmsNormGrad = {}
+ self.BatchMatMul = {}
+ self.Add = {}
+ self.LayerNorm = {}
+ self.LayerNormGrad = {}
+ self.ScaledMaskedSoftmax = {}
+ self.ScaledMaskedSoftmaxGrad = {}
+ self.FastGeluGrad = {}
+ self.FastGelu = {}
+ self.Mul = {}
+ self.Softmax = {}
+ self.SoftmaxGrad = {}
+ self.FlashAttentionScore = {}
+ self.FlashAttentionScoreGrad = {}
+
+ def clear_cache(self):
+ for attr in self.__dict__:
+ setattr(self, attr, {})
+
+
+class ModelManager:
+ def __init__(self, npu_type='910B'):
+ self.models = {}
+ self.npu_type = npu_type
+
+ def cache_model(self, model, op):
+ self.models[op] = model
+
+ def get_cached_model(self, model_name: str):
+ return self.models.get(model_name, None)
+
+ def load_model(self, model, op, model_dir):
+ if not os.path.exists(model_dir):
+ raise FileNotFoundError(f"Can't find '{model_dir}'.")
+ path = os.path.join(model_dir, f"{op}_{self.npu_type}.pth")
+ weight = torch.load(path)
+ model.set_model_info(weight.popitem()[1])
+ model.load_state_dict(weight)
+ # if use model to predict,need to set training=False,otherwise require inputs dims==model_train_inputs dims
+ # during fit,after clear model cache(self.train()),training's value will be reset True
+ model.training = False
+ self.models[op] = model
+
+ def save_model(self, model, op, model_dir):
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir, exist_ok=False)
+ weight = model.state_dict()
+ weight['model_info'] = model.get_model_info()
+ torch.save(weight, f'{model_dir}/{op}_{self.npu_type}.pth')
+
+ def save_models(self, model_dir):
+ for op, op_model in self.models.items():
+ self.save_model(op_model, op, model_dir)
+
+
+class OperateProfileCache(metaclass=SingletonType):
+ def __init__(self):
+ self.data_frame = pd.DataFrame(
+ columns=[KeyField.OpType, KeyField.InputShapes, KeyField.OutputShapes, KeyField.FwdTime, KeyField.BwdTime]
+ )
+
+ def record(self, op_type: str, input_shapes: list, output_shapes: list, fwd_time: float, bwd_time: float):
+ _, _, exist = self.find(op_type, input_shapes)
+ if not exist:
+ input_shapes_str = OperateProfileCache.shapes_to_str(input_shapes)
+ output_shape_str = OperateProfileCache.shapes_to_str(output_shapes)
+ self.data_frame.loc[len(self.data_frame.index)] = [
+ op_type, input_shapes_str, output_shape_str, fwd_time, bwd_time
+ ]
+
+ def find(self, op_type: str, input_shapes: list):
+ input_shapes_str = OperateProfileCache.shapes_to_str(input_shapes)
+ data = self.data_frame[
+ (self.data_frame[KeyField.OpType] == op_type) &
+ (self.data_frame[KeyField.InputShapes] == input_shapes_str)
+ ]
+ fwd_time = data[KeyField.FwdTime].mean()
+ bwd_time = data[KeyField.BwdTime].mean()
+ from_cache = False if np.isnan(fwd_time) and np.isnan(bwd_time) else True
+ return fwd_time, bwd_time, from_cache
+
+ @staticmethod
+ def shapes_to_str(shapes):
+ result = ''
+ index = 0
+ for shape in shapes:
+ result += ','.join(map(lambda x: str(x), shape)) if isinstance(shape, list) else str(shape)
+ if index < len(shapes) - 1:
+ result += ';' if isinstance(shape, list) else ','
+ index += 1
+ result = '"' + result
+ result = result + '"'
+ return result
+
+
+def get_cache_path():
+ global SEARCH_CACHE_PATH
+ if SEARCH_CACHE_PATH is None:
+ SEARCH_CACHE_PATH = os.getcwd() + os.sep + 'autoparallel_temp_cache' + os.sep
+ try:
+ os.makedirs(SEARCH_CACHE_PATH, exist_ok=True)
+ print(f"Create cache: {SEARCH_CACHE_PATH}")
+ except Exception:
+ print(f'Create cache directory failed')
+ SEARCH_CACHE_PATH = os.getcwd()
+ return SEARCH_CACHE_PATH
+
+
+def analyse_module_profile(profile_file, key):
+ if key not in ('step_time', 'transformer_act_mem'):
+ raise AssertionError('key[{}] error'.format(key))
+
+ if not os.path.exists(path=profile_file):
+ return float('inf')
+
+ with open(profile_file, 'r') as file:
+ try:
+ content = file.read()
+ content = json.loads(content)
+ return float(content.get(key))
+ except Exception:
+ return float('inf')
+
+
+def set_kv_store(args):
+ global KV_STORE
+ if args.node_rank == 0:
+ KV_STORE = torch.distributed.TCPStore(
+ host_name=args.master_addr,
+ port=int(args.master_port) + 2,
+ world_size=args.nnodes,
+ is_master=True,
+ timeout=datetime.timedelta(seconds=30)
+ )
+ else:
+ KV_STORE = torch.distributed.TCPStore(
+ host_name=args.master_addr,
+ port=int(args.master_port) + 2,
+ world_size=args.nnodes,
+ is_master=False
+ )
+
+
+def get_kv_store():
+ global KV_STORE
+ if KV_STORE is None:
+ raise AssertionError('KV_STORE must be initialized')
+ return KV_STORE
+
+
+# init SingletonType class
+model_manager = ModelManager()
+sample_cache = SampleCache()
+operator_cache = OperateProfileCache()
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_apply.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_apply.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f777a07751a50820d1e29e9d9d3e3e067e31836
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_apply.py
@@ -0,0 +1,156 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import time
+import math
+
+import torch
+from megatron.training.global_vars import get_args
+
+from mindspeed.core.auto_parallel import set_kv_store
+from mindspeed.core.auto_parallel.auto_parallel_optimizer import SearchByGreyBox
+from mindspeed.core.auto_parallel.auto_parallel_memory import MemoryCostModel
+from mindspeed.core.auto_parallel.auto_parallel_profiling import (
+ DistributedMemoryProfiler,
+ DistributedOperateProfiler,
+ DistributedPerformanceProfiler
+)
+
+
+def filter_unvalid_configs(search_spaces):
+ memory_model = MemoryCostModel()
+ fitting_configs = memory_model.get_fitting_configurations(search_spaces)
+ for config in fitting_configs:
+ mem = DistributedMemoryProfiler().launch(config)
+ if not math.isinf(mem):
+ memory_model.profiled_configs.append(config)
+ memory_model.profiled_configs_memory.append(mem)
+
+ print(f"profiled_configs: {memory_model.profiled_configs}")
+ print(f"profiled_configs_mem: {memory_model.profiled_configs_memory}")
+
+ memory_model.fit_model()
+ valid_configs, valid_configs_memory = [], []
+ for config in search_spaces:
+ cost_memory = memory_model.get_peak_memory(config)
+ if not memory_model.is_oom(cost_memory):
+ valid_configs.append(config)
+ valid_configs_memory.append(cost_memory)
+ return valid_configs
+
+
+def build_initial_spaces(args):
+ world_size = args.nproc_per_node * args.nnodes
+ device_count = args.nproc_per_node
+
+ solutions = []
+ for pp in range(1, world_size + 1):
+ if world_size % pp != 0 or args.num_layers % pp != 0:
+ continue
+
+ for i in range(device_count):
+ tp = 2 ** i
+ if tp > device_count or tp > (world_size // pp):
+ break
+ if (args.num_query_groups > 1 and args.num_query_groups % tp != 0) \
+ or (args.num_attention_heads % tp != 0):
+ break
+
+ max_cp_size = world_size // (pp * tp)
+ for cp_size in range(1, max_cp_size + 1):
+ if world_size % (pp * tp * cp_size) != 0 or \
+ args.global_batch_size % (world_size // (pp * tp * cp_size)) != 0:
+ continue
+
+ for up in range(1, cp_size + 1):
+ if cp_size % up != 0:
+ continue
+ cp = cp_size // up
+ head, remainder = divmod(args.num_attention_heads, up * tp)
+ if (head < 1 or remainder != 0) or (args.seq_length % (2 * cp) != 0):
+ continue
+
+ dp = world_size // (pp * tp * cp_size)
+ dp_group_batch_size = args.global_batch_size // dp
+ for num_mb in range(1, dp_group_batch_size + 1):
+ if dp_group_batch_size % num_mb != 0:
+ continue
+ mbs = dp_group_batch_size // num_mb
+ solutions.append([pp, tp, dp, cp, up, mbs])
+ return solutions
+
+
+def monitor_train_task():
+ while True:
+ message = torch.tensor([0 for _ in range(7)], dtype=torch.int)
+ torch.distributed.broadcast(message, 0)
+ task_type = message[-1].item()
+ config = [m.item() for m in message[:-1]]
+ if task_type == -1:
+ break
+ elif task_type == 0:
+ DistributedMemoryProfiler().launch(config)
+ elif task_type == 1:
+ DistributedOperateProfiler().launch(config)
+ elif task_type == 2:
+ DistributedPerformanceProfiler().launch(config)
+
+
+def export_results(config):
+ results = {}
+ results['optimal_parallel_strategy'] = {}
+ results['optimal_parallel_strategy']['pipeline-model-parallel-size'] = config[0]
+ results['optimal_parallel_strategy']['tensor-model-parallel-size'] = config[1]
+ results['optimal_parallel_strategy']['data-parallel-size'] = config[2]
+ results['optimal_parallel_strategy']['micro-batch-size'] = config[-1]
+ if config[3] > 1 and config[4] > 1:
+ results['optimal_parallel_strategy']['context-parallel-algo'] = 'hybrid_cp_algo'
+ results['optimal_parallel_strategy']['context-parallel-size'] = config[3] * config[4]
+ results['optimal_parallel_strategy']['ulysses-degree-in-cp'] = config[4]
+ elif config[3] > 1 and config[4] == 1:
+ results['optimal_parallel_strategy']['context-parallel-algo'] = 'megatron_cp_algo'
+ results['optimal_parallel_strategy']['context-parallel-size'] = config[3]
+ elif config[3] == 1 and config[4] > 1:
+ results['optimal_parallel_strategy']['context-parallel-algo'] = 'ulysses_cp_algo'
+ results['optimal_parallel_strategy']['context-parallel-size'] = config[4]
+ return json.dumps(results)
+
+
+def search_optimal_configuration(args):
+ set_kv_store(args)
+
+ init_method = 'tcp://{}:{}'.format(args.master_addr, int(args.master_port) + 1)
+ torch.distributed.init_process_group(
+ backend=torch.distributed.Backend.GLOO,
+ init_method=init_method,
+ rank=args.node_rank,
+ world_size=args.nnodes
+ )
+
+ if args.node_rank == 0:
+ start_time = time.time()
+ search_space = build_initial_spaces(args)
+ search_space = filter_unvalid_configs(search_space)
+ print(f"filter search_space: {len(search_space)}")
+ print("\n".join(str(item) for item in search_space), flush=True)
+
+ config, _ = SearchByGreyBox().search(get_args(), search_space)
+ torch.distributed.broadcast(torch.tensor([-1 for _ in range(7)], dtype=torch.int), 0)
+
+ results = export_results(config)
+ print(f"find optimal configuration: {results}, cost_time: {time.time() - start_time}")
+ else:
+ monitor_train_task()
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_memory.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7a7edd312b6d6c844e2f1b0f975b59b5c202061
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_memory.py
@@ -0,0 +1,168 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from itertools import product
+
+import numpy as np
+import torch
+from megatron.training.global_vars import get_args
+
+from mindspeed.core.auto_parallel import SingletonType
+
+
+class MemoryCostModel(metaclass=SingletonType):
+ def __init__(self):
+ args = get_args()
+ self.num_layers = args.num_layers
+ self.num_attn_heads = args.num_attention_heads
+ self.hidden_size = args.hidden_size
+ self.seq_length = args.seq_length
+ self.ffn_hidden_size = args.ffn_hidden_size
+ if not self.ffn_hidden_size:
+ self.ffn_hidden_size = 4 * self.hidden_size
+
+ self.model = None
+ self.profiled_configs = []
+ self.profiled_configs_memory = []
+ self.max_available_memory = None
+
+ @staticmethod
+ def cal_coeff(config):
+ _, tp, _, cp, up, b = config
+ coeff = [
+ 1,
+ b * (1 / tp) * (1 / cp) * (1 / up),
+ b * (1 / tp) * (1 / cp) * (1 / cp) * (1 / up),
+ b * (1 / cp) * (1 / up)
+ ]
+ return np.array(coeff)
+
+ @staticmethod
+ def cal_coeff_matrix(configs):
+ coeff_matrix = []
+ for config in configs:
+ _, tp, _, cp, up, b = config
+ coeff_matrix.append([
+ 1,
+ b * (1 / tp) * (1 / cp) * (1 / up),
+ b * (1 / tp) * (1 / cp) * (1 / cp) * (1 / up),
+ b * (1 / cp) * (1 / up)
+ ])
+ return np.array(coeff_matrix)
+
+ def is_oom(self, cost_memory):
+ if self.max_available_memory is None:
+ properties = torch.npu.get_device_properties(0)
+ self.max_available_memory = properties.total_memory / (1024 ** 3)
+ # Եڴ1.2ΪOOMֵ
+ return cost_memory > (self.max_available_memory * 1.2)
+
+ def get_fitting_configurations(self, search_spaces):
+ search_spaces_matrix = np.array(search_spaces)
+ temp_search_spaces = [config for config in search_spaces if config[-1] < 8]
+
+ tp_group = []
+ max_tp = search_spaces_matrix[:, 1].max()
+ for config in temp_search_spaces:
+ _, tp, _, cp, up, _ = config
+ if cp == 1 and up == 1 and tp == max_tp:
+ tp_group.append(config)
+
+ cp_group = []
+ min_cp = search_spaces_matrix[:, 3].min()
+ for config in temp_search_spaces:
+ pp, tp, _, cp, up, _ = config
+ if tp > 1 or up > 1:
+ continue
+ if pp > 1 and cp > min_cp:
+ cp_group.append(config)
+
+ up_group = []
+ min_up = search_spaces_matrix[:, 4].min()
+ for config in temp_search_spaces:
+ pp, tp, _, cp, up, _ = config
+ if tp > 1 or cp > 1:
+ continue
+ if pp > 1 and up > min_up:
+ up_group.append(config)
+
+ cp_up_group = []
+ for config in temp_search_spaces:
+ _, tp, _, cp, up, _ = config
+ if tp == 1 and cp > 1 and up > 1:
+ cp_up_group.append(config)
+
+ tp_cp_up_group = []
+ for config in temp_search_spaces:
+ _, tp, _, cp, up, _ = config
+ if tp > 1 and cp > 1 and up > 1:
+ tp_cp_up_group.append(config)
+
+ product_iter = product(*[tp_group, cp_group, up_group, cp_up_group, tp_cp_up_group])
+ fitting_group, cur_condition_number = None, float('inf')
+
+ for group in product_iter:
+ # С100ľֵȶЧ
+ if cur_condition_number < 100:
+ break
+
+ empty_set = set([row[-1] for row in group])
+ if len(empty_set) < 2:
+ continue
+
+ coeff_matrix = MemoryCostModel.cal_coeff_matrix(group)
+ coeff_matrix = coeff_matrix.transpose() @ coeff_matrix
+ if np.linalg.matrix_rank(coeff_matrix) == coeff_matrix.shape[0]:
+ con_num = np.linalg.cond(coeff_matrix)
+ if con_num < cur_condition_number:
+ fitting_group = group
+ cur_condition_number = con_num
+
+ print(f"fitting_group: {fitting_group} condition_number: {cur_condition_number}", flush=True)
+ return fitting_group
+
+
+ def fit_model(self):
+ coeff_matrix = MemoryCostModel.cal_coeff_matrix(self.profiled_configs)
+ profiled_configs_memory = np.array(self.profiled_configs_memory)
+ self.model = np.linalg.inv(coeff_matrix.transpose() @ coeff_matrix) \
+ @ coeff_matrix.transpose() \
+ @ profiled_configs_memory
+
+ def predict(self, config):
+ config_matrix = MemoryCostModel.cal_coeff(config)
+ pred_memory = config_matrix @ self.model
+ return pred_memory
+
+ def get_peak_memory(self, config):
+ args = get_args()
+ pp, tp, _ = config[0], config[1], config[-1]
+ hidden_size = self.hidden_size
+ ffn_hidden_size = self.ffn_hidden_size
+ if args.swiglu:
+ ffn_hidden_size *= 2
+ transformer_params_count = (4 * hidden_size * hidden_size + 2 * hidden_size * ffn_hidden_size) / tp
+ total_params_count = transformer_params_count * (self.num_layers // pp)
+
+ mem_para = 2 * total_params_count
+ mem_grad = 2 * total_params_count
+ mem_optimizer = 12 * total_params_count if args.reuse_fp32_param else 16 * total_params_count
+ mem_activation_layer = abs(self.predict(config)) * (1024 ** 3)
+ mem_activation_batch = mem_activation_layer * (self.num_layers // pp)
+ mem_activation = mem_activation_batch * pp
+ m1 = mem_para + mem_optimizer + mem_activation
+ m2 = mem_para + mem_optimizer + mem_activation + mem_grad - mem_activation_batch
+ peak_memory = max(m1, m2)
+ return peak_memory / (1024 ** 3) + 4
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_model.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a13c8ab02fd89d978f83210eaf08e4b3f79ea0f8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_model.py
@@ -0,0 +1,462 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+import math
+from functools import reduce
+
+import numpy as np
+import torch
+import torch_npu
+from megatron.training.global_vars import get_args
+
+from mindspeed.core.auto_parallel import (
+ ITERATION_LOOP_TIME,
+ BAND_WIDTH_UNIDIRECTIONAL,
+ operator_cache,
+ GlobalMemoryBuffer
+)
+from mindspeed.core.auto_parallel.auto_parallel_rectify import Sampler
+from mindspeed.core.auto_parallel.auto_parallel_profiling import CommProfiling
+from mindspeed.model.transformer import (
+ get_attention_mask,
+ generate_attention_mask
+)
+
+
+class Linear(torch.nn.Module):
+ def __init__(self):
+ super(Linear, self).__init__()
+
+ def forward(self, inputs):
+ x, y = inputs
+ return torch.matmul(x, y.t())
+
+
+class LayerNorm(torch.nn.Module):
+ def __init__(self, hidden_size, eps=1e-5):
+ super(LayerNorm, self).__init__()
+ self.layer_norm = torch.nn.LayerNorm(normalized_shape=hidden_size, eps=eps)
+
+ def forward(self, x):
+ return self.layer_norm(*x)
+
+
+class FusedRmsNorm(torch.nn.Module):
+ def __init__(self, hidden_size, eps=1e-6) -> None:
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=torch.float16)).npu()
+ self.eps = eps
+
+ def forward(self, x):
+ return torch_npu.npu_rms_norm(x[0], self.weight, epsilon=self.eps)[0]
+
+
+class BatchMatMul(torch.nn.Module):
+ def __init__(self):
+ super(BatchMatMul, self).__init__()
+
+ def forward(self, inputs):
+ x, y = inputs
+ return torch.bmm(x, y)
+
+
+class FlashAttention(torch.nn.Module):
+ def __init__(self, head_dim):
+ super().__init__()
+ self.head_dim = head_dim
+ self.scale = 1.0 / math.sqrt(self.head_dim)
+ self.pre_tockens = 65536
+ self.next_tockens = 0
+
+ generate_attention_mask()
+ self.attention_mask = get_attention_mask()
+
+ def forward(self, x):
+ q, k, v = x
+ seq_length, _, hd = q.shape[0], q.shape[1], q.shape[2]
+ head_num = hd // self.head_dim
+ output = torch_npu.npu_fusion_attention(
+ q, k, v, head_num, 'SBH',
+ pse=None,
+ padding_mask=None,
+ atten_mask=self.attention_mask,
+ scale=self.scale,
+ pre_tockens=self.pre_tockens,
+ next_tockens=self.next_tockens,
+ keep_prob=1.0,
+ inner_precise=0,
+ sparse_mode=get_args().sparse_mode
+ )[0]
+ return output
+
+
+class TransformerBlock:
+ def __init__(self):
+ self.number_sample = 100
+ self.noise_model = OperatorNoiseSampler(self.number_sample)
+
+ def norm(self):
+ args = get_args()
+ tp = args.tensor_model_parallel_size
+ cp = args.context_parallel_size // args.ulysses_degree_in_cp
+ up = args.ulysses_degree_in_cp
+ input_shape = [args.seq_length // cp // tp // up, args.micro_batch_size, args.hidden_size]
+ if args.normalization == 'RMSNorm':
+ ftime, btime = self.noise_model.fused_rms_norm(input_shape, input_shape, args.hidden_size)
+ else:
+ ftime, btime = self.noise_model.layernorm(input_shape, input_shape, args.hidden_size)
+ return ftime, btime
+
+ def self_attention_with_fa(self):
+ args = get_args()
+ tp = args.tensor_model_parallel_size
+ cp = args.context_parallel_size // args.ulysses_degree_in_cp
+ up = args.ulysses_degree_in_cp
+ ftime, btime = self.noise_model.flash_attention(
+ [args.seq_length // cp, args.micro_batch_size, args.hidden_size // tp // up],
+ [args.seq_length // cp, args.micro_batch_size, args.hidden_size // tp // up],
+ [args.seq_length // cp, args.micro_batch_size, args.hidden_size // tp // up],
+ [args.seq_length // cp, args.micro_batch_size, args.hidden_size // tp // up],
+ args.hidden_size // args.num_attention_heads,
+ )
+ return ftime, btime
+
+ def get_block_time(self):
+ args = get_args()
+ s = args.seq_length
+ a = args.num_attention_heads
+ h = args.hidden_size
+ ffn = args.ffn_hidden_size if args.ffn_hidden_size is not None else 4 * args.hidden_size
+ d = args.hidden_size // args.num_attention_heads
+ b = args.micro_batch_size
+ tp = args.tensor_model_parallel_size
+ cp = args.context_parallel_size // args.ulysses_degree_in_cp
+ up = args.ulysses_degree_in_cp
+
+ fwd_time = np.array([0 for _ in range(self.number_sample)]).astype(np.float64)
+ bwd_time = np.array([0 for _ in range(self.number_sample)]).astype(np.float64)
+
+ ftime, btime = self.norm()
+ fwd_time += ftime
+ bwd_time += btime
+
+ all_gather_time = CommProfiling.get_comm_time([s // cp // up // tp, b, h], tp, 'all_gather')
+ reduce_scatter_time = CommProfiling.get_comm_time([s // cp // up, b, h], tp, 'reduce_scatter')
+ fwd_time += all_gather_time
+ bwd_time += reduce_scatter_time
+
+ ftime, btime = self.noise_model.matmul(
+ [s // cp // up * b, h],
+ [3 * h // tp, h],
+ [s // cp // up * b, 3 * h // tp]
+ )
+ fwd_time += ftime
+ bwd_time += btime
+
+ if not args.use_flash_attn:
+ raise AssertionError('the auto-parallel only support FA')
+ else:
+ alltoall_time = CommProfiling.get_comm_time([s // cp // up, b, a // tp, d], up, 'alltoall')
+ fwd_time += (3 * alltoall_time)
+ bwd_time += (3 * alltoall_time)
+
+ send_recv_time = CommProfiling.get_send_recv_time([2, 2, s // cp // 2, b, a // tp // up * d])
+ ftime, btime = self.self_attention_with_fa()
+ for _ in range(cp - 1):
+ fwd_time += max([ftime.max(), send_recv_time])
+ bwd_time += max([btime.max(), send_recv_time])
+ fwd_time += ftime
+ bwd_time += btime
+
+ alltoall_time = CommProfiling.get_comm_time([s // cp, b, a // tp // up, d], up, 'alltoall')
+ fwd_time += alltoall_time
+ bwd_time += alltoall_time
+
+ ftime, btime = self.noise_model.matmul([s // cp // up * b, h // tp], [h, h // tp], [s // cp // up * b, h])
+ fwd_time += ftime
+ bwd_time += btime
+
+ reduce_scatter_time = CommProfiling.get_comm_time([s // cp // up, b, h], tp, 'reduce_scatter')
+ all_gather_time = CommProfiling.get_comm_time([s // cp // up // tp, b, h], tp, 'all_gather')
+ fwd_time += reduce_scatter_time
+ bwd_time += all_gather_time
+
+ ftime, btime = self.norm()
+ fwd_time += ftime
+ bwd_time += btime
+
+ all_gather_time = CommProfiling.get_comm_time([s // cp // up // tp, b, h], tp, 'all_gather')
+ reduce_scatter_time = CommProfiling.get_comm_time([s // cp // up, b, h], tp, 'reduce_scatter')
+ fwd_time += all_gather_time
+ bwd_time += reduce_scatter_time
+
+ ftime, btime = self.noise_model.matmul([s // cp // up * b, h], [ffn // tp, h], [s // cp // up * b, ffn // tp])
+ fwd_time += ftime
+ bwd_time += btime
+
+ # 4h->h
+ ftime, btime = self.noise_model.matmul([s // cp // up * b, ffn // tp], [h, ffn // tp], [s // cp // up * b, h])
+ fwd_time += ftime
+ bwd_time += btime
+
+ reduce_scatter_time = CommProfiling.get_comm_time([s // cp // up, b, h], tp, 'reduce_scatter')
+ all_gather_time = CommProfiling.get_comm_time([s // cp // up // tp, b, h], tp, 'all_gather')
+ fwd_time += reduce_scatter_time
+ bwd_time += all_gather_time
+
+ return fwd_time, bwd_time
+
+
+class OperatorNoiseSampler:
+ def __init__(self, num_sample=100):
+ self.sampling = Sampler(num_sample=num_sample)
+
+ @staticmethod
+ def measure_matmul_time(left_shape, left_transpose, right_shape, right_transpose):
+ left_matrix = GlobalMemoryBuffer.get_tensor(left_shape, 0)
+ left_matrix = left_matrix if not left_transpose else left_matrix.t()
+ right_matrix = GlobalMemoryBuffer.get_tensor(right_shape, 1)
+ right_matrix = right_matrix if not right_transpose else right_matrix.t()
+
+ for _ in range(ITERATION_LOOP_TIME):
+ torch.matmul(left_matrix, right_matrix)
+
+ torch.npu.synchronize()
+ start_time = time.time()
+ for _ in range(ITERATION_LOOP_TIME):
+ torch.matmul(left_matrix, right_matrix)
+ torch.npu.synchronize()
+ return (time.time() - start_time) * 1e6 / ITERATION_LOOP_TIME
+
+ @staticmethod
+ def measure_batchmatmul_time(left_shape, left_transpose, right_shape, right_transpose):
+ left_matrix = GlobalMemoryBuffer.get_tensor(left_shape, 0)
+ left_matrix = left_matrix if not left_transpose else left_matrix.permute(0, 2, 1)
+ right_matrix = GlobalMemoryBuffer.get_tensor(right_shape, 0)
+ right_matrix = right_matrix if not right_transpose else right_matrix.permute(0, 2, 1)
+
+ for _ in range(ITERATION_LOOP_TIME):
+ torch.bmm(left_matrix, right_matrix)
+
+ torch.npu.synchronize()
+ start_time = time.time()
+ for _ in range(ITERATION_LOOP_TIME):
+ torch.bmm(left_matrix, right_matrix)
+ torch.npu.synchronize()
+ return (time.time() - start_time) * 1e6 / ITERATION_LOOP_TIME
+
+ def matmul(self, input_shape1, input_shape2, output_shape):
+ ftime, _, from_cache = operator_cache.find('MatMul', [input_shape1, input_shape2])
+ if not from_cache:
+ ftime = self.measure_matmul_time(input_shape1, False, input_shape2, True)
+ ftime_uncertainty = self.sampling.run('MatMul', ftime, output_shape, input_shape1, input_shape2)
+ operator_cache.record('MatMul', [input_shape1, input_shape2], output_shape, ftime, 0)
+
+ btime1, _, from_cache = operator_cache.find('MatMul', [output_shape, input_shape2])
+ if not from_cache:
+ btime1 = self.measure_matmul_time(output_shape, False, input_shape2, False)
+ btime1_uncertainty = self.sampling.run('MatMul', btime1, input_shape1, output_shape, input_shape2)
+ operator_cache.record('MatMul', [output_shape, input_shape2], input_shape1, btime1, 0)
+
+ btime2, _, from_cache = operator_cache.find('MatMul', [output_shape, input_shape1])
+ if not from_cache:
+ btime2 = self.measure_matmul_time(output_shape, True, input_shape1, False)
+ btime2_uncertainty = self.sampling.run('MatMul', btime2, input_shape2, output_shape, input_shape1)
+ operator_cache.record('MatMul', [output_shape, input_shape1], input_shape2, btime2, 0)
+ return ftime_uncertainty, btime1_uncertainty + btime2_uncertainty
+
+ def batch_matmul(self, input_shape1, input_shape2, output_shape):
+ ftime, _, from_cache = operator_cache.find('BatchMatMul', [input_shape1, input_shape2])
+ if not from_cache:
+ ftime = self.measure_batchmatmul_time(input_shape1, False, input_shape2, False)
+ ftime_uncertainty = self.sampling.run('BatchMatMul', ftime, output_shape, input_shape1, input_shape2)
+ operator_cache.record('BatchMatMul', [input_shape1, input_shape2], output_shape, ftime, 0)
+
+ btime1, _, from_cache = operator_cache.find('BatchMatMul', [input_shape1, output_shape])
+ if not from_cache:
+ btime1 = self.measure_batchmatmul_time(input_shape1, True, output_shape, False)
+ btime1_uncertainty = self.sampling.run('BatchMatMul', btime1, input_shape2, input_shape1, output_shape)
+ operator_cache.record('BatchMatMul', [input_shape1, output_shape], input_shape2, btime1, 0)
+
+ btime2, _, from_cache = operator_cache.find('BatchMatMul', [output_shape, input_shape2])
+ if not from_cache:
+ btime2 = self.measure_batchmatmul_time(output_shape, False, input_shape2, True)
+ btime2_uncertainty = self.sampling.run('BatchMatMul', btime2, input_shape1, output_shape, input_shape2)
+ operator_cache.record('BatchMatMul', [output_shape, input_shape2], input_shape1, btime2, 0)
+ return ftime_uncertainty, btime1_uncertainty + btime2_uncertainty
+
+ def layernorm(self, input_shape, output_shape, hidden_size, eps=1e-5):
+ layernorm = LayerNorm(hidden_size, eps)
+ ftime, btime, from_cache = operator_cache.find('LayerNorm', input_shape)
+ if not from_cache:
+ ftime, btime = TimeCostModel.profile(layernorm, [input_shape])
+ ftime_uncertainty = self.sampling.run('LayerNorm', ftime, output_shape, input_shape)
+ btime_uncertainty = self.sampling.run('LayerNormGrad', btime, input_shape, output_shape)
+ operator_cache.record('LayerNorm', input_shape, output_shape, ftime, btime)
+ return ftime_uncertainty, btime_uncertainty
+
+ def fused_rms_norm(self, input_shape, output_shape, hidden_size, eps=1e-6):
+ fused_rms_norm = FusedRmsNorm(hidden_size, eps)
+ ftime, btime, from_cache = operator_cache.find('RmsNorm', input_shape)
+ if not from_cache:
+ ftime, btime = TimeCostModel.profile(fused_rms_norm, [input_shape])
+ ftime_uncertainty = self.sampling.run('RmsNorm', ftime, output_shape, input_shape)
+ btime_uncertainty = self.sampling.run('RmsNormGrad', btime, output_shape, input_shape)
+ operator_cache.record('RmsNorm', input_shape, output_shape, ftime, btime)
+ return ftime_uncertainty, btime_uncertainty
+
+ def flash_attention(self, q, k, v, output_shape, head_dim):
+ flash_attn = FlashAttention(head_dim)
+ ftime, btime, from_cache = operator_cache.find('FlashAttentionScore', [q, k, v])
+ if not from_cache:
+ ftime, btime = TimeCostModel.profile(flash_attn, [q, k, v])
+ ftime_uncertainty = self.sampling.run('FlashAttentionScore', ftime, output_shape, q, k, v)
+ btime_uncertainty = self.sampling.run('FlashAttentionScoreGrad', btime, output_shape, q, k, v)
+ operator_cache.record('FlashAttentionScore', [q, k, v], q, ftime, btime)
+ return ftime_uncertainty, btime_uncertainty
+
+
+class TimeCostModel(object):
+ def __init__(self):
+ args = get_args()
+ self.seq_length = args.seq_length
+ self.hidden_size = args.hidden_size
+ self.pp_size = args.pipeline_model_parallel_size
+ self.dp_size = args.data_parallel_size
+ self.micro_batch_size = args.micro_batch_size
+ self.num_layers_per_stage = args.num_layers // args.pipeline_model_parallel_size
+ self.num_micro_batch = args.global_batch_size // args.micro_batch_size // args.data_parallel_size
+
+ def get_iteration_time(self):
+ transformer_block = TransformerBlock()
+ fwd_time, bwd_time = transformer_block.get_block_time()
+ fwd_time *= self.num_layers_per_stage
+ bwd_time *= self.num_layers_per_stage
+ iteration_times = np.array([0 for _ in range(fwd_time.shape[0])]).astype(np.float64)
+ for i in range(fwd_time.shape[0]):
+ iteration_times[i] = self.pipeline_costmodel(fwd_time[i], bwd_time[i])
+ return iteration_times
+
+ def pipeline_costmodel(self, fwd_time, bwd_time):
+ if self.pp_size == 1:
+ return (fwd_time + bwd_time) * self.num_micro_batch
+
+ send_recv_time = CommProfiling.get_send_recv_time(
+ [self.seq_length, self.micro_batch_size, self.hidden_size]
+ )
+ # p and m start with 1
+ SF = np.zeros((self.pp_size + 1, self.num_micro_batch + 1), np.float64)
+ SB = np.zeros((self.pp_size + 1, self.num_micro_batch + 1), np.float64)
+ EF = np.zeros((self.pp_size + 1, self.num_micro_batch + 1), np.float64)
+ EB = np.zeros((self.pp_size + 1, self.num_micro_batch + 1), np.float64)
+
+ warmup = [self.pp_size - p - 1 for p in range(self.pp_size)]
+ remaining = [self.num_micro_batch - warmup[p] for p in range(self.pp_size)]
+
+ # warmup
+ for p in range(1, self.pp_size + 1):
+ for m in range(1, warmup[p - 1] + 1):
+ if p == 1:
+ SF[p][m] = (m - 1) * fwd_time
+ EF[p][m] = m * fwd_time
+ else:
+ SF[p][m] = max(EF[p][m - 1], EF[p - 1][m] + send_recv_time)
+ EF[p][m] = SF[p][m] + fwd_time
+
+ # 1f1b
+ for num_1f1b in range(1, self.num_micro_batch + 1):
+ # forward of 1f1b
+ for p in range(1, self.pp_size + 1):
+ if num_1f1b > remaining[p - 1]:
+ # cool down phase
+ continue
+ m = warmup[p - 1] + num_1f1b
+ if p == 1:
+ SF[p][m] = EB[p][m + p - self.pp_size - 1]
+ EF[p][m] = SF[p][m] + fwd_time
+ else:
+ SF[p][m] = max(EB[p][m + p - self.pp_size - 1], EF[p - 1][m] + send_recv_time)
+ EF[p][m] = SF[p][m] + fwd_time
+
+ # backward of 1f1b
+ for p in range(self.pp_size, 0, -1):
+ m = num_1f1b
+ if num_1f1b > remaining[p - 1]:
+ # cool down phase
+ continue
+ if p == self.pp_size:
+ SB[p][m] = EF[p][m]
+ else:
+ SB[p][m] = max(EF[p][m + self.pp_size - p], EB[p + 1][m] + send_recv_time)
+ EB[p][m] = SB[p][m] + bwd_time
+
+ # cool down phase
+ for p in range(self.pp_size, 0, -1):
+ m = num_1f1b
+ if num_1f1b <= remaining[p - 1]:
+ continue
+ SB[p][m] = max(EB[p][m - 1], EB[p + 1][m] + send_recv_time)
+ EB[p][m] = SB[p][m] + bwd_time
+
+ e2e_time = max([max(EB[p]) for p in range(self.pp_size)])
+ # allreduce_gradients
+ e2e_time += 0.0
+ return e2e_time
+
+ @staticmethod
+ def profile(model, shapes):
+ model.to(torch.cuda.current_device())
+
+ input_tensors = []
+ index = 0
+ for shape in shapes:
+ tensor = GlobalMemoryBuffer.get_tensor(shape, index).requires_grad_()
+ input_tensors.append(tensor)
+ index += 1
+
+ sum_z = None
+ for _ in range(3):
+ sum_z = model(input_tensors)
+
+ # forward_time
+ torch.npu.synchronize()
+ start_time = time.time()
+ for _ in range(ITERATION_LOOP_TIME):
+ model(input_tensors)
+ torch.npu.synchronize()
+ fwd_time = (time.time() - start_time) * 1e6 / ITERATION_LOOP_TIME
+
+ for _ in range(3):
+ z = model(input_tensors)
+ loss = torch.sum(z)
+ loss.backward()
+
+ torch.npu.synchronize()
+ start_time = time.time()
+ for _ in range(ITERATION_LOOP_TIME):
+ torch.sum(sum_z)
+ torch.npu.synchronize()
+ loss_time = (time.time() - start_time) * 1e6 / ITERATION_LOOP_TIME
+
+ torch.npu.synchronize()
+ start_time = time.time()
+ for i in range(ITERATION_LOOP_TIME):
+ z = model(input_tensors)
+ loss = torch.sum(z)
+ loss.backward()
+ torch.npu.synchronize()
+ bwd_time = (time.time() - start_time) * 1e6 / ITERATION_LOOP_TIME - fwd_time - loss_time
+ return fwd_time, bwd_time
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_optimizer.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eab1cf8662d85fb0001aff6c05a9ce8b9ebdce4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_optimizer.py
@@ -0,0 +1,151 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import time
+import math
+import random
+import ast
+from pathlib import Path
+
+import pandas as pd
+import gpytorch
+from megatron.training.global_vars import get_args
+
+from mindspeed.core.auto_parallel import (
+ model_manager,
+ sample_cache,
+ operator_cache,
+)
+from mindspeed.core.auto_parallel.auto_parallel_rectify import ExactGPModel
+from mindspeed.core.auto_parallel.auto_parallel_model import TimeCostModel
+from mindspeed.core.auto_parallel.auto_parallel_profiling import (
+ BaseLaunch,
+ DistributedOperateProfiler,
+ DistributedPerformanceProfiler
+)
+
+
+class SearchByGreyBox:
+ def __init__(self, stop_threshold=0.05):
+ self.operators = [
+ 'MatMul',
+ 'RmsNorm',
+ 'RmsNormGrad',
+ 'LayerNorm',
+ 'LayerNormGrad',
+ 'FlashAttentionScore',
+ 'FlashAttentionScoreGrad'
+ ]
+
+ args = get_args()
+ if args.normalization == 'RMSNorm':
+ self.operators.remove('LayerNorm')
+ self.operators.remove('LayerNormGrad')
+ else:
+ self.operators.remove('RmsNorm')
+ self.operators.remove('RmsNormGrad')
+
+ self.stop_threshold = stop_threshold
+ self.config_performances = {}
+ self.exist_config = []
+ self.e2e_log = pd.DataFrame()
+
+ @staticmethod
+ def find_csv(operator_profile, key='kernel_details'):
+ csv_files = []
+ for cf in list(Path(operator_profile).rglob('*.csv')):
+ if key in str(cf):
+ csv_files.append(os.path.abspath(str(cf)))
+ if len(csv_files) <= 0:
+ print(f"not find kernel_details.csv")
+ return None
+ return sorted(csv_files)[0]
+
+ @staticmethod
+ def theory_modeling(config):
+ base_launch = BaseLaunch()
+ base_launch.update_args(config)
+ cost_time = TimeCostModel().get_iteration_time()
+ base_launch.recover_args()
+ return cost_time
+
+ def save(self, config, cost_time):
+ self.e2e_log[str(config)] = cost_time
+
+ def generate_config(self):
+ best_config = self.e2e_log.apply(lambda col: col.idxmin(), axis=1).values
+ rest_config = [i for i in best_config if str(i) not in self.exist_config]
+ prop = len(rest_config) / len(best_config)
+ if prop > self.stop_threshold:
+ sample = random.choice(rest_config)
+ self.exist_config.append(sample)
+ return ast.literal_eval(sample)
+ print(f'Unexplored proportion: {prop} < stop_thd :{self.stop_threshold}, early stop triggered.')
+ return None
+
+ def train(self, train_profiling_file, train_operator_data):
+ for operator in self.operators:
+ model = model_manager.get_cached_model(operator)
+ if model is None:
+ likelihood = gpytorch.likelihoods.GaussianLikelihood(
+ gpytorch.priors.NormalPrior(1e-3, 0.02)
+ )
+ model = ExactGPModel(operator=operator, likelihood=likelihood)
+ model_manager.cache_model(model, operator)
+ model.fit(train_profiling_file, train_operator_data)
+
+ def load_base_model(self, model_dir):
+ for operator in self.operators:
+ likelihood = gpytorch.likelihoods.GaussianLikelihood(gpytorch.priors.NormalPrior(1e-3, 0.02))
+ model = ExactGPModel(operator=operator, likelihood=likelihood)
+ try:
+ model_manager.load_model(model, operator, model_dir)
+ except Exception:
+ print(f"{operator} load error")
+
+ def search(self, args, search_spaces):
+ start_time = time.time()
+ self.load_base_model(os.path.dirname(os.path.abspath(__file__)) + os.sep + 'noise_predict_ckpt')
+ while ((time.time() - start_time) / 3600) < 8 \
+ and len(self.config_performances) < len(search_spaces):
+ for config in search_spaces:
+ cost_time = SearchByGreyBox.theory_modeling(config)
+ self.save(config, cost_time)
+ print(f"complete model config: {config}", flush=True)
+
+ next_config = self.generate_config()
+ if next_config is None:
+ break
+ print(f"next_config={next_config}", flush=True)
+
+ operator_profile_path, analyse_thread = DistributedOperateProfiler().launch(next_config)
+ duration_time = DistributedPerformanceProfiler().launch(next_config)
+ self.config_performances[duration_time] = str(next_config)
+ if math.isinf(duration_time):
+ search_spaces.remove(next_config)
+ if analyse_thread is not None:
+ analyse_thread.join()
+
+ operator_data = operator_cache.data_frame
+ operator_profile = SearchByGreyBox.find_csv(operator_profile_path)
+ if operator_profile is not None:
+ print(f"operator_data: {operator_data}\noperator_profile: {operator_profile}")
+ self.train(operator_profile, operator_data)
+ sample_cache.clear_cache()
+
+ model_manager.save_models('final_model')
+ min_key = min(self.config_performances.keys())
+ return ast.literal_eval(self.config_performances.get(min_key)), min_key
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_profiling.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ffcd64afc198ab9231e8eea63a24a14ca900d07
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_profiling.py
@@ -0,0 +1,399 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import stat
+import sys
+import time
+import json
+import copy
+import re
+import operator
+import functools
+import subprocess
+import signal
+import threading
+
+import pandas as pd
+import torch
+import torch_npu
+from torch_npu.profiler.profiler import analyse
+from megatron.training.global_vars import set_args, get_args
+
+from mindspeed.core.auto_parallel import (
+ SingletonType,
+ get_cache_path,
+ get_kv_store,
+ analyse_module_profile,
+ MODULE_PATTERN,
+ OPERATOR_PATTERN,
+ BAND_WIDTH_UNIDIRECTIONAL
+)
+
+
+class BaseLaunch:
+ def __init__(self):
+ self.old_args = None
+
+ def launch(self, config):
+ def update_or_append_param(argv: list, key, value=None):
+ if not value:
+ argv.append(key)
+ return
+
+ if key in argv:
+ argv[argv.index(key) + 1] = value
+ else:
+ argv.extend([key, value])
+
+ def remove_param(argv: list, key, has_value=False):
+ if key in argv:
+ pos = argv.index(key)
+ argv.pop(pos)
+ if has_value:
+ argv.pop(pos)
+
+ def monitor_exit(process):
+ while True:
+ exit_flag = get_kv_store().get("exit_flag")
+ if int(exit_flag) == 1:
+ try:
+ process_group_id = os.getpgid(process.pid)
+ os.killpg(process_group_id, signal.SIGKILL)
+ break
+ except ProcessLookupError:
+ break
+ time.sleep(60)
+
+ args = get_args()
+ argv: list = sys.argv[1:]
+ update_or_append_param(argv, '--eval-iters', '0')
+ update_or_append_param(argv, '--train-iters', '5')
+ update_or_append_param(argv, '--global-batch-size', str(args.global_batch_size))
+ update_or_append_param(argv, '--num-layers', str(args.num_layers))
+ update_or_append_param(argv, '--pipeline-model-parallel-size', str(args.pipeline_model_parallel_size))
+ update_or_append_param(argv, '--tensor-model-parallel-size', str(args.tensor_model_parallel_size))
+ update_or_append_param(argv, '--micro-batch-size', str(args.micro_batch_size))
+ update_or_append_param(argv, '--sequence-parallel')
+ if args.profile_operator:
+ update_or_append_param(argv, '--profile-operator')
+ if args.profile_memory:
+ update_or_append_param(argv, '--profile-memory')
+ if args.module_profile_path:
+ update_or_append_param(argv, '--prof-file', str(args.module_profile_path))
+ if args.context_parallel_algo == 'hybrid_cp_algo':
+ update_or_append_param(argv, '--context-parallel-algo', 'hybrid_cp_algo')
+ update_or_append_param(argv, '--context-parallel-size', str(args.context_parallel_size))
+ update_or_append_param(argv, '--ulysses-degree-in-cp', str(args.ulysses_degree_in_cp))
+ if args.context_parallel_algo == 'megatron_cp_algo':
+ update_or_append_param(argv, '--context-parallel-algo', 'megatron_cp_algo')
+ update_or_append_param(argv, '--context-parallel-size', str(args.context_parallel_size))
+ if args.context_parallel_algo == 'ulysses_cp_algo':
+ update_or_append_param(argv, '--context-parallel-algo', 'ulysses_cp_algo')
+ update_or_append_param(argv, '--context-parallel-size', str(args.context_parallel_size))
+ remove_param(argv, '--auto-parallel')
+
+ command = [
+ 'torchrun',
+ '--nproc_per_node', str(args.nproc_per_node),
+ '--nnodes', str(args.nnodes),
+ '--node-rank', str(args.node_rank),
+ '--master_addr', str(args.master_addr),
+ '--master_port', str(args.master_port),
+ str(sys.argv[0])
+ ] + argv
+
+ get_kv_store().set("exit_flag", "0")
+ process = subprocess.Popen(command, shell=False, preexec_fn=lambda: os.setpgrp())
+ monitor_thread = threading.Thread(target=monitor_exit, args=(process,))
+ monitor_thread.start()
+ process.wait()
+ get_kv_store().set("exit_flag", "1")
+ torch.distributed.barrier()
+
+ def update_args(self, config):
+ args = get_args()
+ self.old_args = copy.deepcopy(args)
+
+ args.pipeline_model_parallel_size = config[0]
+ args.tensor_model_parallel_size = config[1]
+ args.data_parallel_size = config[2]
+ args.context_parallel_size = config[3] * config[4]
+ args.ulysses_degree_in_cp = config[4]
+ args.micro_batch_size = config[5]
+ if config[3] > 1 and config[4] > 1:
+ args.context_parallel_algo = 'hybrid_cp_algo'
+ args.use_cp_send_recv_overlap = True
+ elif config[3] > 1 and config[4] == 1:
+ args.context_parallel_algo = 'megatron_cp_algo'
+ args.use_cp_send_recv_overlap = True
+ elif config[3] == 1 and config[4] > 1:
+ args.context_parallel_algo = 'ulysses_cp_algo'
+
+ def recover_args(self):
+ set_args(self.old_args)
+
+
+class DistributedMemoryProfiler(BaseLaunch):
+ def update_args(self, config):
+ super().update_args(config)
+ args = get_args()
+ args.module_profile_path = (get_cache_path() + MODULE_PATTERN).format(*config)
+ args.global_batch_size = args.pipeline_model_parallel_size * args.data_parallel_size * args.micro_batch_size
+ args.num_layers = args.pipeline_model_parallel_size
+ args.profile_memory = True
+
+ def launch(self, config):
+ args = get_args()
+ if args.node_rank != 0:
+ self.update_args(config)
+ super().launch(config)
+ super().recover_args()
+ return None
+
+ self.update_args(config)
+ module_profile_path = get_args().module_profile_path
+ if os.path.exists(module_profile_path):
+ super().recover_args()
+ return analyse_module_profile(module_profile_path, key='transformer_act_mem')
+
+ buffer = config + [0]
+ torch.distributed.broadcast(torch.tensor(buffer, dtype=torch.int), 0)
+
+ super().launch(config)
+ super().recover_args()
+ return analyse_module_profile(module_profile_path, key='transformer_act_mem')
+
+
+class DistributedOperateProfiler(BaseLaunch):
+ def update_args(self, config):
+ super().update_args(config)
+ args = get_args()
+ args.module_profile_path = None
+ args.operator_profile_path = (get_cache_path() + OPERATOR_PATTERN).format(*config)
+ args.global_batch_size = 4 * args.pipeline_model_parallel_size * args.data_parallel_size * args.micro_batch_size
+ args.num_layers = 2 * args.pipeline_model_parallel_size
+ args.profile_operator = True
+
+ def launch(self, config):
+ self.update_args(config)
+ args = get_args()
+ if args.node_rank != 0:
+ super().launch(config)
+ super().recover_args()
+ return None
+
+ operator_profile_path = args.operator_profile_path
+ if os.path.exists(operator_profile_path):
+ super().recover_args()
+ return operator_profile_path, None
+
+ buffer = config + [1]
+ torch.distributed.broadcast(torch.tensor(buffer, dtype=torch.int), 0)
+
+ os.environ['ASCEND_WORK_PATH'] = operator_profile_path
+ os.makedirs(operator_profile_path)
+ super().launch(config)
+ super().recover_args()
+
+ analyse_thread = threading.Thread(
+ target=analyse, args=(operator_profile_path + os.sep + 'profiling_data', 32)
+ )
+ analyse_thread.daemon = True
+ analyse_thread.start()
+ return operator_profile_path, analyse_thread
+
+
+class DistributedPerformanceProfiler(BaseLaunch):
+ def update_args(self, config):
+ super().update_args(config)
+ args = get_args()
+ args.module_profile_path = (get_cache_path() + MODULE_PATTERN).format(*config)
+
+ def launch(self, config):
+ self.update_args(config)
+ args = get_args()
+ if args.node_rank != 0:
+ super().launch(config)
+ super().recover_args()
+ return None
+
+ module_profile_path = get_args().module_profile_path
+ if os.path.exists(module_profile_path):
+ super().recover_args()
+ return analyse_module_profile(module_profile_path, key='step_time')
+
+ buffer = config + [2]
+ torch.distributed.broadcast(torch.tensor(buffer, dtype=torch.int), 0)
+ super().launch(config)
+ super().recover_args()
+ return analyse_module_profile(module_profile_path, key='step_time')
+
+
+class OperateProfile(metaclass=SingletonType):
+ def __init__(self, args):
+ experimental_config = torch_npu.profiler._ExperimentalConfig(
+ profiler_level=torch_npu.profiler.ProfilerLevel.Level2,
+ data_simplification=False
+ )
+ activities = [torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU]
+ self.op_profiler = torch_npu.profiler.profile(
+ activities=activities,
+ record_shapes=True,
+ schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=1, repeat=1, skip_first=2),
+ experimental_config=experimental_config,
+ )
+ self.op_profiler.start()
+
+ def step(self):
+ if torch.distributed.get_rank() in (0,):
+ self.op_profiler.step()
+
+ def stop(self):
+ if torch.distributed.get_rank() in (0,):
+ self.op_profiler.stop()
+
+
+class Profiling(metaclass=SingletonType):
+ MEMORY_UNIT = 1024 ** 3
+
+ def __init__(self, args, warmup_step=3, stop_step=5):
+ self.args = args
+ self.warmup_step = warmup_step
+ self.stop_step = stop_step
+ self.curr_step = 0
+ self.pattern = r'^module.module.language_model.encoder.layers.\d+$'
+ self.context = {
+ 'step_time': 0,
+ 'transformer_act_mem': 0
+ }
+
+ def should_profiling(self):
+ rank = torch.distributed.get_rank()
+ if rank in self.args.profile_ranks and \
+ self.warmup_step <= self.curr_step < self.stop_step:
+ return True
+ return False
+
+ def forward_pre_hook(self):
+ def hook(module, *args, **kwargs):
+ if torch.distributed.get_rank() in self.args.profile_ranks:
+ torch.npu.synchronize()
+ self.start_memory = torch.npu.memory_allocated()
+ torch.npu.reset_max_memory_allocated()
+ return hook
+
+ def forward_post_hook(self):
+ def hook(module, *args, **kwargs):
+ if torch.distributed.get_rank() in self.args.profile_ranks:
+ torch.npu.synchronize()
+ self.end_memory = torch.npu.max_memory_allocated()
+ transformer_act_mem = (self.end_memory - self.start_memory) / Profiling.MEMORY_UNIT
+ self.context['transformer_act_mem'] = transformer_act_mem
+ return hook
+
+ def register_recursive_hook(self, prefix_name, model):
+ model = model[0] if isinstance(model, list) else model
+ for name, module in model.named_children():
+ next_name = prefix_name + "." + name if prefix_name != "" else name
+ if re.fullmatch(self.pattern, next_name):
+ module.register_forward_pre_hook(self.forward_pre_hook())
+ module.register_forward_hook(self.forward_post_hook())
+ break
+ self.register_recursive_hook(next_name, module)
+
+ def hook_train_step(self, train_step):
+ def custom_train_step(*args, **kwargs):
+ start_time = time.time()
+ result = train_step(*args, **kwargs)
+ torch.cuda.synchronize()
+ step_time = time.time() - start_time
+ if self.should_profiling():
+ cur_step_time = self.context.get('step_time')
+ cur_step_time += (step_time - cur_step_time) / (self.curr_step - self.warmup_step + 1)
+ self.context['step_time'] = cur_step_time
+ self.export_to_file()
+ self.curr_step += 1
+ return result
+ return custom_train_step
+
+ def export_to_file(self):
+ if torch.distributed.get_rank() in self.args.profile_ranks:
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ modes = stat.S_IWUSR | stat.S_IRUSR
+ with os.fdopen(os.open(self.args.prof_file, flags, modes), 'w') as fout:
+ fout.write(json.dumps(self.context))
+
+
+class CommProfiling:
+ @staticmethod
+ def get_comm_time(shape, domains, op):
+ if domains == 1:
+ return 0
+
+ if op == 'all_reduce':
+ return CommProfiling.cal_all_reduce(shape, domains)
+ if op == 'all_gather':
+ return CommProfiling.cal_all_gather(shape, domains)
+ if op == 'alltoall':
+ return CommProfiling.cal_alltoall(shape, domains)
+ if op == 'reduce_scatter':
+ return CommProfiling.cal_reduce_scatter(shape, domains)
+ raise AssertionError('communicate operator type error')
+
+ @staticmethod
+ def cal_all_reduce(shape, domains):
+ data_size = CommProfiling.get_data_size(shape)
+ data_size = data_size / domains * (domains - 1) * domains * 2
+ band_width = domains * (domains - 1) / 2 * BAND_WIDTH_UNIDIRECTIONAL
+ return CommProfiling.div(data_size, band_width)
+
+ @staticmethod
+ def cal_all_gather(shape, domains):
+ data_size = CommProfiling.get_data_size(shape)
+ data_size = data_size / domains * (domains - 1) * domains
+ band_width = domains * (domains - 1) / 2 * BAND_WIDTH_UNIDIRECTIONAL
+ return CommProfiling.div(data_size, band_width)
+
+ @staticmethod
+ def cal_alltoall(shape, domains):
+ data_size = CommProfiling.get_data_size(shape)
+ data_size = data_size / domains * (domains - 1) * domains
+ band_width = domains * (domains - 1) / 2 * BAND_WIDTH_UNIDIRECTIONAL
+ return CommProfiling.div(data_size, band_width)
+
+ @staticmethod
+ def cal_reduce_scatter(shape, domains):
+ data_size = CommProfiling.get_data_size(shape)
+ data_size = data_size / domains * (domains - 1) * domains
+ band_width = domains * (domains - 1) / 2 * BAND_WIDTH_UNIDIRECTIONAL
+ return CommProfiling.div(data_size, band_width)
+
+ @staticmethod
+ def get_send_recv_time(shape):
+ data_size = CommProfiling.get_data_size(shape)
+ return (data_size / BAND_WIDTH_UNIDIRECTIONAL) * 1e6
+
+ @staticmethod
+ def get_data_size(shape):
+ return functools.reduce(operator.mul, shape) * 2 // 1024**3
+
+ @staticmethod
+ def div(data_size, band_width):
+ try:
+ return data_size / band_width * 1e6
+ except ZeroDivisionError:
+ print(f"band_width is zero")
+ return 0
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_rectify.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_rectify.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e09eba32a9499e5cd55ab468ae71439d018e4e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/auto_parallel_rectify.py
@@ -0,0 +1,424 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import glob
+import copy
+import warnings
+import ast
+from typing import Optional
+
+import pandas as pd
+import numpy as np
+import gpytorch
+import torch
+
+from mindspeed.core.auto_parallel import (
+ ARD_NUM_DIMS,
+ KeyField,
+ sample_cache,
+ model_manager
+)
+
+
+class ExactGPModel(gpytorch.models.ExactGP):
+ def __init__(self, operator, train_inputs=None,
+ train_targets=None, raw_lengthscale=None,
+ likelihood=None, dtype=torch.float64):
+ super(ExactGPModel, self).__init__(train_inputs, train_targets, likelihood=likelihood)
+ self.operator = operator
+ self.dtype = dtype
+
+ self.mean_module = gpytorch.means.ConstantMean()
+ self.covar_module = gpytorch.kernels.ScaleKernel(
+ gpytorch.kernels.MaternKernel(nu=0.5, ard_num_dims=ARD_NUM_DIMS[operator],
+ lengthscale_constraint=gpytorch.constraints.GreaterThan(3e-2)))
+ if raw_lengthscale is not None:
+ self.covar_module.base_kernel.raw_lengthscale.data \
+ = self.raw_lengthscale * torch.ones_like(self.covar_module.base_kernel.raw_lengthscale.data)
+
+ self.train_round = 0
+ self.train_data = pd.DataFrame()
+
+ self.y_train_mean: Optional[torch.Tensor] = None
+ self.y_train_std: Optional[torch.Tensor] = None
+ self.x_train_std: Optional[torch.Tensor] = None
+
+ def get_model_info(self):
+ return self.train_data, self.train_round
+
+ def set_model_info(self, values):
+ self.train_data, self.train_round = values
+ # set model info by train_data
+ self.data_standardize()
+
+ def forward(self, x):
+ mean = self.mean_module(x)
+ covar = self.covar_module(x)
+ return gpytorch.distributions.MultivariateNormal(mean, covar)
+
+ def fit(self, profiling_file, multi_operator_data, num_iter=3000, lr=0.03):
+ hd = DataHandler(profiling_file, multi_operator_data)
+ data = hd.generate_data(self.operator)
+ # merge self.train_data with new train_data
+ self.update_data(data)
+ # set model train_inputs and target_inputs
+ self.data_standardize()
+ # clear cache
+ self.train()
+ self.likelihood.train()
+ optimizer = torch.optim.Adam(self.parameters(), lr=lr)
+ mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)
+ for i in range(num_iter):
+ optimizer.zero_grad()
+ output = self(self.train_inputs[0])
+ loss = -mll(output, self.train_targets)
+ loss.backward()
+ if i % 100 == 0:
+ logs = 'Iter %d/%d - Loss: %.5f outputscale: %.5f noise: %.5f' % (
+ i + 1, num_iter, loss.item(),
+ self.covar_module.outputscale.item(),
+ self.likelihood.noise.item()
+ ) + ' lengthscale: ' + str(
+ np.round(self.covar_module.base_kernel.lengthscale.detach().cpu().numpy()[0], 5))
+ print(logs)
+ optimizer.step()
+ self.eval()
+ self.likelihood.eval()
+ self.train_round += 1
+
+ def update_data(self, data: pd.DataFrame):
+ """
+ :param data columns = [shape error count]
+ """
+ if not self.train_data.empty:
+ exits_shapes = self.train_data.loc[:, KeyField.InputShapes].values.tolist()
+ for index, rows in data.iterrows():
+ shape = getattr(rows, KeyField.InputShapes)
+ # update existent input_shape
+ if shape in exits_shapes:
+ error, number = data[data[KeyField.InputShapes] == shape].iloc[:, 1:3].values.flatten()
+ current_train_data = self.train_data[self.train_data[KeyField.InputShapes] == shape]
+ train_error, train_number = current_train_data.iloc[:, 1:3].values.flatten()
+ count = int(number + train_number)
+ new_error = (error * number + train_error * train_number) / count
+ self.train_data[self.train_data[KeyField.InputShapes] == shape] = [shape, new_error, count]
+ else:
+ # save new input_shape
+ self.train_data = pd.concat([self.train_data, rows.to_frame().T], ignore_index=True)
+ else:
+ self.train_data = data
+
+ def data_standardize(self):
+ y_train = torch.tensor(self.train_data['error'], dtype=self.dtype)
+ x_train = self.train_data[KeyField.InputShapes].str.split(',', expand=True).values.astype(int)
+ x_train = torch.tensor(x_train, dtype=self.dtype).log()
+ if x_train.shape[0] == 1:
+ self.x_train_std = torch.tensor(np.ones(x_train.shape), dtype=self.dtype)
+ self.y_train_std = torch.tensor(1, dtype=self.dtype)
+ else:
+ self.x_train_std, self.y_train_std = torch.std(x_train, dim=0), torch.std(y_train, dim=0)
+ self.x_train_std[self.x_train_std == 0] = 1.
+ self.y_train_std[self.y_train_std == 0] = 1.
+ x_train /= self.x_train_std
+ self.y_train_mean = torch.mean(y_train, dim=0)
+ y_train = (y_train - self.y_train_mean) / self.y_train_std
+ self.set_train_data(x_train, y_train, strict=False)
+
+
+class Sampler:
+ def __init__(self, num_sample=10, pre_thd=0):
+ self.pre_thd = pre_thd
+ self.num_sample = torch.Size([num_sample])
+
+ def run(self, operator, direct_time, output_shape: list, *input_shape):
+ input_shape = copy.deepcopy(input_shape)
+ output_shape = copy.deepcopy(output_shape)
+ # modify input_shape
+ input_shape = Sampler.reduce_dim(operator, output_shape, input_shape)
+ # check cache
+ cached_samples = getattr(sample_cache, operator)
+ sample = cached_samples.get(input_shape, None)
+ if sample is not None:
+ return sample
+ # load model
+ model = model_manager.get_cached_model(operator)
+ # predict
+ input_shape_np = np.array(input_shape).reshape(1, -1)
+ fixed_shape = np.concatenate([input_shape_np, input_shape_np], axis=0)
+ x = torch.tensor(fixed_shape, dtype=torch.float64).log()
+ if model is None:
+ relative_error = np.zeros(self.num_sample)
+ else:
+ with torch.no_grad(), gpytorch.settings.fast_pred_var():
+ pred = model(x / model.x_train_std)
+ pred = pred * model.y_train_std.item() + model.y_train_mean.item()
+ relative_error = pred.sample(self.num_sample).cpu().numpy()[:, 0]
+ sample = direct_time * (relative_error + 1.).flatten()
+ negative_indices = np.where(sample <= self.pre_thd)[0]
+ if negative_indices.size > 0:
+ sample[negative_indices] = 0
+ warnings.warn(f'Uncertainty of {operator} is too large, input shape: {input_shape}', Warning)
+ # save prediction data
+ cached_samples[input_shape] = sample
+ return sample
+
+ @staticmethod
+ def reduce_dim(operator, output_shape, input_shapes):
+ input_shapes = copy.deepcopy(input_shapes)
+ output_shape = copy.deepcopy(output_shape)
+ if operator in ['LayerNorm', 'LayerNormGrad']:
+ input_shape = input_shapes[0]
+ elif operator in ['FastGelu', 'FastGeluGrad']:
+ input_shape = output_shape
+ elif operator in ['Softmax', 'SoftmaxGrad']:
+ input_shape = output_shape
+ elif operator == 'Add' or operator == 'Mul':
+ if len(input_shapes[0]) >= len(input_shapes[1]):
+ max_dims, min_dims = input_shapes
+ else:
+ min_dims, max_dims = input_shapes
+ if len(max_dims) == 2:
+ max_dims.insert(0, 1)
+ if len(max_dims) == 1:
+ max_dims = [1, 1, max_dims[0]]
+ if len(min_dims) == 3:
+ min_dims = [1, 1, 1]
+ elif len(min_dims) == 2:
+ min_dims = [2, 1, 1]
+ else:
+ min_dims = [2, 2, 1]
+ max_dims.extend(min_dims)
+ input_shape = max_dims
+ elif operator == 'BatchMatMul':
+ if len(input_shapes) != 2:
+ raise AssertionError(f"Dim of BatchMatMul is {len(input_shapes)}")
+ b, k, m = output_shape[0], output_shape[2], output_shape[1]
+ n = input_shapes[0][1:] + input_shapes[1][1:]
+ for shape in output_shape[1:]:
+ n.remove(shape)
+ input_shape = [b, m, n[0], k]
+ elif operator == 'MatMul':
+ if len(input_shapes) != 2:
+ raise AssertionError(f"Dim of MatMul is {len(input_shapes)}")
+ input_shape = input_shapes[0]
+ input_shape.extend(input_shapes[1])
+ for shape in output_shape:
+ input_shape.remove(shape)
+ output_shape.insert(1, input_shape[0])
+ input_shape = output_shape
+ elif operator == 'RmsNorm' or operator == 'RmsNormGrad':
+ input_shape = input_shapes[0]
+ elif operator == 'FlashAttentionScore' or operator == 'FlashAttentionScoreGrad':
+ input_shape = input_shapes[0]
+ else:
+ raise ValueError(f"{operator} not supported.")
+
+ return tuple(input_shape)
+
+
+class DataHandler:
+ def __init__(self, profiling_file, multi_operator_data: pd.DataFrame):
+ self.sample_data = multi_operator_data
+ self.profiling = self.extract_target_data(profiling_file)
+ self.current_profiling_operator = None
+ self.current_sample_operator = None
+ self.backward_flag = False
+
+ @staticmethod
+ def extract_target_data(file):
+ if os.path.isdir(file):
+ file = glob.glob(os.path.join(file, "*.csv"))
+ data = pd.concat((pd.read_csv(f) for f in file), ignore_index=True).loc[:,
+ [KeyField.OpType, KeyField.InputShapes, KeyField.OutputShapes, KeyField.Duration]]
+ else:
+ data = pd.read_csv(file).loc[:,
+ [KeyField.OpType, KeyField.InputShapes, KeyField.OutputShapes, KeyField.Duration]]
+ data.loc[data['Type'].str.startswith('MatMul'), 'Type'] = 'MatMul'
+ data.loc[data['Type'].str.startswith('BatchMatMul'), 'Type'] = 'BatchMatMul'
+ data.loc[
+ (data['Type'].str.startswith('LayerNorm') &
+ ~(data['Type'].str.contains('Back') | data['Type'].str.contains('Grad'))), 'Type'
+ ] = 'LayerNorm'
+ data.loc[
+ (data['Type'].str.startswith('LayerNorm') &
+ (data['Type'].str.contains('Back') | data['Type'].str.contains('Grad'))), 'Type'
+ ] = 'LayerNormGrad'
+ # filter
+ data = data[(data[KeyField.Duration] > 5) & (data[KeyField.InputShapes].str.len() > 4)].reset_index(drop=True)
+ return data
+
+ @staticmethod
+ def convert_dim(data):
+ new_input_shape = []
+ for index, tmp_data in data[[KeyField.OpType, KeyField.InputShapes, KeyField.OutputShapes]].iterrows():
+ op, input_shape, output_shape = tmp_data.tolist()
+ input_shape, output_shape = ast.literal_eval(input_shape), ast.literal_eval(output_shape)
+ if op == 'LayerNorm' or op == 'LayerNormGrad':
+ input_shape = input_shape.split(';')[0]
+ elif op == 'Add' or op == 'Mul':
+ dims = input_shape.split(';')
+ d0_l, d1_l = dims[0].split(','), dims[1].split(',')
+ if len(d0_l) >= len(d1_l):
+ max_length_dim = d0_l
+ min_length_dim = d1_l
+ else:
+ max_length_dim = d1_l
+ min_length_dim = d0_l
+ if len(max_length_dim) == 2:
+ max_length_dim = ['1', '1', max_length_dim[0], max_length_dim[1]]
+ elif len(max_length_dim) == 1:
+ max_length_dim = ['1', '1', '1', max_length_dim[0]]
+ elif len(max_length_dim) == 3:
+ max_length_dim.insert(0, '1')
+ if len(min_length_dim) == 3:
+ min_length_dim = ['2', '1', '1', '1']
+ elif len(min_length_dim) == 2:
+ min_length_dim = ['2', '2', '1', '1']
+ elif len(min_length_dim) == 1:
+ min_length_dim = ['2', '2', '2', '1']
+ elif len(min_length_dim) == 4:
+ min_length_dim = ['1', '1', '1', '1']
+ max_length_dim.extend(min_length_dim)
+ input_shape = ','.join(max_length_dim)
+ elif op == 'BatchMatMul':
+ output_shape = output_shape.split(',')
+ b, k, m = output_shape[0], output_shape[2], output_shape[1]
+ input_shapes = input_shape.split(';')
+ n = input_shapes[0].split(',')[1:] + input_shapes[1].split(',')[1:]
+ for shape in output_shape[1:]:
+ n.remove(shape)
+ input_shape = ','.join([b, m, n[0], k])
+ elif op == 'MatMul':
+ input_shape = input_shape.replace(';', ',').split(',')
+ output_shape = output_shape.split(',')
+ for shape in output_shape:
+ input_shape.remove(shape)
+ output_shape.insert(1, input_shape[0])
+ input_shape = ','.join(output_shape)
+ elif op == 'Softmax' or op.startswith('SoftmaxGrad'):
+ input_shape = input_shape.split(';')[0]
+ elif op == 'RmsNorm' or op == 'RmsNormGrad':
+ input_shape = input_shape.split(';')[0]
+ elif op == 'FlashAttentionScore' or op == 'FlashAttentionScoreGrad':
+ input_shape = input_shape.split(';')[0]
+ else:
+ raise TypeError(f"{op} don't support")
+ new_input_shape.append(input_shape)
+ return new_input_shape
+
+ def handle_transpose(self):
+ input_shapes = []
+ for index, sample in self.current_profiling_operator.iterrows():
+ input_shape = sample[KeyField.InputShapes]
+ input_shape = ast.literal_eval(input_shape).split(';')
+ input_shape = [list(map(lambda x: int(x), s.split(','))) for s in input_shape]
+ output_shape = ast.literal_eval(sample[KeyField.OutputShapes]).split(',')
+ output_shape = [int(s) for s in output_shape]
+ if sample[KeyField.OpType] == 'BatchMatMul':
+ if output_shape[1] != input_shape[0][1]:
+ input_shape[0][1], input_shape[0][2] = input_shape[0][2], input_shape[0][1]
+ if output_shape[-1] != input_shape[1][-1]:
+ input_shape[1][1], input_shape[1][2] = input_shape[1][2], input_shape[1][1]
+ elif sample[KeyField.OpType] == 'MatMul':
+ if output_shape[0] != input_shape[0][0]:
+ input_shape[0][0], input_shape[0][1] = input_shape[0][1], input_shape[0][0]
+ if output_shape[-1] != input_shape[1][-1]:
+ input_shape[1][0], input_shape[1][1] = input_shape[1][1], input_shape[1][0]
+ input_shape1 = ','.join([str(i) for i in input_shape[0]])
+ input_shape2 = ','.join([str(i) for i in input_shape[1]])
+ input_shape_sum = input_shape1 + ';' + input_shape2
+ input_shapes.append(f'"{input_shape_sum}"')
+ self.current_profiling_operator.loc[:, KeyField.InputShapes] = input_shapes
+
+ def handle_layer_norm_backward(self, operator):
+ profiling = self.profiling[self.profiling[KeyField.OpType] == operator].reset_index(drop=True)
+ back_grad_data = pd.DataFrame()
+ for index in range(0, profiling.shape[0], 2):
+ sum_duration = profiling.loc[index, KeyField.Duration] + profiling.loc[
+ index + 1, KeyField.Duration]
+ input_shape = profiling.loc[index, KeyField.InputShapes].split(';')[0] + '"'
+ back_grad_data.loc[index, KeyField.OpType] = 'LayerNormGrad'
+ back_grad_data.loc[index, KeyField.InputShapes] = input_shape
+ back_grad_data.loc[index, KeyField.OutputShapes] = input_shape
+ back_grad_data.loc[index, KeyField.Duration] = sum_duration
+ return back_grad_data.reset_index(drop=True)
+
+ def handle_fv(self):
+ condition = self.current_profiling_operator[KeyField.InputShapes].str.replace('"', '').str.split(';').map(
+ lambda x: x[:3]).map(lambda x: x[0] == x[1] == x[2])
+ self.current_profiling_operator = self.current_profiling_operator[condition]
+ # 对FV_grad的input_shape可能出现的异常情况容错处理
+ target_shape = self.current_sample_operator[KeyField.InputShapes].values[0]
+ current_shape = self.current_profiling_operator[KeyField.InputShapes].values[0]
+ if target_shape.split(';')[1] != current_shape.split(';')[1]:
+ self.current_profiling_operator[KeyField.InputShapes] = target_shape
+
+ def generate_data(self, operator):
+ # 串行处理各个算子
+ if len(operator) == 2:
+ # layer_norm反向特殊处理
+ self.current_profiling_operator = self.handle_layer_norm_backward(operator)
+ operator = self.current_profiling_operator.loc[0][KeyField.OpType]
+ else:
+ self.current_profiling_operator = self.profiling[self.profiling[KeyField.OpType] == operator]
+ self.backward_flag = False
+ if operator.endswith('Grad'):
+ self.backward_flag = True
+ operator = operator.split('Grad')[0]
+ # matmul和batch_matmul需要考虑转置情况
+ if operator in ['MatMul', 'BatchMatMul']:
+ self.handle_transpose()
+ # convert sample input_shape
+ self.current_sample_operator = self.sample_data[
+ self.sample_data[KeyField.OpType].str.startswith(operator)].reset_index(
+ drop=True)
+ # 删除负载均衡产生的shape和对FVGrad可能出现的异常Input_shape容错处理.
+ if operator.startswith('FlashAttention'):
+ self.handle_fv()
+ # convert profiling input_shape
+ self.current_profiling_operator.loc[:, KeyField.InputShapes] = self.convert_dim(
+ self.current_profiling_operator
+ )
+ self.current_sample_operator[KeyField.InputShapes] = self.convert_dim(self.current_sample_operator)
+ # 获取当前算子的所有input_shape
+ set_operator = self.current_sample_operator[KeyField.InputShapes].drop_duplicates().tolist()
+ errors_df = pd.DataFrame()
+ # 计算每个input_shape的相对误差
+ for shape in set_operator:
+ # 获取profiling数据当前input_shape的所有样本
+ tmp_data = self.current_profiling_operator[
+ self.current_profiling_operator[KeyField.InputShapes] == shape].copy()
+ if self.backward_flag:
+ direct_mean = self.current_sample_operator[
+ self.current_sample_operator[KeyField.InputShapes] == shape
+ ]['bwd_time'].values[0]
+ else:
+ direct_mean = self.current_sample_operator[
+ self.current_sample_operator[KeyField.InputShapes] == shape
+ ]['fwd_time'].values[0]
+ # 计算相对误差
+ tmp_data['error'] = (tmp_data[KeyField.Duration] - direct_mean) / direct_mean
+ tmp_data['direct_mean'] = direct_mean
+ errors_df = pd.concat([errors_df, tmp_data], axis=0)
+ if errors_df.empty:
+ raise AssertionError('profiling_shape mismatch operator_shape')
+
+ # 分组平均和计数
+ train_data = errors_df.groupby(KeyField.InputShapes).agg(
+ {'error': 'mean', KeyField.InputShapes: 'count'})
+ train_data.rename(columns={KeyField.InputShapes: 'sample_number'}, inplace=True)
+ train_data.reset_index(inplace=True)
+ return train_data
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/help.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/help.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c0bfaa09767de9760d9360ff6c55d42a2d142cc
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/help.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import datetime
+import json
+import math
+
+import torch
+
+
+SEARCH_CACHE_PATH = None
+KV_STORE = None
+PROFILE_CONTENT = {"fwd_time": [], "bwd_time": [], "act_mem": [], "module_param": []}
+INITIAL_CONFIG = {}
+GPT_ARGS_PATH = "gpt_args.json"
+STAGE_PROFILE_PATH = 'stage_1_profile.json'
+
+
+def broadcast_communicate(commum_data, source_rank):
+ temp_data = torch.cuda.FloatTensor([commum_data])
+ torch.distributed.broadcast(temp_data, src=source_rank)
+ return temp_data.item()
+
+
+def broadcast_communicate_list(commum_data, source_rank):
+ temp_data = torch.cuda.FloatTensor(commum_data)
+ torch.distributed.broadcast(temp_data, src=source_rank)
+ return temp_data.tolist()
+
+
+def cal_throughput(run_time, profile_data, parallel_cfg):
+ sum_token = profile_data["text_decoder.seq_length"] * profile_data['grad_acc_step'] * profile_data['micro_batch_size']
+ PP = parallel_cfg[0]
+ TP = parallel_cfg[1]
+ per_npu_throughput = sum_token / (run_time / 1000) / (PP * TP)
+ return per_npu_throughput
+
+
+def get_json(json_path):
+ with open(json_path, 'r', encoding='utf-8') as f:
+ json_data = json.load(f)
+ return json_data
+
+
+def save_json(json_path, json_data):
+ json_data_json = json.dumps(json_data)
+ with open(json_path, 'w') as f:
+ f.write(json_data_json)
+
+
+def precise_round(num, ndigits=0):
+ multiplier = 10 ** ndigits
+ return math.floor(num * multiplier + 0.5) / multiplier
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/memory_modeling.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/memory_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0cd16ff2a071ab65f2f3bd7e4f9988ce2a15cef
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/memory_modeling.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import json
+
+import torch
+
+from mindspeed.core.auto_parallel.mm_search.help import get_json
+
+
+def get_model_parameters(model_config):
+ transformer_params_count = 12 * model_config["hidden_size"] ** 2
+ total_params_count = transformer_params_count * model_config["num_layers"]
+ return total_params_count
+
+
+def get_model_total_static_memory(args, parallel_config):
+ model_config = get_json(args.mm_model)
+ DP = parallel_config[2]
+
+ if model_config.get("image_encoder"):
+ vit_model_cfg = {"hidden_size": model_config["image_encoder"]["vision_encoder"]["hidden_size"],
+ "num_layers": model_config["image_encoder"]["vision_encoder"]["num_layers"]}
+ vit_model_params_count = get_model_parameters(vit_model_cfg)
+ if model_config.get("text_decoder"):
+ llm_model_cfg = {"hidden_size": model_config["text_decoder"]["hidden_size"],
+ "num_layers": model_config["text_decoder"]["num_layers"]}
+ llm_model_params_count = get_model_parameters(llm_model_cfg)
+
+ mem_para, mem_grad, mem_optimizer = 0, 0, 0
+ if model_config["image_encoder"]["vision_encoder"]["params_dtype"] == "bf16":
+ if not model_config["image_encoder"]["vision_encoder"].get("freeze", False):
+ mem_para += 2 * vit_model_params_count
+ mem_grad += 4 * vit_model_params_count
+ mem_optimizer += 4 * vit_model_params_count + 8 * vit_model_params_count / DP
+ else:
+ mem_para += 2 * vit_model_params_count
+ if model_config["text_decoder"]["params_dtype"] == "bf16":
+ if not model_config["text_decoder"].get("freeze", False):
+ mem_para += 2 * llm_model_params_count
+ mem_grad += 4 * llm_model_params_count
+ mem_optimizer += 4 * llm_model_params_count + 8 * llm_model_params_count / DP
+ else:
+ mem_para += 2 * llm_model_params_count
+
+ model_total_static_memory = mem_para + mem_grad + mem_optimizer
+ return model_total_static_memory / (1024 ** 2)
+
+
+def parallel_cluster_is_oom(args, parallel_config, static_mem):
+ PP, TP = parallel_config[0], parallel_config[1]
+
+ max_available_memory = torch.npu.get_device_properties(0).total_memory * 0.95 / 1024**2
+
+ if PP * TP * max_available_memory < static_mem:
+ return True
+ else:
+ return False
+
+
+def count_module_param(model):
+ for mod in model:
+ precision_placeholder = {torch.float32: 4, torch.float16: 2, torch.bfloat16: 2}
+ module_param_property = {name: [param.numel(), precision_placeholder.get(param.dtype, 0), param.requires_grad] for name, param in mod.named_parameters()}
+ # model_para, optimizer, grad
+ module_param_dict = [0, 0, 0]
+ for module_param in module_param_property:
+ module_param_dict[0] += module_param_property[module_param][0] * \
+ module_param_property[module_param][1] / 1024 ** 2
+ if module_param_property[module_param][2]:
+ module_param_dict[1] += (module_param_property[module_param][0] * 4 + \
+ module_param_property[module_param][0] * 8) / 1024 ** 2
+ module_param_dict[2] += module_param_property[module_param][0] * 4 / 1024 ** 2
+ module_param_property_json = json.dumps(module_param_property)
+ with open(f'raw_profile_{torch.distributed.get_rank()}.json', 'w') as f:
+ f.write(module_param_property_json)
+ return module_param_dict
+
+
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/optimizer.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc012a3f41f8df7087e545190694b981afcc522
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/optimizer.py
@@ -0,0 +1,161 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import time
+import copy
+import sys
+import json
+
+import torch
+
+from megatron.training import get_args
+from mindspeed.core.auto_parallel import set_kv_store
+from mindspeed.core.auto_parallel.mm_search.help import get_json, save_json, GPT_ARGS_PATH
+from mindspeed.core.auto_parallel.mm_search.profiling import DistributedPerformanceProfiler
+from mindspeed.core.auto_parallel.mm_search.solver import solve_auto_parallel_mm
+from mindspeed.core.auto_parallel.mm_search.memory_modeling import get_model_total_static_memory, parallel_cluster_is_oom
+
+
+class SearchByProfile:
+ def __init__(self):
+ self.merge_config_list = {}
+
+
+ def get_gpt_args(self, args):
+ gpt_args = {}
+ world_size = args.nproc_per_node * args.nnodes
+ tp = getattr(args, "tensor_model_parallel_size", 1)
+ pp = getattr(args, "pipeline_model_parallel_size", 1)
+ cp = getattr(args, "context_parallel_size", 1)
+ dp = world_size / tp / pp / cp
+ grad_acc_step = int(args.global_batch_size / args.micro_batch_size / dp)
+ gpt_args['grad_acc_step'] = grad_acc_step
+ save_json(GPT_ARGS_PATH, gpt_args)
+
+
+ def merge_config(self, args, search_spaces):
+ search_spaces_backup = copy.deepcopy(search_spaces)
+ world_size = args.nproc_per_node * args.nnodes
+ configs = []
+ for ind, cfg in enumerate(search_spaces_backup):
+ cfg[0] = 4 # pp
+ cfg[2] = world_size // (cfg[0] * cfg[1]) # dp
+ if cfg[2] < 1:
+ continue
+ if cfg not in configs:
+ configs.append(cfg)
+ self.merge_config_list[tuple(cfg)] = [search_spaces[ind], ]
+ else:
+ self.merge_config_list[tuple(cfg)].append(search_spaces[ind])
+ print("[INFO] merge config list", self.merge_config_list)
+
+ return configs
+
+
+ def search(self, args, search_spaces):
+ self.get_gpt_args(args)
+ merge_cfg = self.merge_config(args, search_spaces)
+
+ opt_config = []
+ run_throughput = 0
+ for config in merge_cfg:
+ print(f"[INFO] now profile config: {config}")
+
+ status_code = 0
+ status_code += DistributedPerformanceProfiler().launch(config, 'profiling_stage_1')
+ status_code += DistributedPerformanceProfiler().launch(config, 'profiling_stage_2')
+
+ if status_code == 0:
+ parallel_split_config = self.merge_config_list[tuple(config)]
+ print(f"[INFO] now solve cfg: {parallel_split_config}")
+
+ optimal_config = solve_auto_parallel_mm(args, parallel_split_config)
+ if optimal_config and optimal_config['throughput'] > run_throughput:
+ run_throughput = optimal_config['throughput']
+ opt_config = optimal_config
+
+ opt_config_json = json.dumps(opt_config)
+ with open(f'auto_parallel_search_optimal_config.json', 'w') as f:
+ f.write(opt_config_json)
+ print(f"[INFO] finally opt config: {opt_config}")
+
+
+ @staticmethod
+ def build_initial_spaces(args):
+ world_size = args.simulated_nproc_per_node * args.simulated_nnodes
+ device_count = args.simulated_nproc_per_node
+
+ solutions = []
+ for pp in range(1, world_size + 1):
+ if world_size % pp != 0:
+ continue
+
+ for i in range(device_count):
+ tp = 2 ** i
+ if tp > device_count or tp > (world_size // pp):
+ break
+ if (args.num_query_groups > 1 and args.num_query_groups % tp != 0) \
+ or (args.num_attention_heads % tp != 0):
+ break
+
+ dp = world_size // (pp * tp)
+ dp_group_batch_size = args.global_batch_size // dp
+ for num_mb in range(1, dp_group_batch_size + 1):
+ if dp_group_batch_size % num_mb != 0:
+ continue
+
+ mbs = dp_group_batch_size // num_mb
+ if mbs > 2:
+ continue
+
+ solutions.append([pp, tp, dp, mbs])
+
+ return solutions
+
+
+ @staticmethod
+ def filter_invalid_configs(args, search_spaces):
+ rough_filter_configs = []
+ for config in search_spaces:
+ static_mem = get_model_total_static_memory(args, config)
+ print(f"config: {config} static_mem: {static_mem}", flush=True)
+ # PPģֳ̬4
+ if not parallel_cluster_is_oom(args, config, static_mem) and config[0] <= 16 and config[1] <= args.nproc_per_node / 4:
+ rough_filter_configs.append(config)
+ print(f"[INFO] finish static memory filter config {rough_filter_configs}")
+
+ return rough_filter_configs
+
+
+def monitor_train_task():
+ while True:
+ print(f"monitor next task...", flush=True)
+ message = torch.tensor([0 for _ in range(5)], dtype=torch.int)
+ torch.distributed.broadcast(message, src=0)
+ task_type = message[-1].item()
+ config = [m.item() for m in message[:-1]]
+ if task_type == -1:
+ break
+ elif task_type == 0:
+ DistributedPerformanceProfiler().launch(config)
+
+
+def auto_parallel_mm_search_optimal_config(args):
+ set_kv_store(args)
+ # set cluster communication
+ init_method = 'tcp://{}:{}'.format(args.master_addr, int(args.master_port) + 1)
+ torch.distributed.init_process_group(
+ backend=torch.distributed.Backend.GLOO,
+ init_method=init_method,
+ rank=args.node_rank,
+ world_size=args.nnodes
+ )
+
+ if args.node_rank == 0:
+ search_space = SearchByProfile().build_initial_spaces(args)
+ print(f"[INFO] len(init_search_space): {len(search_space)}, {search_space}")
+
+ search_space = SearchByProfile().filter_invalid_configs(args, search_space)
+ print(f"[INFO] filter search_space: {len(search_space)}")
+
+ SearchByProfile().search(get_args(), search_space)
+ else:
+ monitor_train_task()
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/pp_layer_search.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/pp_layer_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac359926e30fb35f099732c3c0986445ad3555a1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/pp_layer_search.py
@@ -0,0 +1,218 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import time
+import functools
+import operator
+
+import numpy as np
+import pulp
+from pulp import LpMinimize, LpProblem, LpVariable, lpDot, lpSum
+import highspy
+
+from mindspeed.core.auto_parallel.mm_search.help import precise_round
+from mindspeed.core.auto_parallel import BAND_WIDTH_UNIDIRECTIONAL
+
+
+def get_send_recv_time(shape):
+ data_size = functools.reduce(operator.mul, shape) * 2 / (1024 ** 3)
+ return (data_size / BAND_WIDTH_UNIDIRECTIONAL) * 1e3
+
+
+def pp_layer_search(parallel_cfg, profile_data, npu_memory_limit, last_stage_max_layer):
+ print(f"[INFO] start pp layer search {time.ctime()}")
+ print(f"[INFO] profile: {profile_data}")
+
+ PP = parallel_cfg[0]
+ DP = parallel_cfg[2]
+
+ num_vit = profile_data["image_encoder.vision_encoder.num_layers"]
+ num_llm = profile_data["text_decoder.num_layers"]
+
+ model_structure = [1, num_vit - 2, 1, 1, num_llm - 2, 1]
+ recomputing_fwd = [profile_data['vit']['fwd_time'], profile_data['llm']['fwd_time']]
+ recomputing_act = [profile_data['vit']['act_mem'], profile_data['llm']['act_mem']]
+ layer_name = ['vit_pre', 'vit', 'vit_post', 'llm_pre', 'llm', 'llm_post']
+ model_num_layers = num_vit + num_llm
+ print(f"PP:{PP}, DP:{DP}, num_vit, {num_vit}, num_llm, {num_llm}, model_num_layers, {model_num_layers}, \
+ model_structure, {model_structure}")
+
+ fwd_time, bwd_time, act_memory, static_memory = [], [], [], []
+ for key in layer_name:
+ fwd_time.append(int(profile_data[key]['fwd_time']))
+ bwd_time.append(int(profile_data[key]['bwd_time']))
+ act_memory.append(int(profile_data[key]['act_mem']))
+ static_memory.append(int(sum(profile_data[key]['module_param'])))
+ print(f"fwd_time, {fwd_time}, bwd_time, {bwd_time}, act_memory, {act_memory}, static_memory, {static_memory}")
+
+ fwd_duration_layers, bwd_duration_layers, act_memory_layers, static_memory_layers = [], [], [], []
+ for ind, num in enumerate(model_structure):
+ fwd_duration_layers += num * [fwd_time[ind]]
+ bwd_duration_layers += num * [bwd_time[ind]]
+ act_memory_layers += num * [act_memory[ind]]
+ static_memory_layers += num * [static_memory[ind]]
+
+ memory_reserved = [npu_memory_limit] * PP
+ num_micro_batches = profile_data["grad_acc_step"]
+ if num_micro_batches < PP:
+ return None, None, None
+
+ send_recv_time = get_send_recv_time(
+ [profile_data["text_decoder.seq_length"], profile_data["micro_batch_size"], profile_data["text_decoder.hidden_size"]]
+ )
+ comm_matrix = [[send_recv_time] * PP for _ in range(PP)]
+ for i in range(PP):
+ comm_matrix[i][i] = 0
+
+ prob = LpProblem("Min_duration_time", LpMinimize)
+
+ layer_placement = [LpVariable.matrix(f"X_{i}", range(model_num_layers), cat="Binary") for i in range(PP - 1)]
+
+ # variable: forward/backward stage start time
+ bwd_start, fwd_start = [], []
+ for j in range(PP):
+ fwd_start.append(LpVariable.matrix(f"fs_{j}", range(num_micro_batches), lowBound=1e-4, cat="Continuous"))
+ bwd_start.append(LpVariable.matrix(f"bs_{j}", range(num_micro_batches), lowBound=1e-4, cat="Continuous"))
+ recomputing_layers = [LpVariable.matrix("vit_r", range(PP - 1), lowBound=0, cat='Integer'),
+ LpVariable.matrix("llm_r", range(PP - 1), lowBound=0, cat='Integer')]
+
+ layers_per_stage = [lpSum(layer_placement[s][i] for i in range(model_num_layers)) for s in range(PP - 1)]
+ layers_per_stage.append(model_num_layers)
+
+ Const1 = 0.0001
+ Const2 = 10000
+ Z = [LpVariable.matrix(f"Z_{i}", range(PP), cat="Binary") for i in range(2)]
+
+ prob += recomputing_layers[0][0] + recomputing_layers[1][0] <= layers_per_stage[0]
+ for s in range(1, PP - 1):
+ prob += recomputing_layers[0][s] + recomputing_layers[1][s] <= layers_per_stage[s] - layers_per_stage[s - 1]
+ for s in range(PP - 1):
+ # constraint: llm recompute
+ prob += Z[1][s] <= 1 - (layers_per_stage[s] - num_vit) * Const1
+ prob += Z[1][s] >= Const1 * (num_vit - layers_per_stage[s])
+ prob += recomputing_layers[1][s] <= layers_per_stage[s] - num_vit + Const2 * Z[1][s]
+ prob += recomputing_layers[1][s] <= Const2 * (1 - Z[1][s])
+ prob += recomputing_layers[0][0] <= num_vit
+ for s in range(1, PP - 1):
+ # constraint: vit recompute
+ prob += Z[0][s] <= 1 - (layers_per_stage[s - 1] - num_vit) * Const1
+ prob += Z[0][s] >= Const1 * (num_vit - layers_per_stage[s - 1])
+ prob += recomputing_layers[0][s] <= num_vit - layers_per_stage[s - 1] + Const2 * (1 - Z[0][s])
+ prob += recomputing_layers[0][s] <= Const2 * Z[0][s]
+
+ # variable: pp stage forward/backward time
+ fwd_duration_each_stage = []
+ bwd_duration_each_stage = []
+ fwd_duration_each_stage.append(lpSum(lpDot(fwd_duration_layers, layer_placement[0])))
+ bwd_duration_each_stage.append(lpSum(lpDot(bwd_duration_layers, layer_placement[0]))
+ + recomputing_layers[0][0] * recomputing_fwd[0]
+ + recomputing_layers[1][0] * recomputing_fwd[1])
+ for s in range(1, PP - 1):
+ fwd_duration_each_stage.append(lpSum(lpDot(fwd_duration_layers, layer_placement[s])) -
+ lpSum(lpDot(fwd_duration_layers, layer_placement[s - 1])))
+ bwd_duration_each_stage.append(lpSum(lpDot(bwd_duration_layers, layer_placement[s]))
+ - lpSum(lpDot(bwd_duration_layers, layer_placement[s - 1]))
+ + recomputing_layers[0][s] * recomputing_fwd[0]
+ + recomputing_layers[1][s] * recomputing_fwd[1])
+ fwd_duration_each_stage.append(sum(fwd_duration_layers) - lpSum(lpDot(fwd_duration_layers, layer_placement[-1])))
+ bwd_duration_each_stage.append(sum(bwd_duration_layers) - lpSum(lpDot(bwd_duration_layers, layer_placement[-1])))
+
+ prob += bwd_duration_each_stage[0] >= 1e-4
+
+ # constraint: pp schedules constraints
+ # warm up
+ for s in range(PP):
+ for j in range(PP - s - 1):
+ prob += fwd_start[s][j] + fwd_duration_each_stage[s] <= fwd_start[s][j + 1]
+ # cool down
+ for s in range(PP):
+ for j in range(num_micro_batches + s - PP, num_micro_batches - 1):
+ prob += bwd_start[s][j] + bwd_duration_each_stage[s] <= bwd_start[s][j + 1]
+
+ for s in range(PP):
+ for j in range(num_micro_batches - PP + s + 1):
+ prob += fwd_start[s][j + PP - s - 1] + fwd_duration_each_stage[s] <= bwd_start[s][j]
+
+ for s in range(PP):
+ for j in range(num_micro_batches - PP + s):
+ prob += bwd_start[s][j] + bwd_duration_each_stage[s] <= fwd_start[s][j + PP - s]
+
+ for s in range(PP - 1):
+ for j in range(num_micro_batches):
+ prob += fwd_start[s + 1][j] >= fwd_start[s][j] + fwd_duration_each_stage[s] + comm_matrix[s][s + 1]
+ prob += bwd_start[s + 1][j] + bwd_duration_each_stage[s + 1] + comm_matrix[s + 1][s] <= bwd_start[s][j]
+
+ # constraint: model layer placement
+ for s in range(PP - 1):
+ for i in range(model_num_layers - 1):
+ prob += layer_placement[s][i] >= layer_placement[s][i + 1]
+
+ for s in range(PP - 2):
+ prob += (lpSum(layer_placement[s + 1][j] for j in range(model_num_layers)) >=
+ lpSum(layer_placement[s][j] for j in range(model_num_layers)) + 1)
+
+ # constraint: model memory
+ prob += ((lpSum(lpDot(layer_placement[0], act_memory_layers)) -
+ recomputing_layers[0][0] * recomputing_act[0]
+ - recomputing_layers[1][0] * recomputing_act[1]) * (PP - 1) +
+ lpSum(lpDot(layer_placement[0], act_memory_layers)) +
+ lpSum(lpDot(layer_placement[0], static_memory_layers)) <= memory_reserved[0])
+ for s in range(1, PP - 1):
+ prob += ((lpSum(lpDot(layer_placement[s], act_memory_layers))
+ - lpSum(lpDot(layer_placement[s - 1], act_memory_layers))
+ - recomputing_layers[0][s] * recomputing_act[0]
+ - recomputing_layers[1][s] * recomputing_act[1]) * (PP - s - 1) +
+ lpSum(lpDot(layer_placement[s], act_memory_layers))
+ - lpSum(lpDot(layer_placement[s - 1], act_memory_layers)) +
+ lpSum(lpDot(layer_placement[s], static_memory_layers))
+ - lpSum(lpDot(layer_placement[s - 1], static_memory_layers)) <= memory_reserved[s])
+
+ prob += layer_placement[0][0] == 1
+
+ prob += lpSum(layer_placement[-1][i] for i in range(model_num_layers)) >= model_num_layers - last_stage_max_layer
+
+ # object function
+ obj = bwd_start[0][num_micro_batches - 1] + bwd_duration_each_stage[0]
+ prob += obj
+ prob.writeLP("pp_layers_prob.lp")
+
+ print(f"[INFO] start solve {time.ctime()}")
+ h = highspy.Highs()
+ filename = 'pp_layers_prob.lp'
+ h.readModel(filename)
+ h.run()
+ print(f"[INFO] finish solve {time.ctime()}, solve state {h.modelStatusToString(h.getModelStatus())}")
+
+ if h.modelStatusToString(h.getModelStatus()) != "Optimal":
+ return None, None, None
+
+ layer_placement_values = [[0 for t in range(model_num_layers)] for s in range(PP - 1)]
+ recompute_values = [[0 for z in range(PP - 1)] for j in range(2)]
+ e2e_time = 0
+ for i, val in enumerate(h.getSolution().col_value):
+ for s in range(PP - 1):
+ for t in range(model_num_layers):
+ if h.getColByName(str(layer_placement[s][t]))[1] == i:
+ layer_placement_values[s][t] = precise_round(val)
+ break
+ for j in range(2):
+ for z in range(PP - 1):
+ if h.getColByName(str(recomputing_layers[j][z]))[1] == i:
+ recompute_values[j][z] = precise_round(val)
+ break
+ if h.getColByName(str(bwd_start[0][num_micro_batches - 1]))[1] == i:
+ e2e_time += int(val)
+ for m in range(model_num_layers):
+ if h.getColByName(str(layer_placement[0][m]))[1] == i:
+ e2e_time += val * bwd_duration_layers[m]
+ break
+ for time_id in range(2):
+ if h.getColByName(str(recomputing_layers[j][0]))[1] == i:
+ e2e_time += val * recomputing_fwd[time_id]
+ break
+
+ layer_placement_result = np.array(layer_placement_values).sum(axis=1)
+ print(f"[INFO] result: layer recompute: {recompute_values}")
+ print(f"[INFO] the layer placement: {layer_placement_result}")
+ print(f"[INFO] e2e time: {e2e_time}")
+
+ return layer_placement_result, recompute_values, e2e_time
+
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/profiling.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f5182bd4431495cf9d333422d8dd860d6aeb29d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/profiling.py
@@ -0,0 +1,270 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import os
+import sys
+import time
+import copy
+import operator
+import subprocess
+import signal
+import threading
+import json
+
+import torch
+import torch_npu
+
+from megatron.training.global_vars import set_args, get_args
+from megatron.core import parallel_state
+from mindspeed.core.auto_parallel import get_kv_store
+from mindspeed.core.auto_parallel.mm_search.help import (
+ broadcast_communicate_list,
+ get_json,
+ save_json,
+ INITIAL_CONFIG,
+ PROFILE_CONTENT,
+ STAGE_PROFILE_PATH)
+from mindspeed.core.auto_parallel.mm_search.solver import record_train_config
+from mindspeed.core.auto_parallel.auto_parallel_profiling import BaseLaunch
+
+
+class DistributedPerformanceProfiler(BaseLaunch):
+ def update_args(self, config):
+ args = get_args()
+ self.old_args = copy.deepcopy(args)
+
+ args.pipeline_model_parallel_size = config[0]
+ args.tensor_model_parallel_size = config[1]
+ args.data_parallel_size = config[2]
+ args.micro_batch_size = config[3]
+
+
+ def launch_model(self, config, profile_module):
+ def update_or_append_param(argv: list, key, value=None):
+ if not value:
+ argv.append(key)
+ return
+
+ if key in argv:
+ argv[argv.index(key) + 1] = value
+ else:
+ argv.extend([key, value])
+
+ def remove_param(argv: list, key, has_value=False):
+ if key in argv:
+ pos = argv.index(key)
+ argv.pop(pos)
+ if has_value:
+ argv.pop(pos)
+
+ def monitor_exit(process):
+ while True:
+ exit_flag = get_kv_store().get("exit_flag")
+ if int(exit_flag) == 1:
+ try:
+ process_group_id = os.getpgid(process.pid)
+ os.killpg(process_group_id, signal.SIGKILL)
+ break
+ except ProcessLookupError:
+ break
+ time.sleep(60)
+
+ args = get_args()
+ argv: list = sys.argv[1:]
+ update_or_append_param(argv, '--eval-iters', '0')
+ update_or_append_param(argv, '--train-iters', '5')
+ update_or_append_param(argv, '--pipeline-model-parallel-size', str(args.pipeline_model_parallel_size))
+ update_or_append_param(argv, '--tensor-model-parallel-size', str(args.tensor_model_parallel_size))
+ update_or_append_param(argv, '--micro-batch-size', str(args.micro_batch_size))
+ update_or_append_param(argv, '--auto-parallel-profile')
+ update_or_append_param(argv, '--profile-subgraph-seg')
+ update_or_append_param(argv, '--enable-dummy-optimizer')
+ remove_param(argv, '--auto-parallel-mm')
+ if profile_module == 'profiling_stage_1':
+ update_or_append_param(argv, '--profile-stage', '1')
+ elif profile_module == 'profiling_stage_2':
+ update_or_append_param(argv, '--profile-stage', '2')
+
+ command = [
+ 'torchrun',
+ '--nproc_per_node', str(args.nproc_per_node),
+ '--nnodes', str(args.nnodes),
+ '--node-rank', str(args.node_rank),
+ '--master_addr', str(args.master_addr),
+ '--master_port', str(args.master_port),
+ str(sys.argv[0])
+ ] + argv
+ print(' '.join(map(str, command)), flush=True)
+
+ get_kv_store().set("exit_flag", "0")
+ process = subprocess.Popen(command, shell=False, preexec_fn=lambda: os.setpgrp())
+ monitor_thread = threading.Thread(target=monitor_exit, args=(process,))
+ monitor_thread.start()
+ status_code = process.wait()
+ get_kv_store().set("exit_flag", "1")
+ torch.distributed.barrier()
+ return status_code
+
+
+ def launch(self, config, profile_module):
+ self.update_args(config)
+ args = get_args()
+ if args.node_rank != 0:
+ self.launch_model(config, profile_module)
+ super().recover_args()
+ return None
+
+ buffer = config + [0]
+ torch.distributed.broadcast(torch.tensor(buffer, dtype=torch.int), 0)
+ status_code = self.launch_model(config, profile_module)
+ super().recover_args()
+
+ return status_code
+
+
+def save_profile_data(args):
+ global PROFILE_CONTENT
+ profile_content_json = json.dumps(PROFILE_CONTENT)
+ with open(f'model_profile_{torch.distributed.get_rank()}.json', 'w') as f:
+ f.write(profile_content_json)
+ if args.profile_subgraph_seg:
+ PROFILE_CONTENT = get_profile_from_rank(args)
+ PROFILE_CONTENT = record_train_config(PROFILE_CONTENT)
+
+ if torch.distributed.get_rank() == 0:
+ profile_content_json = json.dumps(PROFILE_CONTENT)
+ with open(f'model_profile.json', 'w') as f:
+ f.write(profile_content_json)
+ print(PROFILE_CONTENT)
+
+
+def set_profile_model_config(args):
+ vit_model_args = ["num_layers"]
+ llm_model_args = ["num_layers", "seq_length", "hidden_size"]
+ train_args = ["micro_batch_size", "use_distributed_optimizer", "simulated_nproc_per_node", "simulated_nnodes"]
+ for arg in vit_model_args:
+ if hasattr(args.mm.model.image_encoder.vision_encoder, arg):
+ INITIAL_CONFIG[f"image_encoder.vision_encoder.{arg}"] = getattr(args.mm.model.image_encoder.vision_encoder, arg)
+ for arg in llm_model_args:
+ if hasattr(args.mm.model.text_decoder, arg):
+ INITIAL_CONFIG[f"text_decoder.{arg}"] = getattr(args.mm.model.text_decoder, arg)
+ for arg in train_args:
+ if hasattr(args, arg):
+ INITIAL_CONFIG[arg] = getattr(args, arg)
+
+ if args.profile_stage == 1:
+ args.mm.model.image_encoder.vision_encoder.num_layers = 2
+ args.mm.model.image_encoder.vision_encoder.pipeline_num_layers = [1, ] * 2 + [0, ] * 2
+ args.mm.model.text_decoder.num_layers = 2
+ args.mm.model.text_decoder.pipeline_num_layers = [0, ] * 2 + [1, ] * 2
+ elif args.profile_stage == 2:
+ args.mm.model.image_encoder.vision_encoder.num_layers = 4
+ args.mm.model.image_encoder.vision_encoder.pipeline_num_layers = [2, ] * 2 + [0, ] * 2
+ args.mm.model.text_decoder.num_layers = 4
+ args.mm.model.text_decoder.pipeline_num_layers = [0, ] * 2 + [2, ] * 2
+
+ recompute_args = ["recompute_granularity", "recompute_method", "recompute_num_layers"]
+ for arg in recompute_args:
+ if hasattr(args.mm.model.image_encoder.vision_encoder, arg):
+ setattr(args.mm.model.image_encoder.vision_encoder, arg, None)
+ if hasattr(args.mm.model.image_encoder.vision_projector, arg):
+ setattr(args.mm.model.image_encoder.vision_projector, arg, None)
+ if hasattr(args.mm.model.text_decoder, arg):
+ setattr(args.mm.model.text_decoder, arg, None)
+
+ print(f"[INFO] initial_config:", INITIAL_CONFIG)
+ print(f"[INFO] finish: vit pp layer: {args.mm.model.image_encoder.vision_encoder.pipeline_num_layers}, \
+ vit num layer: {args.mm.model.image_encoder.vision_encoder.num_layers}, \
+ llm pp layer: {args.mm.model.text_decoder.pipeline_num_layers}, \
+ llm num layer: {args.mm.model.text_decoder.num_layers}, \
+ PP: {args.pipeline_model_parallel_size}, \
+ TP: {args.tensor_model_parallel_size}")
+
+
+def get_profile_from_rank(args):
+ global PROFILE_CONTENT
+
+ def get_average_time(data, m=2):
+ data = sorted(data)
+ median = data[len(data) // 2]
+ normal = [x for x in data if median - m * median < x < median + m * median]
+ try:
+ average = sum(normal) / len(normal)
+ return average
+ except ZeroDivisionError:
+ print("[Error] Divided by zero.")
+ return None
+
+ def get_computer_time():
+ if "fwd_time" in PROFILE_CONTENT:
+ PROFILE_CONTENT["fwd_time"] = get_average_time(PROFILE_CONTENT["fwd_time"])
+ else:
+ PROFILE_CONTENT["fwd_time"] = 0
+ if "bwd_time" in PROFILE_CONTENT:
+ PROFILE_CONTENT["bwd_time"] = get_average_time(PROFILE_CONTENT["bwd_time"])
+ else:
+ PROFILE_CONTENT["bwd_time"] = 0
+ if "act_mem" in PROFILE_CONTENT:
+ PROFILE_CONTENT["act_mem"] = get_average_time(PROFILE_CONTENT["act_mem"])
+ else:
+ PROFILE_CONTENT["act_mem"] = 0
+
+ get_computer_time()
+
+ tp_device_num = args.tensor_model_parallel_size
+ pp_device_num = args.pipeline_model_parallel_size
+ dp_device_num = int(args.nnodes * args.nproc_per_node / tp_device_num / pp_device_num)
+
+ profile_data_list = []
+ for rank_id in range(args.pipeline_model_parallel_size):
+ fwd_time, bwd_time, model_mem, act_mem = 0, 0, [0, 0, 0], 0
+ if parallel_state.get_pipeline_model_parallel_rank() == rank_id:
+ fwd_time = PROFILE_CONTENT['fwd_time']
+ bwd_time = PROFILE_CONTENT['bwd_time']
+ model_mem = PROFILE_CONTENT['module_param']
+ act_mem = PROFILE_CONTENT['act_mem']
+ profile_rank_data = [fwd_time, bwd_time, act_mem] + model_mem
+ profile_rank_data = broadcast_communicate_list(profile_rank_data, rank_id * tp_device_num * dp_device_num)
+ profile_data_list.append(profile_rank_data)
+
+ if args.profile_stage == 1:
+ PROFILE_CONTENT = {}
+ PROFILE_CONTENT['vit_pre'] = {"fwd_time": profile_data_list[0][0],
+ "bwd_time": profile_data_list[0][1],
+ "module_param": profile_data_list[0][-3:],
+ "act_mem": profile_data_list[0][2]}
+ PROFILE_CONTENT['vit_post'] = {"fwd_time": profile_data_list[1][0],
+ "bwd_time": profile_data_list[1][1],
+ "module_param": profile_data_list[1][-3:],
+ "act_mem": profile_data_list[1][2]}
+ PROFILE_CONTENT['llm_pre'] = {"fwd_time": profile_data_list[2][0],
+ "bwd_time": profile_data_list[2][1],
+ "module_param": profile_data_list[2][-3:],
+ "act_mem": profile_data_list[2][2]}
+ PROFILE_CONTENT['llm_post'] = {"fwd_time": profile_data_list[3][0],
+ "bwd_time": profile_data_list[3][1],
+ "module_param": profile_data_list[3][-3:],
+ "act_mem": profile_data_list[3][2]}
+
+ save_json(STAGE_PROFILE_PATH, PROFILE_CONTENT)
+
+ elif args.profile_stage == 2:
+ profile_data = get_json(STAGE_PROFILE_PATH)
+
+ PROFILE_CONTENT = copy.deepcopy(profile_data)
+ model_mem_vit = copy.deepcopy(profile_data_list[0][-3:])
+ model_mem_llm = copy.deepcopy(profile_data_list[2][-3:])
+ for i, v in enumerate(profile_data['vit_pre']['module_param']):
+ model_mem_vit[i] = (profile_data_list[0][-3:][i] - profile_data['vit_pre']['module_param'][i])
+ model_mem_llm[i] = (profile_data_list[2][-3:][i] - profile_data['llm_pre']['module_param'][i])
+
+ PROFILE_CONTENT['vit'] = {"fwd_time": (profile_data_list[0][0] - profile_data['vit_pre']['fwd_time']),
+ "bwd_time": (profile_data_list[0][1] - profile_data['vit_pre']['bwd_time']),
+ "module_param": model_mem_vit,
+ "act_mem": (profile_data_list[0][2] - profile_data['vit_pre']['act_mem'])}
+ PROFILE_CONTENT['llm'] = {"fwd_time": (profile_data_list[2][0] - profile_data['llm_pre']['fwd_time']),
+ "bwd_time": (profile_data_list[2][1] - profile_data['llm_pre']['bwd_time']),
+ "module_param": model_mem_llm,
+ "act_mem": (profile_data_list[2][2] - profile_data['llm_pre']['act_mem']),
+ "embed_time": profile_data['llm_pre']['fwd_time'] - (profile_data_list[2][0] - profile_data['llm_pre']['fwd_time']) if profile_data['llm_pre']['fwd_time'] - (profile_data_list[2][0] - profile_data['llm_pre']['fwd_time']) > 0 else 0}
+
+ return PROFILE_CONTENT
+
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/schedules.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e1c3c4a5ae76276300029c68f555b4f2f49cfdd
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/schedules.py
@@ -0,0 +1,26 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import time
+from functools import wraps
+
+import torch
+
+from megatron.training import get_args
+from mindspeed.core.auto_parallel.mm_search.help import PROFILE_CONTENT
+
+
+def backward_step_decorator(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ global_args = get_args()
+ if global_args.auto_parallel_profile:
+ # get model backward time
+ torch.npu.synchronize()
+ st_time = time.time()
+ grad = fn(*args, **kwargs)
+ torch.npu.synchronize()
+ PROFILE_CONTENT["bwd_time"].append((time.time() - st_time) * 1000)
+ else:
+ grad = fn(*args, **kwargs)
+ return grad
+ return wrapper
+
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/solver.py b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bfec8bea593b2c63eee2d2b135620aa49eee5ee
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/auto_parallel/mm_search/solver.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import json
+import itertools
+import copy
+import sys
+import time
+import math
+
+import torch
+import numpy as np
+
+from mindspeed.core.auto_parallel.mm_search.help import (
+ broadcast_communicate_list,
+ cal_throughput,
+ get_json,
+ INITIAL_CONFIG,
+ GPT_ARGS_PATH)
+from mindspeed.core.auto_parallel.mm_search.pp_layer_search import pp_layer_search
+
+
+def record_train_config(profile):
+ for key in INITIAL_CONFIG:
+ profile[key] = INITIAL_CONFIG[key]
+ gpt_args = get_json(GPT_ARGS_PATH)
+ for key in gpt_args:
+ profile[key] = gpt_args[key]
+ return profile
+
+
+class AutoParallelSolver():
+ def __init__(self, profile_data):
+ if torch.cuda.is_available():
+ self.max_available_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2
+ else:
+ self.max_available_memory = 62000
+ self.layer_name = ['vit_pre', 'vit', 'vit_post', 'llm_pre', 'llm', 'llm_post']
+ print(f"[INFO] NPU available memory: {self.max_available_memory}")
+
+
+ def update_profile(self, args, parallel_cfg, profile_data):
+ update_profile_data = copy.deepcopy(profile_data)
+
+ if args.use_distributed_optimizer:
+ DP = parallel_cfg[2]
+
+ for key in profile_data:
+ if key in self.layer_name:
+ update_profile_data[key]['module_param'][1] = profile_data[key]['module_param'][1] / 12 * (4 + 8 / DP)
+
+ return update_profile_data
+
+
+ def cal_max_layer(self, profile_data):
+ llm_available_memory = self.max_available_memory - sum(profile_data['llm_post']['module_param']) - profile_data['llm_post']['act_mem']
+ last_stage_max_layer = math.floor(llm_available_memory / (sum(profile_data['llm']['module_param']))) + 1
+ return last_stage_max_layer
+
+
+ def trans_optimal_config(self, optimal_config, profile_data):
+ parallel_config = optimal_config['parallel_config']
+ optimal_config['parallel_config'] = {'PP': parallel_config[0],
+ 'TP': parallel_config[1],
+ 'DP': parallel_config[2],
+ 'MBS': parallel_config[3]}
+
+ layer_placement = optimal_config['layer_placement']
+ sum_model_layer = profile_data['image_encoder.vision_encoder.num_layers'] + profile_data['text_decoder.num_layers']
+ layer_placement.append(sum_model_layer)
+ merge_layer_place = []
+ merge_layer_place.append(int(layer_placement[0]))
+ for i in range(1, len(layer_placement)):
+ layer_num = int(layer_placement[i] - layer_placement[i - 1])
+ merge_layer_place.append(layer_num)
+
+ vit_layer_placement = [0] * optimal_config['parallel_config']['PP']
+ llm_layer_placement = [0] * optimal_config['parallel_config']['PP']
+ vit_layer_num = profile_data['image_encoder.vision_encoder.num_layers']
+ llm_layer_num = profile_data['text_decoder.num_layers']
+ for i, capacity in enumerate(merge_layer_place):
+ a_count = min(vit_layer_num, capacity)
+ vit_layer_placement[i] = a_count
+ vit_layer_num -= a_count
+ b_count = min(llm_layer_num, capacity - a_count)
+ llm_layer_placement[i] = b_count
+ llm_layer_num -= b_count
+ optimal_config['layer_placement'] = {'vit_layer_placement': vit_layer_placement,
+ 'llm_layer_placement': llm_layer_placement}
+
+ layer_recompute = optimal_config['layer_recompute']
+ optimal_config['layer_recompute'] = {'vit_layer_recompute': layer_recompute[0],
+ 'llm_layer_recompute': layer_recompute[1]}
+ return optimal_config
+
+
+def solve_auto_parallel_mm(args, parallel_cfgs):
+ if torch.distributed.get_rank() == 0:
+ with open(f'model_profile.json', 'r', encoding='utf-8') as f:
+ profile_data = json.load(f)
+
+ solver = AutoParallelSolver(profile_data)
+
+ optimal_config = {}
+ optimal_throughput = 0
+ for parallel_cfg in parallel_cfgs:
+ print(f"[INFO] now to solve config {parallel_cfg}")
+
+ cfg_profile_data = solver.update_profile(args, parallel_cfg, profile_data)
+
+ last_stage_max_layer = solver.cal_max_layer(cfg_profile_data)
+ print(f"[INFO] last stage max layer {last_stage_max_layer}")
+
+ layer_placement, layer_recompute, e2e_time = pp_layer_search(parallel_cfg, cfg_profile_data, solver.max_available_memory, last_stage_max_layer)
+
+ if e2e_time is None:
+ continue
+
+ per_npu_throughput = cal_throughput(e2e_time, cfg_profile_data, parallel_cfg)
+ print(f"[INFO] per_npu throughput {per_npu_throughput}")
+
+ if per_npu_throughput > optimal_throughput:
+ optimal_config = {"parallel_config": parallel_cfg,
+ "layer_placement": layer_placement.tolist(),
+ "layer_recompute": layer_recompute,
+ "e2e_time": e2e_time,
+ "throughput": per_npu_throughput}
+ optimal_config = solver.trans_optimal_config(optimal_config, profile_data)
+ optimal_throughput = per_npu_throughput
+
+ print(f"optimal_config: {optimal_config}")
+ return optimal_config
+ return None
+
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/FlashAttentionScoreGrad_910B.pth b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/FlashAttentionScoreGrad_910B.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f742d0624c77f54018ed8e13d5d54f54c1b0a57a
Binary files /dev/null and b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/FlashAttentionScoreGrad_910B.pth differ
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/FlashAttentionScore_910B.pth b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/FlashAttentionScore_910B.pth
new file mode 100644
index 0000000000000000000000000000000000000000..cc9533ea3e004a22a77f0218ba825dc92af6bcf8
Binary files /dev/null and b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/FlashAttentionScore_910B.pth differ
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/MatMul_910B.pth b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/MatMul_910B.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1b2f42f9486c9bd6a14aae62d4fbc3c1a1635db3
Binary files /dev/null and b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/MatMul_910B.pth differ
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/RmsNormGrad_910B.pth b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/RmsNormGrad_910B.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3881b4048d6e5b78b44e61f6998a47f0ca597f58
Binary files /dev/null and b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/RmsNormGrad_910B.pth differ
diff --git a/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/RmsNorm_910B.pth b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/RmsNorm_910B.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3f08eebb1f60d6d83d8648fc710d3c6614f82ab6
Binary files /dev/null and b/model/train/yoco_moe/mindspeed/core/auto_parallel/noise_predict_ckpt/RmsNorm_910B.pth differ
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/__init__.py b/model/train/yoco_moe/mindspeed/core/context_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/adaptive_context_parallel.py b/model/train/yoco_moe/mindspeed/core/context_parallel/adaptive_context_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..da8bd4f4103824f67cf7d9bdcdcdf248a256bbbe
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/context_parallel/adaptive_context_parallel.py
@@ -0,0 +1,371 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+import torch_npu
+from mindspeed.ops.npu_ring_attention_update import npu_ring_attention_update
+
+
+def flash_attn_p2p_communicate(scheduling_info, send_q_dst, recv_q_src, send_kv_dst, recv_kv_src, cp_group, rank):
+ send_recv_ops = []
+ is_recv_q = False
+ is_recv_kv = False
+ for send_dst in scheduling_info.send_q_dst:
+ send_op = torch.distributed.P2POp(torch.distributed.isend, send_q_dst, send_dst, group=cp_group, tag=send_dst)
+ send_recv_ops.append(send_op)
+ for send_dst in scheduling_info.send_kv_dst:
+ send_op = torch.distributed.P2POp(torch.distributed.isend, send_kv_dst, send_dst, group=cp_group, tag=send_dst)
+ send_recv_ops.append(send_op)
+ if scheduling_info.recv_q_src > -1:
+ recv_src = scheduling_info.recv_q_src
+ recv_op = torch.distributed.P2POp(torch.distributed.irecv, recv_q_src, recv_src, group=cp_group, tag=rank)
+ send_recv_ops.append(recv_op)
+ is_recv_q = True
+ if scheduling_info.recv_kv_src > -1:
+ recv_src = scheduling_info.recv_kv_src
+ recv_op = torch.distributed.P2POp(torch.distributed.irecv, recv_kv_src, recv_src, group=cp_group, tag=rank)
+ send_recv_ops.append(recv_op)
+ is_recv_kv = True
+ send_recv_ops_qkv = []
+ if len(send_recv_ops) > 0:
+ send_recv_ops_qkv = torch.distributed.batch_isend_irecv(send_recv_ops)
+ return is_recv_q, is_recv_kv, send_recv_ops_qkv
+
+
+def flash_attn_p2p_communicate_o(scheduling_info, send_o_dst, recv_o_src, cp_group, rank):
+ send_recv_ops = []
+ is_recv_o = False
+ for recv_src in scheduling_info.recv_o_src:
+ recv_op = torch.distributed.P2POp(torch.distributed.irecv, recv_o_src, recv_src, group=cp_group, tag=100000 + rank)
+ send_recv_ops.append(recv_op)
+ is_recv_o = True
+ if scheduling_info.send_o_dst > -1:
+ send_dst = scheduling_info.send_o_dst
+ send_op = torch.distributed.P2POp(torch.distributed.isend, send_o_dst, send_dst, group=cp_group, tag=100000 + send_dst)
+ send_recv_ops.append(send_op)
+ send_recv_ops_o = []
+ if len(send_recv_ops) > 0:
+ send_recv_ops_o = torch.distributed.batch_isend_irecv(send_recv_ops)
+ return is_recv_o, send_recv_ops_o
+
+
+class AdaptiveAttention(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, q, k, v, n, cp_para, softmax_scale=None, attn_mask=None, dropout_p=0.):
+ keep_prob = 1. - dropout_p
+ cp_size = cp_para.get("cp_size")
+ rank = cp_para.get("rank")
+ scheduling_info = cp_para.get('scheduling_info')
+ cp_group = cp_para.get('cp_group')
+
+ seq_len = q.shape[0]
+ batch_size = q.shape[1]
+ head_dim = q.shape[-1] // n
+
+ if softmax_scale is None:
+ softmax_scale = head_dim ** (-0.5)
+ send_kv_dst = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) # [2, s, b, h]
+ recv_q_src, recv_kv_src = None, None
+ send_recv_ops_qkv = []
+ is_recv_q, is_recv_kv = False, False
+ send_o_dst, recv_o_src = None, None
+ send_recv_ops_o = []
+ is_recv_o, is_send_o = False, False
+ attn_out, softmax_max, softmax_sum = None, None, None
+
+ round_num = len(scheduling_info)
+ for i in range(round_num + 1):
+ is_activate = is_recv_q or is_recv_kv # receive q or kv last round means calculate this round
+ is_send_o = is_recv_q # receive q last round means send o this round
+
+ # wait until QKV is received
+ if len(send_recv_ops_qkv) > 0:
+ for send_recv_op in send_recv_ops_qkv:
+ send_recv_op.wait()
+
+ # determine QKV for this round
+ cur_q = recv_q_src if is_recv_q else q
+ cur_k = recv_kv_src[0] if is_recv_kv else k
+ cur_v = recv_kv_src[1] if is_recv_kv else v
+
+ # send QKV for next round
+ if i < round_num - 1:
+ recv_q_src = torch.empty_like(q)
+ recv_kv_src = torch.empty_like(send_kv_dst)
+ is_recv_q, is_recv_kv, send_recv_ops_qkv = flash_attn_p2p_communicate(scheduling_info[i],
+ q, recv_q_src,
+ send_kv_dst, recv_kv_src,
+ cp_group, rank)
+
+ # calculate QKV for this round
+ if i == 0 or (i < round_num and is_activate):
+ this_mask = attn_mask[i] if isinstance(attn_mask, list) else None
+ attn_outs = torch_npu.npu_fusion_attention(
+ cur_q, cur_k, cur_v, n, "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=this_mask,
+ scale=softmax_scale,
+ pre_tockens=cur_k.shape[0],
+ next_tockens=cur_k.shape[0],
+ keep_prob=keep_prob,
+ sparse_mode=0
+ )
+ cur_attn_out, cur_softmax_max, cur_softmax_sum = attn_outs[0], attn_outs[1], attn_outs[2] # [s, b, h], [b, n, s, 8], [b, n, s, 8]
+ if not is_send_o:
+ if i == 0:
+ softmax_sum = cur_softmax_sum
+ softmax_max = cur_softmax_max
+ attn_out = cur_attn_out
+ else:
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = npu_ring_attention_update(
+ attn_out, softmax_max, softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum)
+ attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
+
+ # wait until O is received
+ if len(send_recv_ops_o) > 0:
+ for send_recv_op in send_recv_ops_o:
+ send_recv_op.wait()
+
+ # update O if receive O
+ if is_recv_o:
+ recv_attn_out = recv_o_src[:, :, :, :head_dim].permute(2, 0, 1, 3) # [b, n, s, d] -> [s, b, n, d]
+ recv_attn_out = recv_attn_out.view(seq_len, batch_size, -1).to(attn_out.dtype) # [s, b, n, d] -> [s, b, h]
+ recv_softmax_max = recv_o_src[:, :, :, head_dim:head_dim + 8]
+ recv_softmax_sum = recv_o_src[:, :, :, head_dim + 8:]
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = npu_ring_attention_update(
+ attn_out, softmax_max, softmax_sum, recv_attn_out, recv_softmax_max, recv_softmax_sum)
+ attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
+
+ # send O for next round
+ if i < round_num:
+ cur_attn_out = cur_attn_out.view(seq_len, batch_size, n, -1).permute(1, 2, 0, 3) # [s, b, h] -> [s, b, n, d]
+ send_o_dst = torch.cat((cur_attn_out, cur_softmax_max), dim=-1) # [s, b, n, d+8]
+ send_o_dst = torch.cat((send_o_dst, cur_softmax_sum), dim=-1) # [s, b, n, d+16]
+ recv_o_src = torch.empty_like(send_o_dst)
+ is_recv_o, send_recv_ops_o = flash_attn_p2p_communicate_o(scheduling_info[i], send_o_dst, recv_o_src, cp_group, rank)
+
+ k, v = send_kv_dst[0], send_kv_dst[1]
+ attn_mask = attn_mask if isinstance(attn_mask, list) else [attn_mask]
+ ctx.save_for_backward(q, k, v, *attn_mask, attn_out, softmax_max, softmax_sum)
+ ctx.n = n
+ ctx.softmax_scale = softmax_scale
+ ctx.cp_group = cp_group
+ ctx.cp_size = cp_size
+ ctx.cp_rank = rank
+ ctx.scheduling_info = scheduling_info
+ return attn_out
+
+ @staticmethod
+ def backward(ctx, dout):
+ q, k, v, *attn_mask, attn_out, softmax_max, softmax_sum = ctx.saved_tensors
+ softmax_max = softmax_max.contiguous()
+ softmax_sum = softmax_sum.contiguous()
+
+ n = ctx.n
+ softmax_scale = ctx.softmax_scale
+ cp_group = ctx.cp_group
+ cp_size = ctx.cp_size
+ rank = ctx.cp_rank
+ dist_attn_scheduler = ctx.scheduling_info
+
+ send_recv_reqs_input = []
+ send_recv_reqs_dq = []
+ send_recv_reqs_dkv = []
+ num_received_dq, num_received_dkv = 0, 0
+
+ # 把m和l的1/8进行all-gather
+ softmax_max_all = torch.empty((cp_size, *(softmax_max.shape[:-1])), device=softmax_max.device,
+ dtype=softmax_max.dtype)
+ softmax_sum_all = torch.empty((cp_size, *(softmax_sum.shape[:-1])), device=softmax_sum.device,
+ dtype=softmax_sum.dtype)
+ softmax_max_local = softmax_max[:, :, :, 0].contiguous() # [b, n, s, 8] -> [b, n, s, 1]
+ softmax_sum_local = softmax_sum[:, :, :, 0].contiguous() # [b, n, s, 8] -> [b, n, s, 1]
+ # [b, n, s] -> [8, b, n, s]
+ handle_softmax_max = torch.distributed._all_gather_base(softmax_max_all, softmax_max_local,
+ group=cp_group, async_op=True)
+ handle_softmax_sum = torch.distributed._all_gather_base(softmax_sum_all, softmax_sum_local,
+ group=cp_group, async_op=True)
+
+ # 组合需要发送的tensors
+ kv = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) # [2, s, b, h]
+ qodo = torch.cat((q.unsqueeze(0), attn_out.unsqueeze(0), dout.unsqueeze(0)), dim=0) # [3, s, b, h]
+
+ # 创建接收tensors的buffer
+ kv_recv = torch.empty((2, *kv.shape), device=kv.device, dtype=kv.dtype) # [2, 2, s, b, h]
+ qodo_recv = torch.empty((2, 3, *q.shape), device=q.device, dtype=q.dtype) # [2, 3, s, b, h]
+ dq_recv = torch.empty((2, *q.shape), device=q.device, dtype=q.dtype) # [2, s, b, h]
+ dkv_recv = torch.empty((2, 2, *k.shape), device=k.device, dtype=k.dtype) # [2, 2, s, b, h]
+
+ # 初始化第0轮的cur_q, cur_k, cur_v, cur_o, cur_do, cur_m, cur_l
+ cur_q, cur_k, cur_v = q, k, v
+ cur_o, cur_do = attn_out, dout
+ cur_m, cur_l = softmax_max, softmax_sum
+
+ dq, dk, dv = None, None, None
+
+ handle_softmax_max.wait()
+ handle_softmax_sum.wait()
+
+ # 循环遍历每一个round
+ round_cnt = len(dist_attn_scheduler)
+ for rnd_idx in range(round_cnt):
+ is_active = True
+ if len(send_recv_reqs_input) > 0:
+ idx = 0
+ for send_recv_op in send_recv_reqs_input:
+ send_recv_op.wait()
+ idx += 1
+
+ cur_recv_buf_idx = rnd_idx % 2
+ prev_recv_buf_idx = 1 - cur_recv_buf_idx
+
+ # 确定本轮的cur_q, cur_k, cur_v, cur_o, cur_do, cur_m, cur_l
+ if rnd_idx > 0:
+ prev_scheduling = dist_attn_scheduler[rnd_idx - 1]
+ if prev_scheduling.recv_q_src > -1: # 这一轮计算自己出KV
+ cur_q, cur_o, cur_do = (qodo_recv[prev_recv_buf_idx][0], qodo_recv[prev_recv_buf_idx][1],
+ qodo_recv[prev_recv_buf_idx][2])
+ cur_k, cur_v = k, v
+
+ idx = torch.distributed.get_group_rank(cp_group, prev_scheduling.recv_q_src)
+ cur_m = softmax_max_all[idx, :, :, :].view(softmax_max_all.shape[1:] +
+ (1,)).repeat(1, 1, 1, 8)
+ cur_l = softmax_sum_all[idx, :, :, :].view(softmax_max_all.shape[1:] +
+ (1,)).repeat(1, 1, 1, 8)
+ elif prev_scheduling.recv_kv_src > -1: # 这一轮计算自己出Q
+ cur_q, cur_o, cur_do = q, attn_out, dout
+ cur_k, cur_v = kv_recv[prev_recv_buf_idx][0], kv_recv[prev_recv_buf_idx][1]
+ cur_m, cur_l = softmax_max, softmax_sum
+ else:
+ is_active = False
+
+ # 把本轮的input通信加入input通信队列(需要通信得到下一轮执行所需的q+o+do/k+v、发送下一轮别的device需要的q+o+do/k+v)
+ send_recv_ops_input, send_recv_reqs_input = [], []
+ cur_scheduling = dist_attn_scheduler[rnd_idx] # 本轮计算过程中需要并行执行的通信调度
+
+ if cur_scheduling.recv_q_src > -1:
+ # recv q + attn_out + dout from cur_scheduling.recv_q_src
+ recv_op = torch.distributed.P2POp(torch.distributed.irecv, qodo_recv[cur_recv_buf_idx],
+ cur_scheduling.recv_q_src, cp_group, tag=rank)
+ send_recv_ops_input.append(recv_op)
+ elif cur_scheduling.recv_kv_src > -1:
+ # recv kv from cur_scheduling.recv_kv_src
+ recv_op = torch.distributed.P2POp(torch.distributed.irecv, kv_recv[cur_recv_buf_idx],
+ cur_scheduling.recv_kv_src, cp_group, tag=rank)
+ send_recv_ops_input.append(recv_op)
+
+ if len(cur_scheduling.send_q_dst) > 0:
+ for send_q_dev in cur_scheduling.send_q_dst:
+ # send q + attn_out + dout to send_q_dev
+ send_op = torch.distributed.P2POp(torch.distributed.isend, qodo, send_q_dev, cp_group,
+ tag=send_q_dev)
+ send_recv_ops_input.append(send_op)
+ if len(cur_scheduling.send_kv_dst) > 0:
+ for send_kv_dev in cur_scheduling.send_kv_dst:
+ # send kv to send_kv_dev
+ send_op = torch.distributed.P2POp(torch.distributed.isend, kv, send_kv_dev, cp_group,
+ tag=send_kv_dev)
+ send_recv_ops_input.append(send_op)
+
+ # 发起本轮的input通信
+ if len(send_recv_ops_input) > 0:
+ send_recv_reqs_input = torch.distributed.batch_isend_irecv(send_recv_ops_input)
+
+ # 仍然按照前向的调度顺序来进行反向的计算,需要q k v do_q m_q l_q
+ if is_active:
+ this_mask = attn_mask[rnd_idx] if attn_mask is not None else None
+ attn_grad_outs = torch_npu.npu_fusion_attention_grad(
+ cur_q, cur_k, cur_v, cur_do, n,
+ "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=this_mask,
+ softmax_max=cur_m,
+ softmax_sum=cur_l,
+ attention_in=cur_o,
+ scale_value=softmax_scale,
+ sparse_mode=0,
+ keep_prob=1.,
+ )
+ cur_dq, cur_dk, cur_dv = attn_grad_outs[0], attn_grad_outs[1], attn_grad_outs[2]
+ else:
+ cur_dq, cur_dk, cur_dv = None, None, None
+
+ if rnd_idx == 0:
+ dq = cur_dq
+ dk = cur_dk
+ dv = cur_dv
+ else:
+ # 等待output send-recv结束,并用收到的dq/dkdv来更新结果
+ if num_received_dq > 0:
+ for send_recv_op in send_recv_reqs_dq:
+ send_recv_op.wait()
+ for i in range(num_received_dq):
+ dq.add_(dq_recv[i])
+
+ if num_received_dkv > 0:
+ for send_recv_op in send_recv_reqs_dkv:
+ send_recv_op.wait()
+ for i in range(num_received_dkv):
+ dk.add_(dkv_recv[i][0])
+ dv.add_(dkv_recv[i][1])
+ # 用cur_dq, cur_dk, cur_dv更新结果:检查当前轮的计算是否是帮别人算的,如果是/不是,则加上cur_dk, cur_dv/cur_dq
+ send_recv_reqs_dq, send_recv_reqs_dkv = [], []
+ send_recv_ops_dq, send_recv_ops_dkv = [], []
+ num_received_dq, num_received_dkv = 0, 0
+ prev_scheduling = dist_attn_scheduler[rnd_idx - 1]
+ if is_active:
+ if prev_scheduling.recv_q_src > -1: # 这一轮计算自己出KV,是帮别人算
+ dk.add_(cur_dk)
+ dv.add_(cur_dv)
+ send_dq = cur_dq
+ send_op = torch.distributed.P2POp(torch.distributed.isend, send_dq, prev_scheduling.recv_q_src,
+ cp_group, tag=rank * 10)
+ send_recv_ops_dq.append(send_op)
+ elif prev_scheduling.recv_kv_src > -1: # 这一轮计算自己出Q
+ dq.add_(cur_dq)
+ send_dkv = torch.cat((cur_dk.unsqueeze(0), cur_dv.unsqueeze(0)), dim=0) # [2, s, b, h]
+ send_op = torch.distributed.P2POp(torch.distributed.isend, send_dkv,
+ prev_scheduling.recv_kv_src, cp_group, tag=rank * 10)
+ send_recv_ops_dkv.append(send_op)
+
+ # 发起output recv:上一轮发Q的目的device,从那儿收dQ;上一轮发KV的目的device,从那儿收dKdV;
+ if len(prev_scheduling.send_q_dst) > 0:
+ for idx, send_q_dst in enumerate(prev_scheduling.send_q_dst):
+ recv_op = torch.distributed.P2POp(torch.distributed.irecv, dq_recv[idx],
+ send_q_dst, cp_group, tag=send_q_dst * 10)
+ send_recv_ops_dq.append(recv_op)
+ num_received_dq += 1
+ if len(prev_scheduling.send_kv_dst) > 0:
+ for idx, send_kv_dst in enumerate(prev_scheduling.send_kv_dst):
+ recv_op = torch.distributed.P2POp(torch.distributed.irecv, dkv_recv[idx],
+ send_kv_dst, cp_group, tag=send_kv_dst * 10)
+ send_recv_ops_dkv.append(recv_op)
+ num_received_dkv += 1
+
+ if len(send_recv_ops_dq) > 0:
+ send_recv_reqs_dq = torch.distributed.batch_isend_irecv(send_recv_ops_dq)
+ if len(send_recv_ops_dkv) > 0:
+ send_recv_reqs_dkv = torch.distributed.batch_isend_irecv(send_recv_ops_dkv)
+
+ if num_received_dq > 0:
+ for send_recv_op in send_recv_reqs_dq:
+ send_recv_op.wait()
+ for i in range(num_received_dq):
+ dq.add_(dq_recv[i])
+
+ if num_received_dkv > 0:
+ for send_recv_op in send_recv_reqs_dkv:
+ send_recv_op.wait()
+ for i in range(num_received_dkv):
+ dk.add_(dkv_recv[i][0])
+ dv.add_(dkv_recv[i][1])
+
+ return dq, dk, dv, None, None, None, None, None
+
+
+def adaptive_attn_context_parallel(q, k, v, n, cp_para, softmax_scale=None, attn_mask=None, dropout_p=0.):
+ out = AdaptiveAttention.apply(
+ q, k, v, n, cp_para, softmax_scale, attn_mask, dropout_p
+ )
+ return out
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/context_parallel_kv_cache.py b/model/train/yoco_moe/mindspeed/core/context_parallel/context_parallel_kv_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd474fa6c81282cc0cc46d3207538b35cc41dd60
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/context_parallel/context_parallel_kv_cache.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+
+
+def get_cache_policy(layer_number, cache_policy_init, cache_interval):
+ cache_policy = cache_policy_init
+ if cache_interval != 0:
+ if layer_number % (cache_interval + 1) == 1:
+ cache_policy = cache_policy_init
+ else:
+ cache_policy = None
+
+ return cache_policy
+
+
+class ContextParallelKVCache:
+ """Context Parallelism KV Cache Implementation"""
+
+ def __init__(self, cache_policy, outer_data, inner_data, k, v) -> None:
+ self.outer_size, self.outer_ring_p2p = outer_data
+ self.inner_size, self.inner_ring_p2p = inner_data
+ self.cache_policy = cache_policy
+ self.k = k
+ self.v = v
+ self.cp_size = self.outer_size * self.inner_size
+ self.outer_index = 0
+
+ send_data = torch.zeros((2, *self.k[-1].shape), dtype=self.k[-1].dtype, device=self.k[-1].device)
+ send_data.copy_(torch.cat((self.k[-1].unsqueeze(0), self.v[-1].unsqueeze(0)), dim=0))
+ outer_recv_data = send_data.clone()
+ inner_recv_data = send_data.clone()
+ self.cur_kv, self.outer_next_kv, self.inner_next_kv = send_data, outer_recv_data, inner_recv_data
+
+ self.k_out, self.v_out = None, None
+
+ def communicate_outer_ring_kv(self, index) -> None:
+ """
+ Implements of kv communications in outer ring
+
+ Args:
+ index (int): the index of outer for loop
+ """
+ self.outer_index = index
+
+ # index > 0, using kv after communication
+ if index > 0:
+ if index == 1 and self.cache_policy == "half":
+ # special case: index=1, cache_policy=half, KV block should be transformed to K
+ self.outer_ring_p2p.wait()
+ if self.inner_size > 1:
+ # KV have been transformed in inner ring
+ self.cur_kv.copy_(self.outer_next_kv[1])
+ self.outer_next_kv = self.outer_next_kv[1].clone()
+ else:
+ # KV is not transformed in inner ring
+ self.cur_kv, self.outer_next_kv = self.outer_next_kv, self.cur_kv
+ self.k_out, self.v_out = self.cur_kv[0].clone(), self.cur_kv[1].clone()
+ self.cur_kv = self.cur_kv[1].clone()
+ self.outer_next_kv = self.outer_next_kv[1].clone()
+ else:
+ self.outer_ring_p2p.wait()
+ self.cur_kv, self.outer_next_kv = self.outer_next_kv, self.cur_kv
+
+ # last step, no need to communicate KV
+ is_last_step = index + 1 == self.outer_size
+
+ # only need communicate KV in the first step when full cache
+ first_step_with_full_cache = self.cache_policy == "full" and index > 0
+
+ if not first_step_with_full_cache and not is_last_step:
+ self.outer_ring_p2p.async_send_recv(send_tensor=self.cur_kv, recv_tensor=self.outer_next_kv)
+
+ def communicate_inner_ring_kv(self, index):
+ """
+ Implements of kv communications in inner ring
+
+ Args:
+ index (int): the index of inner for loop
+
+ Returns:
+ cur_k (torch.tensor): k(keys), backward operator input in this iteration
+ cur_v (torch.tensor): v(values), backward operator input in this iteration
+ """
+ total_index = self.outer_index * self.inner_size + index
+
+ # index > 0, using kv after communication
+ if index > 0:
+ if total_index == 1 and self.cache_policy == "half":
+ # special case: index=1, cache_policy=half, KV block should be transformed to K
+ self.inner_ring_p2p.wait()
+ self.cur_kv, self.inner_next_kv = self.inner_next_kv, self.cur_kv
+ self.k_out, self.v_out = self.cur_kv[0].clone(), self.cur_kv[1].clone()
+ self.cur_kv = self.cur_kv[1].clone()
+ self.inner_next_kv = self.inner_next_kv[1].clone()
+ else:
+ self.inner_ring_p2p.wait()
+ self.cur_kv, self.inner_next_kv = self.inner_next_kv, self.cur_kv
+
+ # last step, no need to communicate KV
+ is_last_step = index + 1 == self.inner_size
+
+ # only need communicate KV in the first step when full cache
+ first_step_with_full_cache = self.cache_policy == "full" and total_index > 0
+
+ if not first_step_with_full_cache and not is_last_step:
+ self.inner_ring_p2p.async_send_recv(send_tensor=self.cur_kv, recv_tensor=self.inner_next_kv)
+
+ cache_index = self.cp_size - total_index - 1
+ if self.cache_policy is None:
+ self.k_out, self.v_out = self.cur_kv[0], self.cur_kv[1]
+
+ elif self.cache_policy == "half":
+ if total_index == 0:
+ self.k_out, self.v_out = self.cur_kv[0], self.cur_kv[1]
+ elif total_index > 1:
+ self.k_out, self.v_out = self.k[cache_index], self.cur_kv
+
+ elif self.cache_policy == "full":
+ if total_index <= 1:
+ self.k_out, self.v_out = self.cur_kv[0], self.cur_kv[1]
+ else:
+ self.k_out, self.v_out = self.k[cache_index], self.v[cache_index]
+
+ return self.k_out, self.v_out
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/ring_context_parallel.py b/model/train/yoco_moe/mindspeed/core/context_parallel/ring_context_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..910a72eaf5b726107f646e2fddad420d7804f405
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/context_parallel/ring_context_parallel.py
@@ -0,0 +1,939 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+import torch_npu
+from einops import rearrange
+from mindspeed.ops.fusion_attention_v2 import npu_fusion_attention, npu_fusion_attention_grad
+from mindspeed.core.context_parallel.context_parallel_kv_cache import ContextParallelKVCache
+from .utils import RingP2P, tnd_out_update, causal_out_update, general_out_update, forward_update, sbh_to_tnd, tnd_to_sbh, unflatten_softmax, flatten_softmax, get_selection_indices_for_tnd_softmax_update
+
+
+def causal_forward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v, attn_mask=None):
+ cur_attn_mask = None
+ if q_block_id == kv_block_id:
+ # [2, s, b, h] -> [2s, b, h]
+ cur_attn_mask = attn_mask
+ cur_q, cur_k, cur_v = [x.view(-1, *x.shape[2:]) for x in [q, cur_k, cur_v]]
+ elif kv_block_id <= q_block_id:
+ # [2, s, b, h] -> [2s, b, h]
+ cur_q = q.view(-1, *q.shape[2:])
+ # only k[0] v[0] need to be calculated
+ cur_k, cur_v = [x[0] for x in [cur_k, cur_v]]
+ else:
+ # only q[1] need to be calculated
+ cur_q = q[1]
+ # [2, s, b, h] -> [2s, b, h]
+ cur_k, cur_v = [x.view(-1, *x.shape[2:]) for x in [cur_k, cur_v]]
+
+ return cur_q, cur_k, cur_v, cur_attn_mask
+
+
+def tnd_forward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v, fetch_ptrs, attn_mask=None):
+ seqlen, half_seqlen, q_index, kv_index = fetch_ptrs
+ actual_seq_qlen, actual_seq_kvlen, sub_out_seq_len = seqlen
+ half_actual_seq_qlen, half_actual_seq_kvlen, half_sub_out_seq_len = half_seqlen
+
+ cur_attn_mask = None
+ if q_block_id == kv_block_id:
+ cur_attn_mask = attn_mask
+ cur_q = q
+ cur_seq_qlen, cur_seq_kvlen = actual_seq_qlen, actual_seq_kvlen
+ cur_sub_out_seq_len = sub_out_seq_len
+ elif kv_block_id <= q_block_id:
+ cur_q = q
+ cur_k, cur_v = [torch.index_select(x, 0, kv_index) for x in [cur_k, cur_v]]
+ cur_seq_qlen, cur_seq_kvlen = actual_seq_qlen, half_actual_seq_kvlen
+ cur_sub_out_seq_len = sub_out_seq_len
+ else:
+ cur_q = torch.index_select(q, 0, q_index)
+ cur_seq_qlen, cur_seq_kvlen = half_actual_seq_qlen, actual_seq_kvlen
+ cur_sub_out_seq_len = half_sub_out_seq_len
+
+ return cur_q, cur_k, cur_v, cur_attn_mask, (cur_seq_qlen, cur_seq_kvlen, cur_sub_out_seq_len)
+
+
+def tnd_backward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v, attn_out, dout,
+ softmax_values, seq_lens, index_values, attn_mask=None):
+ # fetch backward output
+ actual_seq_qlen, actual_seq_kvlen, half_actual_seq_kvlen, half_actual_seq_qlen = seq_lens
+ softmax_max, softmax_sum, half_softmax_max, half_softmax_sum = softmax_values
+ q_index, kv_index = index_values
+ cur_attn_mask = None
+ if q_block_id >= kv_block_id:
+ if q_block_id == kv_block_id:
+ cur_attn_mask = attn_mask
+ cur_seq_qlen, cur_seq_kvlen = actual_seq_qlen, actual_seq_kvlen
+ else:
+ cur_k, cur_v = [torch.index_select(x, 0, kv_index) for x in [cur_k, cur_v]]
+ cur_seq_qlen, cur_seq_kvlen = actual_seq_qlen, half_actual_seq_kvlen
+
+ cur_q, cur_attn_out, cur_dout = q, attn_out, dout
+ cur_softmax_max, cur_softmax_sum = softmax_max, softmax_sum
+ else:
+ cur_q, cur_attn_out, cur_dout = [torch.index_select(x, 0, q_index) for x in [q, attn_out, dout]]
+ cur_softmax_max, cur_softmax_sum = half_softmax_max, half_softmax_sum
+ cur_seq_qlen, cur_seq_kvlen = half_actual_seq_qlen, actual_seq_kvlen
+
+ return (cur_q, cur_k, cur_v), cur_attn_out, cur_dout, (cur_softmax_max, cur_softmax_sum), cur_attn_mask, (cur_seq_qlen, cur_seq_kvlen)
+
+
+def causal_backward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v, attn_out, dout,
+ softmax_max, softmax_sum, attn_mask=None):
+ cur_attn_mask = None
+ if q_block_id >= kv_block_id:
+ # [b, n, 2, s, 8] -> [b, n, 2s, 8]
+ cur_softmax_max = softmax_max.view(softmax_max.shape[0], softmax_max.shape[1], -1,
+ softmax_max.shape[-1])
+ cur_softmax_sum = softmax_sum.view(softmax_sum.shape[0], softmax_sum.shape[1], -1,
+ softmax_sum.shape[-1])
+ # [2, s, b, h] -> [2s, b, h]
+ cur_q, cur_attn_out, cur_dout = [x.view(-1, *x.shape[2:]) for x in [q, attn_out, dout]]
+ if q_block_id == kv_block_id:
+ cur_attn_mask = attn_mask
+ # [2, s, b, h] -> [2s, b, h]
+ cur_k, cur_v, = [x.view(-1, *x.shape[2:]) for x in [cur_k, cur_v]]
+ else:
+ cur_k, cur_v = [x[0] for x in [cur_k, cur_v]]
+ else:
+ # [2, s, b, h] -> [2s, b, h]
+ cur_k, cur_v = [x.view(-1, *x.shape[2:]) for x in [cur_k, cur_v]]
+ # only q[1] attn_out[1] and dout[1] need to be calculated
+ cur_q, cur_attn_out, cur_dout = [x[1] for x in [q, attn_out, dout]]
+ cur_softmax_max, cur_softmax_sum = [x[:, :, 1, :, :] for x in [softmax_max, softmax_sum]]
+
+ return cur_q, cur_k, cur_v, cur_attn_out, cur_dout, cur_softmax_max, cur_softmax_sum, cur_attn_mask
+
+
+def tnd_grad_update(q_block_id, kv_block_id, cur_attn_grads, global_attn_grads,
+ q_index, kv_index):
+ cur_dq, cur_dk, cur_dv = cur_attn_grads
+ dq, dk, dv = global_attn_grads
+ if q_block_id == kv_block_id:
+ dq.add_(cur_dq)
+ dk.add_(cur_dk)
+ dv.add_(cur_dv)
+ elif q_block_id > kv_block_id:
+ dq.add_(cur_dq)
+ dk.index_add_(0, kv_index, cur_dk)
+ dv.index_add_(0, kv_index, cur_dv)
+ else:
+ dq.index_add_(0, q_index, cur_dq)
+ dk.add_(cur_dk)
+ dv.add_(cur_dv)
+
+ return dq, dk, dv
+
+
+def causal_grad_update(q_block_id, kv_block_id, cur_dq, cur_dk, cur_dv, dq, dk, dv):
+ if q_block_id == kv_block_id:
+ cur_dq = cur_dq.view(dq.shape)
+ cur_dk = cur_dk.view(dk.shape)
+ cur_dv = cur_dv.view(dv.shape)
+ dq.add_(cur_dq)
+ dk.add_(cur_dk)
+ dv.add_(cur_dv)
+ elif q_block_id > kv_block_id:
+ cur_dq = cur_dq.view(dq.shape)
+ dq.add_(cur_dq)
+ dk[0].add_(cur_dk)
+ dv[0].add_(cur_dv)
+ else:
+ dq[1].add_(cur_dq)
+ cur_dk = cur_dk.view(dk.shape) # [2s, b, h] -> [2, s, b, h]
+ cur_dv = cur_dv.view(dv.shape)
+ dk.add_(cur_dk)
+ dv.add_(cur_dv)
+
+ return dq, dk, dv
+
+
+def cal_row(cur_q, cur_k, cur_v, s, attn_info):
+ # q: [s, b, h], kv: [2s, b, h]
+ n, pse, pse_type, attn_mask, softmax_scale, keep_prob, \
+ q_index_list, kv_index_list = attn_info
+
+ # r1c0
+ cur_attn_mask = None
+ attn_outs_r1c0 = npu_fusion_attention(
+ cur_q, cur_k[:s], cur_v[:s], n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[1] * s, ] if q_index_list is not None else q_index_list,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else kv_index_list
+ )
+ # r1c1
+ cur_attn_mask = attn_mask
+ attn_outs_r1c1 = npu_fusion_attention(
+ cur_q, cur_k[s:], cur_v[s:], n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[1] * s, ] if q_index_list is not None else q_index_list,
+ kv_start_idx=[kv_index_list[1] * s, ] if kv_index_list is not None else kv_index_list
+ )
+
+ # update row1
+ attn_out = attn_outs_r1c0[0]
+ softmax_max = attn_outs_r1c0[1]
+ softmax_sum = attn_outs_r1c0[2]
+ curr_attn_out = attn_outs_r1c1[0]
+ curr_softmax_max = attn_outs_r1c1[1]
+ curr_softmax_sum = attn_outs_r1c1[2]
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(attn_out, softmax_max, softmax_sum,
+ curr_attn_out, curr_softmax_max,
+ curr_softmax_sum)
+ return [attn_out_updated, softmax_max_updated, softmax_sum_updated]
+
+
+def flash_attention_with_alibi_pse(q_block_id, kv_block_id, cur_qkv, attn_info, s):
+ n, pse, pse_type, cur_attn_mask, softmax_scale, keep_prob, \
+ q_index_list, kv_index_list = attn_info
+ cur_q, cur_k, cur_v = cur_qkv
+ if q_block_id == kv_block_id:
+ attn_outs_r0c0 = npu_fusion_attention(
+ cur_q[:s], cur_k[:s], cur_v[:s], n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[0] * s, ] if q_index_list is not None else None,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else None,
+ )
+ attn_outs_r1 = cal_row(cur_q[s:], cur_k, cur_v, s, attn_info)
+ # get output
+ attn_outs = []
+ attn_outs.append(torch.cat([attn_outs_r0c0[0], attn_outs_r1[0]]))
+ attn_outs.append(torch.cat([attn_outs_r0c0[1], attn_outs_r1[1]], dim=2))
+ attn_outs.append(torch.cat([attn_outs_r0c0[2], attn_outs_r1[2]], dim=2))
+ elif q_block_id > kv_block_id:
+ attn_outs_r0c0 = npu_fusion_attention(
+ cur_q[:s], cur_k, cur_v, n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[0] * s, ] if q_index_list is not None else None,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else None,
+ )
+ attn_outs_r1c0 = npu_fusion_attention(
+ cur_q[s:], cur_k, cur_v, n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[1] * s, ] if q_index_list is not None else None,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else None,
+ )
+ # get output
+ attn_outs = []
+ attn_outs.append(torch.cat([attn_outs_r0c0[0], attn_outs_r1c0[0]]))
+ attn_outs.append(torch.cat([attn_outs_r0c0[1], attn_outs_r1c0[1]], dim=2))
+ attn_outs.append(torch.cat([attn_outs_r0c0[2], attn_outs_r1c0[2]], dim=2))
+ else:
+ attn_outs = cal_row(cur_q, cur_k, cur_v, s, attn_info)
+
+ return attn_outs
+
+
+def cal_row_grad(cur_q, cur_k, cur_v, cur_dout, cur_softmax_max, cur_softmax_sum, cur_attn_out,
+ attn_grad_info, s, kv_block_id):
+ n, pse, pse_type, attn_mask, softmax_scale, keep_prob, rng_states, \
+ q_index_list, kv_index_list = attn_grad_info
+
+ cur_attn_mask = None
+ attn_grad_outs_r1c0 = npu_fusion_attention_grad(
+ cur_q, cur_k[:s], cur_v[:s], cur_dout, n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ softmax_max=cur_softmax_max,
+ softmax_sum=cur_softmax_sum,
+ attention_in=cur_attn_out,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[1] * s, ] if q_index_list is not None else q_index_list,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else kv_index_list
+ )
+
+ cur_attn_mask = attn_mask
+ attn_grad_outs_r1c1 = npu_fusion_attention_grad(
+ cur_q, cur_k[s:], cur_v[s:], cur_dout, n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ softmax_max=cur_softmax_max,
+ softmax_sum=cur_softmax_sum,
+ attention_in=cur_attn_out,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[1] * s, ] if q_index_list is not None else q_index_list,
+ kv_start_idx=[kv_index_list[1] * s, ] if kv_index_list is not None else kv_index_list
+ )
+
+ return attn_grad_outs_r1c0, attn_grad_outs_r1c1
+
+
+def flash_attention_with_alibi_pse_grad(q_block_id, kv_block_id, cur_qkv, cur_dout, cur_attn_out,
+ cur_softmax_max, cur_softmax_sum, attn_grad_info, s):
+ n, pse, pse_type, cur_attn_mask, softmax_scale, keep_prob, rng_states, \
+ q_index_list, kv_index_list = attn_grad_info
+ cur_q, cur_k, cur_v = cur_qkv
+
+ if q_block_id == kv_block_id:
+ attn_grad_outs_r0c0 = npu_fusion_attention_grad(
+ cur_q[:s], cur_k[:s], cur_v[:s], cur_dout[:s], n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ softmax_max=cur_softmax_max[:, :, :s],
+ softmax_sum=cur_softmax_sum[:, :, :s],
+ attention_in=cur_attn_out[:s],
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[0] * s, ] if q_index_list is not None else q_index_list,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else kv_index_list
+ )
+ attn_grad_outs_r1c0, attn_grad_outs_r1c1 = cal_row_grad(
+ cur_q[s:], cur_k, cur_v, cur_dout[s:], cur_softmax_max[:, :, s:], cur_softmax_sum[:, :, s:],
+ cur_attn_out[s:], attn_grad_info, s, kv_block_id
+ )
+ attn_grad_outs = []
+ attn_grad_outs.append(torch.cat(
+ [attn_grad_outs_r0c0[0], attn_grad_outs_r1c0[0] + attn_grad_outs_r1c1[0]]))
+ attn_grad_outs.append(torch.cat(
+ [attn_grad_outs_r0c0[1] + attn_grad_outs_r1c0[1], attn_grad_outs_r1c1[1]]))
+ attn_grad_outs.append(torch.cat(
+ [attn_grad_outs_r0c0[2] + attn_grad_outs_r1c0[2], attn_grad_outs_r1c1[2]]))
+
+ elif q_block_id > kv_block_id:
+ attn_grad_outs_r0c0 = npu_fusion_attention_grad(
+ cur_q[:s], cur_k, cur_v, cur_dout[:s], n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ softmax_max=cur_softmax_max[:, :, :s],
+ softmax_sum=cur_softmax_sum[:, :, :s],
+ attention_in=cur_attn_out[:s],
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[0] * s, ] if q_index_list is not None else q_index_list,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else kv_index_list
+ )
+ attn_grad_outs_r1c0 = npu_fusion_attention_grad(
+ cur_q[s:], cur_k, cur_v, cur_dout[s:], n, 'SBH',
+ pse=pse,
+ pse_type=pse_type,
+ padding_mask=None,
+ softmax_max=cur_softmax_max[:, :, s:],
+ softmax_sum=cur_softmax_sum[:, :, s:],
+ attention_in=cur_attn_out[s:],
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tokens=s,
+ next_tokens=0 if cur_attn_mask is not None else s,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ q_start_idx=[q_index_list[1] * s, ] if q_index_list is not None else q_index_list,
+ kv_start_idx=[kv_index_list[0] * s, ] if kv_index_list is not None else kv_index_list
+ )
+ attn_grad_outs = []
+ attn_grad_outs.append(torch.cat([attn_grad_outs_r0c0[0], attn_grad_outs_r1c0[0]]))
+ attn_grad_outs.append(attn_grad_outs_r0c0[1] + attn_grad_outs_r1c0[1])
+ attn_grad_outs.append(attn_grad_outs_r0c0[2] + attn_grad_outs_r1c0[2])
+
+ else:
+ attn_grad_outs_r1c0, attn_grad_outs_r1c1 = cal_row_grad(
+ cur_q, cur_k, cur_v, cur_dout, cur_softmax_max, cur_softmax_sum, cur_attn_out,
+ attn_grad_info, s, kv_block_id
+ )
+ attn_grad_outs = []
+ attn_grad_outs.append(attn_grad_outs_r1c0[0] + attn_grad_outs_r1c1[0])
+ attn_grad_outs.append(torch.cat([attn_grad_outs_r1c0[1], attn_grad_outs_r1c1[1]]))
+ attn_grad_outs.append(torch.cat([attn_grad_outs_r1c0[2], attn_grad_outs_r1c1[2]]))
+
+
+ return attn_grad_outs
+
+
+
+
+class AttentionWithCp(torch.autograd.Function):
+ """Attention implementation with context parallelism"""
+
+
+ @staticmethod
+ def forward(ctx, q, k, v, n, cp_para, softmax_scale=None, attn_mask=None, dropout_p=0.,
+ packed_seq_params=None):
+ keep_prob = 1. - dropout_p
+ causal = cp_para['causal']
+ cp_group = cp_para.get("cp_group")
+ cp_size = cp_para.get("cp_size")
+ rank = cp_para.get("rank")
+ cp_global_ranks = cp_para.get("cp_global_ranks")
+ cp_group_for_send_recv_overlap = cp_para.get("cp_group_for_send_recv_overlap")
+ # WARNING: Degrade to original ring attention, if ranks and comm groups for double ring are not provided
+ cp_inner_ranks = cp_para.get("cp_inner_ranks", [torch.distributed.get_rank()])
+ cp_outer_ranks = cp_para.get("cp_outer_ranks", cp_global_ranks)
+ cp_group_for_intra_window = cp_para.get('cp_group_for_intra_window')
+ cp_group_for_intra_window_send_recv_overlap = cp_para.get('cp_group_for_intra_window_send_recv_overlap')
+ megatron_cp_in_bnsd = cp_para.get('megatron_cp_in_bnsd')
+
+ pse = cp_para.get("pse")
+ pse_type = cp_para.get("pse_type")
+
+ cache_policy = cp_para.get("cache_policy")
+
+ inner_ring = RingP2P(cp_inner_ranks, cp_group_for_intra_window, cp_group_for_intra_window_send_recv_overlap)
+ outer_ring = RingP2P(cp_outer_ranks, cp_group, cp_group_for_send_recv_overlap)
+ inner_size = len(cp_inner_ranks)
+ outer_size = cp_size // inner_size
+
+ actual_seq_kvlen = packed_seq_params.cu_seqlens_q.tolist() if packed_seq_params else None
+ actual_seq_qlen = packed_seq_params.cu_seqlens_kv.tolist() if packed_seq_params else None
+ is_eod_reset = (actual_seq_kvlen is not None) and (actual_seq_qlen is not None)
+ seq_len, bsz, hidden = q.shape
+
+ if softmax_scale is None:
+ head_dim = q.shape[-1] // n
+ softmax_scale = head_dim ** (-0.5)
+ if causal and attn_mask is None:
+ attn_mask = torch.ones((2048, 2048), dtype=torch.bool, device=q.device)
+ attn_mask = torch.triu(attn_mask, diagonal=1)
+
+ if causal:
+ if is_eod_reset:
+ # SBH -> TND
+ # fa varlen mode require TND layout
+ q, k, v = [sbh_to_tnd(x, n) for x in [q, k, v]]
+
+ # only first half of each sub sequence KV block need to be calculated when i <= rank
+ kv_index = packed_seq_params.kv_index
+ # only last half of each sub sequence q block need to be calculated when i > rank
+ q_index = packed_seq_params.q_index
+
+ sub_out_seq_len = (torch.tensor([0] + actual_seq_qlen)[1:] - torch.tensor([0] + actual_seq_qlen)[:-1]).tolist()
+ seq_lens = (actual_seq_qlen, actual_seq_kvlen, sub_out_seq_len)
+ half_seq_lens = [[x // 2 for x in lst] for lst in seq_lens]
+ fetch_ptrs = (seq_lens, half_seq_lens, q_index, kv_index)
+
+ softmax_indices = get_selection_indices_for_tnd_softmax_update(q.shape[0], q.shape[1], half_seq_lens[2]).to(q.device)
+ else:
+ # split chunk[i]~chunk[cp_size-i-1] into chunk[i] and chunk[cp_size-i-1],, [2s, b, h] -> [2, s, b, h]
+ q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
+ cur_kv = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) # [2, 2, s, b, h]
+ next_kv = torch.empty_like(cur_kv)
+ next_round_kv = torch.empty_like(cur_kv)
+ attn_out, softmax_max, softmax_sum = None, None, None
+ # (seed, offset, numels) for dropout mask
+ rng_states = [[0, 0, 0] for _ in range(cp_size)]
+ global_attn_outs = [attn_out, softmax_max, softmax_sum, rng_states]
+ q_block_id, kv_block_id, kv_block_id_outer = rank, rank, rank
+
+ # kv cache list
+ k_cache_list = []
+ v_cache_list = []
+
+ for j in range(outer_size):
+ kv_block_id = kv_block_id_outer
+ kv_block_offset = (kv_block_id // inner_size) * inner_size
+ if j < outer_size - 1:
+ outer_ring.async_send_recv(send_tensor=cur_kv, recv_tensor=next_round_kv)
+ for i in range(inner_size):
+ # wait until KV is received from recv_src
+ if i < inner_size - 1:
+ inner_ring.async_send_recv(send_tensor=cur_kv, recv_tensor=next_kv)
+
+ cur_k, cur_v = cur_kv[0], cur_kv[1] # [2, s, b, h]
+
+ # cache kv or k
+ if j * inner_size + i + 2 != cp_size:
+ if cache_policy == "full":
+ k_cache_list.append(cur_kv[0].clone())
+ v_cache_list.append(cur_kv[1].clone())
+ elif cache_policy == "half":
+ k_cache_list.append(cur_kv[0].clone())
+
+ if causal:
+ # flash attention forward
+ cur_sub_out_seq_len = None
+ attn_outs = None
+ if pse is None:
+ if is_eod_reset:
+ cur_q, cur_k, cur_v, cur_attn_mask, cur_seq_lens = tnd_forward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v,
+ fetch_ptrs, attn_mask)
+ cur_seq_qlen, cur_seq_kvlen, cur_sub_out_seq_len = cur_seq_lens
+ # flash attention forward
+ attn_outs = torch_npu.npu_fusion_attention(
+ cur_q, cur_k, cur_v, n, "TND",
+ pse=None,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tockens=cur_k.shape[0],
+ next_tockens=0 if cur_attn_mask is not None else cur_k.shape[0],
+ keep_prob=keep_prob,
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ actual_seq_qlen=cur_seq_qlen,
+ actual_seq_kvlen=cur_seq_kvlen
+ )
+ else:
+ cur_q, cur_k, cur_v, cur_attn_mask = causal_forward_fetch(q_block_id, kv_block_id,
+ q, cur_k, cur_v, attn_mask)
+
+ layout = "SBH"
+ pre_tockens_value = cur_k.shape[0]
+ if megatron_cp_in_bnsd:
+ cur_q = rearrange(cur_q, 's b (h d) -> b h s d', h=n).contiguous()
+ kv_n = cur_v.shape[2] // cur_q.shape[3]
+ cur_k, cur_v = [rearrange(x, 's b (h d) -> b h s d', h=kv_n).contiguous() for x in [cur_k, cur_v]]
+ layout = "BNSD"
+ pre_tockens_value = cur_k.shape[2]
+
+ attn_outs = torch_npu.npu_fusion_attention(
+ cur_q, cur_k, cur_v, n, layout,
+ pse=None,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ scale=softmax_scale,
+ pre_tockens=pre_tockens_value,
+ next_tockens=0 if cur_attn_mask is not None else pre_tockens_value,
+ keep_prob=keep_prob,
+ sparse_mode=3 if cur_attn_mask is not None else 0
+ )
+ if megatron_cp_in_bnsd:
+ attn_outs = rearrange(attn_outs[0], 'b h s d -> s b (h d)').contiguous(), attn_outs[1], attn_outs[2]
+ else:
+ cur_q, cur_k, cur_v, cur_attn_mask = causal_forward_fetch(q_block_id, kv_block_id,
+ q, cur_k, cur_v, attn_mask)
+ q_index_list = [q_block_id, cp_size * 2 - 1 - q_block_id]
+ kv_index_list = [kv_block_id, cp_size * 2 - 1 - kv_block_id]
+ attn_info = [n, pse, pse_type, cur_attn_mask, softmax_scale, keep_prob,
+ q_index_list, kv_index_list]
+ s = q.shape[1]
+ attn_outs = flash_attention_with_alibi_pse(
+ q_block_id, kv_block_id,
+ (cur_q, cur_k, cur_v),
+ attn_info,
+ s
+ )
+ if is_eod_reset:
+ global_attn_outs = tnd_out_update(q_block_id, kv_block_id, attn_outs, global_attn_outs,
+ q_index, softmax_indices, cur_sub_out_seq_len)
+ else:
+ global_attn_outs = causal_out_update(q_block_id, kv_block_id, attn_outs, global_attn_outs)
+ else:
+ # [2s, b, h], [b, n, 2s, 8], [b, n, 2s, 8]
+ this_mask = AttentionWithCp.compute_mask(
+ actual_seq_qlen, actual_seq_kvlen,
+ q_block_id, kv_block_id,
+ attn_mask
+ )
+
+ attn_outs = torch_npu.npu_fusion_attention(
+ q, cur_k, cur_v, n, "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=this_mask,
+ scale=softmax_scale,
+ pre_tockens=cur_k.shape[0],
+ next_tockens=cur_k.shape[0],
+ keep_prob=keep_prob,
+ sparse_mode=1
+ )
+
+ global_attn_outs = general_out_update(q_block_id, kv_block_id, attn_outs, global_attn_outs)
+
+ if inner_ring.wait():
+ cur_kv, next_kv = next_kv, cur_kv # double buffer
+ kv_block_id = (kv_block_id + inner_size - 1) % inner_size + kv_block_offset
+
+ if outer_ring.wait():
+ cur_kv, next_round_kv = next_round_kv, cur_kv # double buffer
+ kv_block_id_outer = (kv_block_id_outer + cp_size - inner_size) % cp_size
+
+ k_cache_list = k_cache_list if k_cache_list else [cur_kv[0].clone()]
+ v_cache_list = v_cache_list if v_cache_list else [cur_kv[1].clone()]
+ attn_mask = attn_mask if isinstance(attn_mask, list) else [attn_mask]
+
+ attn_out, softmax_max, softmax_sum, rng_states = global_attn_outs
+
+ if causal and not is_eod_reset:
+ q = q.view(-1, *q.shape[2:])
+ k_cache_list = [x.view(-1, *x.shape[2:]) for x in k_cache_list]
+ v_cache_list = [x.view(-1, *x.shape[2:]) for x in v_cache_list]
+
+ k_stack = torch.stack(k_cache_list)
+ v_stack = torch.stack(v_cache_list)
+
+ ctx.save_for_backward(q, k_stack, v_stack, *attn_mask, attn_out, softmax_max, softmax_sum)
+ ctx.n = n
+ ctx.causal = causal
+ ctx.softmax_scale = softmax_scale
+ ctx.cp_group = cp_group
+ ctx.cp_size = cp_size
+ ctx.cp_rank = rank
+ ctx.cp_global_ranks = cp_global_ranks
+ ctx.cp_inner_ranks = cp_inner_ranks
+ ctx.cp_outer_ranks = cp_outer_ranks
+ ctx.cp_dkv_outer_ranks = cp_para.get('cp_dkv_outer_ranks', cp_global_ranks)
+ ctx.kv_block_id = kv_block_id
+ ctx.keep_prob = keep_prob
+ ctx.rng_states = rng_states
+ ctx.pse = pse
+ ctx.pse_type = pse_type
+ ctx.cp_group_for_send_recv_overlap = cp_group_for_send_recv_overlap
+ ctx.cp_group_for_intra_window = cp_group_for_intra_window
+ ctx.cp_group_for_intra_window_send_recv_overlap = cp_group_for_intra_window_send_recv_overlap
+ ctx.actual_seq_qlen = actual_seq_qlen
+ ctx.actual_seq_kvlen = actual_seq_kvlen
+ ctx.is_eod_reset = is_eod_reset
+ ctx.megatron_cp_in_bnsd = megatron_cp_in_bnsd
+ ctx.bsz = bsz
+ ctx.cache_policy = cache_policy
+
+ if causal and is_eod_reset:
+ ctx.q_index = q_index
+ ctx.kv_index = kv_index
+ ctx.half_actual_seq_qlen = half_seq_lens[0]
+ ctx.half_actual_seq_kvlen = half_seq_lens[1]
+ ctx.half_sub_out_seq_len = half_seq_lens[2]
+ ctx.sub_out_seq_len = sub_out_seq_len
+ ctx.softmax_indices = softmax_indices
+ return tnd_to_sbh(attn_out, bsz)
+
+ return attn_out
+
+ @staticmethod
+ def backward(ctx, dout):
+ q, k_stack, v_stack, *attn_mask, attn_out, softmax_max, softmax_sum = ctx.saved_tensors
+ attn_mask = attn_mask[0] if len(attn_mask) == 1 else attn_mask
+
+ n = ctx.n
+ causal = ctx.causal
+ softmax_scale = ctx.softmax_scale
+ cp_group = ctx.cp_group
+ cp_size = ctx.cp_size
+ rank = ctx.cp_rank
+ keep_prob = ctx.keep_prob
+ rng_states = ctx.rng_states
+ pse = ctx.pse
+ pse_type = ctx.pse_type
+ megatron_cp_in_bnsd = ctx.megatron_cp_in_bnsd
+ cp_group_for_send_recv_overlap = ctx.cp_group_for_send_recv_overlap
+ cp_group_for_intra_window = ctx.cp_group_for_intra_window
+ cp_group_for_intra_window_send_recv_overlap = ctx.cp_group_for_intra_window_send_recv_overlap
+ cache_policy = ctx.cache_policy
+ is_eod_reset = ctx.is_eod_reset
+ if causal and is_eod_reset:
+ dout = sbh_to_tnd(dout, n)
+ # Reversed order of forward
+ inner_size = len(ctx.cp_inner_ranks)
+ outer_size = len(ctx.cp_outer_ranks)
+
+ intra_kv_comm = RingP2P(ctx.cp_inner_ranks, cp_group_for_intra_window, cp_group_for_intra_window_send_recv_overlap, is_backward=True)
+ intra_dkv_comm = RingP2P(ctx.cp_inner_ranks, cp_group_for_intra_window, cp_group_for_intra_window_send_recv_overlap, is_backward=True)
+ inter_kv_comm = RingP2P(ctx.cp_outer_ranks, cp_group, cp_group_for_send_recv_overlap, is_backward=True)
+ inter_dkv_comm = RingP2P(ctx.cp_dkv_outer_ranks, cp_group, cp_group_for_send_recv_overlap, is_backward=True)
+
+
+ if causal:
+ if is_eod_reset:
+ half_softmax_max = softmax_max.view(-1, 8)[ctx.softmax_indices].view(-1, n, 8)
+ half_softmax_sum = softmax_sum.view(-1, 8)[ctx.softmax_indices].view(-1, n, 8)
+ else:
+ # split chunk[i]~chunk[cp_size-i-1] into chunk[i] and chunk[cp_size-i-1], [2s, b, h] -> [2, s, b, h]
+ q, attn_out, dout = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, attn_out, dout]]
+ k_stack = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in k_stack]
+ v_stack = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in v_stack]
+ # [b, n, 2s, 8] -> [b, n, 2, s, 8]
+ softmax_max = softmax_max.view(softmax_max.shape[0], softmax_max.shape[1],
+ 2, softmax_max.shape[2] // 2, softmax_max.shape[-1])
+ softmax_sum = softmax_sum.view(softmax_sum.shape[0], softmax_sum.shape[1],
+ 2, softmax_sum.shape[2] // 2, softmax_sum.shape[-1])
+
+ def backward_step_helper(q_block_id, kv_block_id, q, cur_k, cur_v):
+ if causal:
+ if pse is None:
+ # flash attention backward
+ if is_eod_reset:
+ softmax_values = (softmax_max, softmax_sum, half_softmax_max, half_softmax_sum)
+ seq_lens = (ctx.actual_seq_qlen, ctx.actual_seq_kvlen, ctx.half_actual_seq_qlen, ctx.half_actual_seq_kvlen)
+ index_values = (ctx.q_index, ctx.kv_index)
+ step_inputs = tnd_backward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v, attn_out, dout,
+ softmax_values, seq_lens, index_values, attn_mask=attn_mask)
+ qkv, cur_attn_out, cur_dout, cur_softmax_values, cur_attn_mask, cur_seq_lens = step_inputs
+ cur_q, cur_k, cur_v = qkv
+ cur_softmax_max, cur_softmax_sum = cur_softmax_values
+ cur_seq_qlen, cur_seq_kvlen = cur_seq_lens
+
+ # flash attention backward
+ attn_grad_outs = torch_npu.npu_fusion_attention_grad(
+ cur_q, cur_k, cur_v, cur_dout, n,
+ "TND",
+ pse=None,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ softmax_max=cur_softmax_max,
+ softmax_sum=cur_softmax_sum,
+ attention_in=cur_attn_out,
+ scale_value=softmax_scale,
+ pre_tockens=cur_k.shape[0],
+ next_tockens=0 if cur_attn_mask is not None else cur_k.shape[0],
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ actual_seq_qlen=cur_seq_qlen,
+ actual_seq_kvlen=cur_seq_kvlen,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ )
+ else:
+ step_inputs = causal_backward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v, attn_out, dout,
+ softmax_max, softmax_sum, attn_mask=attn_mask)
+ cur_q, cur_k, cur_v, cur_attn_out, cur_dout, cur_softmax_max, cur_softmax_sum, cur_attn_mask = step_inputs
+ layout = "SBH"
+ pre_tockens_value = cur_k.shape[0]
+ if megatron_cp_in_bnsd:
+ cur_q, cur_dout, cur_attn_out = [rearrange(x, 's b (h d) -> b h s d', h=n).contiguous() for x in [cur_q, cur_dout, cur_attn_out]]
+ kv_n = cur_v.shape[2] // cur_q.shape[3]
+ cur_k, cur_v = [rearrange(x, 's b (h d) -> b h s d', h=kv_n).contiguous() for x in [cur_k, cur_v]]
+ layout = "BNSD"
+ pre_tockens_value = cur_k.shape[2]
+
+ attn_grad_outs = torch_npu.npu_fusion_attention_grad(
+ cur_q, cur_k, cur_v, cur_dout, n,
+ layout,
+ pse=None,
+ padding_mask=None,
+ atten_mask=cur_attn_mask,
+ softmax_max=cur_softmax_max,
+ softmax_sum=cur_softmax_sum,
+ attention_in=cur_attn_out,
+ scale_value=softmax_scale,
+ pre_tockens=pre_tockens_value,
+ next_tockens=0 if cur_attn_mask is not None else pre_tockens_value,
+ sparse_mode=3 if cur_attn_mask is not None else 0,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ )
+ if megatron_cp_in_bnsd:
+ attn_grad_outs = [rearrange(x, 'b h s d -> s b (h d)').contiguous() for x in [attn_grad_outs[0], attn_grad_outs[1], attn_grad_outs[2]]]
+ else:
+ step_inputs = causal_backward_fetch(q_block_id, kv_block_id, q, cur_k, cur_v, attn_out, dout,
+ softmax_max, softmax_sum, attn_mask=attn_mask)
+ cur_q, cur_k, cur_v, cur_attn_out, cur_dout, cur_softmax_max, cur_softmax_sum, cur_attn_mask = step_inputs
+ q_index_list = [q_block_id, cp_size * 2 - 1 - q_block_id]
+ kv_index_list = [kv_block_id, cp_size * 2 - 1 - kv_block_id]
+ attn_grad_info = [n, pse, pse_type, cur_attn_mask, softmax_scale, keep_prob, rng_states,
+ q_index_list, kv_index_list]
+ s = q.shape[1]
+ attn_grad_outs = flash_attention_with_alibi_pse_grad(
+ q_block_id, kv_block_id,
+ (cur_q, cur_k, cur_v), cur_dout, cur_attn_out,
+ cur_softmax_max, cur_softmax_sum,
+ attn_grad_info, s
+ )
+
+ cur_dq, cur_dk, cur_dv = attn_grad_outs[0], attn_grad_outs[1], attn_grad_outs[2]
+
+
+ else:
+ this_mask = AttentionWithCp.compute_mask(
+ ctx.actual_seq_qlen, ctx.actual_seq_kvlen,
+ q_block_id, kv_block_id,
+ attn_mask
+ )
+ attn_grad_outs = torch_npu.npu_fusion_attention_grad(
+ q, cur_k, cur_v, dout, n,
+ "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=this_mask,
+ softmax_max=softmax_max,
+ softmax_sum=softmax_sum,
+ attention_in=attn_out,
+ scale_value=softmax_scale,
+ pre_tockens=cur_k.shape[0],
+ next_tockens=cur_k.shape[0],
+ sparse_mode=1,
+ keep_prob=keep_prob,
+ seed=rng_states[kv_block_id][0],
+ offset=rng_states[kv_block_id][1],
+ numels=rng_states[kv_block_id][2],
+ )
+ cur_dq, cur_dk, cur_dv = attn_grad_outs[0], attn_grad_outs[1], attn_grad_outs[2]
+
+ return cur_dq, cur_dk, cur_dv
+
+
+ cur_dkv = torch.zeros((2, *k_stack[-1].shape), dtype=k_stack[-1].dtype, device=k_stack[-1].device)
+ next_dkv = cur_dkv.clone()
+ next_round_dkv = cur_dkv.clone()
+
+ q_block_id, kv_block_id, kv_block_id_outer = rank, ctx.kv_block_id, ctx.kv_block_id
+
+ outer_data = (outer_size, inter_kv_comm)
+ inner_data = (inner_size, intra_kv_comm)
+ cp_kv_cache = ContextParallelKVCache(cache_policy, outer_data, inner_data, k_stack, v_stack)
+
+ dq = torch.zeros_like(q) # [2, s, b, h]
+ for j in range(outer_size):
+ kv_block_id = kv_block_id_outer
+ kv_block_offset = (kv_block_id // inner_size) * inner_size
+
+ cp_kv_cache.communicate_outer_ring_kv(j)
+
+ for i in range(inner_size):
+ cur_k, cur_v = cp_kv_cache.communicate_inner_ring_kv(i)
+
+ dq_step, dk_step, dv_step = backward_step_helper(q_block_id, kv_block_id, q, cur_k, cur_v)
+
+ if i == 0 and j > 0: # receive dk dv from last window
+ inter_dkv_comm.wait()
+ cur_dkv, next_round_dkv = next_round_dkv, cur_dkv
+ elif i > 0: # receive dk dv from last step
+ intra_dkv_comm.wait()
+ cur_dkv, next_dkv = next_dkv, cur_dkv
+
+ dk, dv = cur_dkv[0], cur_dkv[1]
+ # update qkv grades
+ if is_eod_reset and causal:
+ tnd_grad_update(q_block_id, kv_block_id, (dq_step, dk_step, dv_step), (dq, dk, dv),
+ ctx.q_index, ctx.kv_index)
+ elif causal:
+ causal_grad_update(q_block_id, kv_block_id, dq_step, dk_step, dv_step, dq, dk, dv)
+ else:
+ dq.add_(dq_step)
+ dk.add_(dk_step)
+ dv.add_(dv_step)
+
+ if i + 1 != inner_size:
+ intra_dkv_comm.async_send_recv(send_tensor=cur_dkv, recv_tensor=next_dkv)
+
+ kv_block_id = (kv_block_id + 1) % inner_size + kv_block_offset
+
+ if intra_dkv_comm.wait():
+ cur_dkv, next_dkv = next_dkv, cur_dkv
+
+ if j + 1 != outer_size:
+ inter_dkv_comm.async_send_recv(send_tensor=cur_dkv, recv_tensor=next_round_dkv)
+
+ kv_block_id_outer = (kv_block_id_outer + inner_size) % cp_size
+
+ if inter_dkv_comm.wait():
+ cur_dkv, next_round_dkv = next_round_dkv, cur_dkv
+
+ dk, dv = cur_dkv[0], cur_dkv[1]
+
+
+ # [2, s, b, h] -> [2s, b, h]
+ if causal and is_eod_reset:
+ dq, dk, dv = [tnd_to_sbh(x, ctx.bsz) for x in [dq, dk, dv]]
+ elif causal:
+ dq, dk, dv = [x.view(-1, *x.shape[2:]) for x in [dq, dk, dv]]
+ return dq, dk, dv, None, None, None, None, None, None, None
+
+ @classmethod
+ def compute_mask(cls, actual_seq_qlen, actual_seq_kvlen, q_block_id, kv_block_id, attn_mask):
+ from bisect import bisect_right
+ from mindspeed.utils import batch_index
+
+ if actual_seq_qlen:
+ seq_len = actual_seq_qlen[-1] // AttentionWithCp.batch_size
+ actual_seq_qlen = batch_index(actual_seq_qlen, seq_len)
+ actual_seq_kvlen = batch_index(actual_seq_kvlen, seq_len)
+ block_size = cls.block_size
+ actual_seq_qlen = [[0] + lst for lst in actual_seq_qlen]
+ sub_seq_qlen = [torch.tensor(x[1:]) - torch.tensor(x[:-1]) for x in actual_seq_qlen]
+ sub_seq_qid = torch.stack([torch.arange(len(lst)).repeat_interleave(lst) for lst in sub_seq_qlen]).npu() # B S
+
+ this_ids = sub_seq_qid[:, q_block_id * block_size:(q_block_id + 1) * block_size].npu()
+ this_tile = this_ids.unsqueeze(dim=2) # B S 1
+
+ actual_seq_kvlen = [[0] + lst for lst in actual_seq_kvlen]
+ sub_seq_kvlen = [torch.tensor(x[1:]) - torch.tensor(x[:-1]) for x in actual_seq_kvlen]
+ sub_seq_kvid = torch.stack([torch.arange(len(lst)).repeat_interleave(lst) for lst in sub_seq_kvlen]).npu() # B S
+ other_ids = sub_seq_kvid[:, kv_block_id * block_size:(kv_block_id + 1) * block_size].npu()
+ other_tile = other_ids.unsqueeze(dim=1) # B 1 S
+
+ mask = this_tile == other_tile # B S S
+ if kv_block_id > q_block_id:
+ mask = torch.zeros_like(mask)
+ elif kv_block_id == q_block_id:
+ mask = torch.tril(mask)
+
+ return torch.logical_not(mask).unsqueeze(dim=1).npu() # B 1 S S
+ else:
+ return attn_mask[kv_block_id] if isinstance(attn_mask, list) else None
+
+
+def ringattn_context_parallel(q, k, v, n, cp_para, softmax_scale=None, attn_mask=None, dropout_p=0.,
+ packed_seq_params=None):
+ AttentionWithCp.block_size = q.shape[0]
+ AttentionWithCp.batch_size = q.shape[1]
+ out = AttentionWithCp.apply(
+ q, k, v, n, cp_para, softmax_scale, attn_mask, dropout_p,
+ packed_seq_params
+ )
+ return out
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/ring_context_parallel_for_ampipe.py b/model/train/yoco_moe/mindspeed/core/context_parallel/ring_context_parallel_for_ampipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a69a246eb0b5f8f8500c04ab07d6fe025e15d0e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/context_parallel/ring_context_parallel_for_ampipe.py
@@ -0,0 +1,503 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch_npu
+from megatron.training import get_args
+from megatron.core import mpu
+
+from mindspeed.core.parallel_state import get_context_parallel_prev_rank, get_context_parallel_next_rank
+from .utils import RingP2P, forward_update
+
+
+def flash_attention_backward(qkvn, dy, softmax_max, softmax_sum, atten_out, attn_mask=None, softmax_scale=1.,
+ keep_prob=1., seed=0, offset=0, numels=0):
+ """FlashAttention backward"""
+ q, k, v, n = qkvn
+ next_tockens = 0 if attn_mask is not None else k.shape[0]
+ sparse_mode = 3 if attn_mask is not None else 0
+ output = torch_npu.npu_fusion_attention_grad(
+ q, k, v, dy, n,
+ "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=attn_mask,
+ softmax_max=softmax_max,
+ softmax_sum=softmax_sum,
+ attention_in=atten_out,
+ scale_value=softmax_scale,
+ pre_tockens=k.shape[0],
+ next_tockens=next_tockens,
+ sparse_mode=sparse_mode,
+ keep_prob=keep_prob,
+ seed=seed,
+ offset=offset,
+ numels=numels
+ )
+ return output
+
+
+def flash_attention_forward(qkvn, attn_mask=None, softmax_scale=1., keep_prob=1.):
+ """FlashAttention forward"""
+ q, k, v, n = qkvn
+ next_tockens = 0 if attn_mask is not None else k.shape[0]
+ sparse_mode = 3 if attn_mask is not None else 0
+
+ output = torch_npu.npu_fusion_attention(
+ q, k, v, n, "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=attn_mask,
+ scale=softmax_scale,
+ pre_tockens=k.shape[0],
+ next_tockens=next_tockens,
+ keep_prob=keep_prob,
+ sparse_mode=sparse_mode
+ )
+ return output
+
+
+def attn_with_cp_for_ampipe_forward(ctx, fa_cp_fwd_args,
+ fa_fwd_args,
+ dropout_p=0.):
+ args = get_args()
+ q, k, v = fa_cp_fwd_args.q, fa_cp_fwd_args.k, fa_cp_fwd_args.v
+ tensor_list, n = fa_fwd_args.flash_tensor_list, fa_fwd_args.head_num
+ kv_list, o_max_sum_list, ampipe_idx = fa_fwd_args.kv_list, fa_fwd_args.o_max_sum_list, fa_fwd_args.cur_degree
+ if kv_list is None:
+ kv_list = []
+ if o_max_sum_list is None:
+ o_max_sum_list = []
+ keep_prob = 1. - dropout_p
+ if args.ampipe_degree > 2:
+ raise RuntimeError(f"Context parallel only support ampipe_degree is 2, but got {args.ampipe_degree}")
+
+ head_dim = q.shape[-1] // n
+ softmax_scale = head_dim ** (-0.5)
+
+ rank = mpu.get_context_parallel_rank()
+ cp_global_ranks = mpu.get_context_parallel_global_ranks()
+ prev_rank = get_context_parallel_prev_rank()
+ next_rank = get_context_parallel_next_rank()
+ cp_size = mpu.get_context_parallel_world_size()
+ cp_group = mpu.get_context_parallel_group()
+ cp_group_for_send_recv_overlap = mpu.get_context_parallel_group_for_send_recv_overlap() if args.use_cp_send_recv_overlap else cp_group
+ send_recv_comm = RingP2P(cp_global_ranks, cp_group, cp_group_for_send_recv_overlap)
+ attn_mask = torch.ones((2048, 2048), dtype=torch.bool, device=q.device)
+ attn_mask = torch.triu(attn_mask, diagonal=1)
+ if ampipe_idx == 0:
+ # split chunk[i]~chunk[2cp-1-i] into chunk[i] and chunk[2cp-1-i],, [2s, b, h] -> [2, s, b, h]
+ q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
+ # (seed, offset, numels) for dropout mask
+ rng_states_qa_kva = [[0, 0, 0] for _ in range(cp_size)]
+ rng_states_qb_kva = [[0, 0, 0] for _ in range(cp_size)]
+ rng_states_qb_kvb = [[0, 0, 0] for _ in range(cp_size)]
+ send_kv = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) # [2, 2, s, b, h]
+ recv_kv = None
+ # chunk[i]
+ attn_out_a, softmax_max_a, softmax_sum_a = None, None, None
+ # chunk[2cp-1-i]
+ attn_out_b, softmax_max_b, softmax_sum_b = None, None, None
+
+ for i in range(cp_size):
+ # wait until KV is received from recv_src
+ if send_recv_comm.wait():
+ send_kv = recv_kv
+ kv_list.append(send_kv) # tmp buffer for next ampipe
+ if i < cp_size - 1:
+ recv_kv = torch.empty_like(send_kv)
+ send_recv_comm.async_send_recv(send_kv, recv_kv)
+ if i == 0:
+ qa, ka, va = [x[0] for x in [q, k, v]]
+ qb, kb, vb = [x[1] for x in [q, k, v]]
+
+ attn_outs_a = flash_attention_forward((qa, ka, va, n),
+ attn_mask=attn_mask, softmax_scale=softmax_scale,
+ keep_prob=keep_prob)
+ attn_outs_b = flash_attention_forward((qb, kb, vb, n),
+ attn_mask=attn_mask, softmax_scale=softmax_scale,
+ keep_prob=keep_prob)
+ attn_out_a, softmax_max_a, softmax_sum_a = attn_outs_a[0], attn_outs_a[1], attn_outs_a[2]
+ attn_out_b, softmax_max_b, softmax_sum_b = attn_outs_b[0], attn_outs_b[1], attn_outs_b[2]
+ # seed, offset, numels (for dropout)
+ rng_states_qa_kva[i] = (attn_outs_a[4], attn_outs_a[5], attn_outs_a[6])
+ rng_states_qb_kvb[i] = (attn_outs_b[4], attn_outs_b[5], attn_outs_b[6])
+ else:
+ cur_k, cur_v = send_kv[0], send_kv[1] # [2, s, b, h]
+
+ if i <= rank:
+ qa, ka, va = [x[0] for x in [q, cur_k, cur_v]]
+ attn_outs_a = flash_attention_forward((qa, ka, va, n),
+ attn_mask=None, softmax_scale=softmax_scale,
+ keep_prob=keep_prob)
+ cur_attn_out_a, cur_softmax_max_a, cur_softmax_sum_a = attn_outs_a[0], attn_outs_a[1], attn_outs_a[
+ 2]
+ rng_states_qa_kva[i] = (attn_outs_a[4], attn_outs_a[5], attn_outs_a[6])
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out_a, softmax_max_a, softmax_sum_a,
+ cur_attn_out_a, cur_softmax_max_a, cur_softmax_sum_a
+ )
+ attn_out_a, softmax_max_a, softmax_sum_a = attn_out_updated, softmax_max_updated, softmax_sum_updated
+ else:
+ kv_idx = i - rank - 1
+ kv = kv_list[kv_idx]
+ cur_k, cur_v = kv[0], kv[1]
+ qb = q[1]
+ ka, va = [x[0] for x in [cur_k, cur_v]]
+
+ attn_outs_b = flash_attention_forward((qb, ka, va, n),
+ attn_mask=None, softmax_scale=softmax_scale)
+ cur_attn_out_b, cur_softmax_max_b, cur_softmax_sum_b = attn_outs_b[0], attn_outs_b[1], attn_outs_b[
+ 2]
+ rng_states_qb_kva[kv_idx] = (attn_outs_b[4], attn_outs_b[5], attn_outs_b[6])
+
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out_b, softmax_max_b, softmax_sum_b,
+ cur_attn_out_b, cur_softmax_max_b, cur_softmax_sum_b
+ )
+ attn_out_b, softmax_max_b, softmax_sum_b = attn_out_updated, softmax_max_updated, softmax_sum_updated
+
+ attn_out_all = torch.cat((attn_out_a.unsqueeze(0), attn_out_b.unsqueeze(0)), dim=0)
+ softmax_max_all = torch.cat((softmax_max_a.unsqueeze(0), softmax_max_b.unsqueeze(0)), dim=0)
+ softmax_sum_all = torch.cat((softmax_sum_a.unsqueeze(0), softmax_sum_b.unsqueeze(0)), dim=0)
+ o_max_sum_list.append(attn_out_all)
+ o_max_sum_list.append(softmax_max_all)
+ o_max_sum_list.append(softmax_sum_all)
+
+ k, v = send_kv[0], send_kv[1]
+ q, k, v = [x.view(-1, *x.shape[2:]) for x in [q, k, v]] # [2s, b, h]
+ attn_out = attn_out_a
+ else:
+ q = q.view(2, q.shape[0] // 2, *q.shape[1:])
+ qb = q[1]
+ attn_out_all, softmax_max_all, softmax_sum_all = o_max_sum_list
+ attn_out_b, softmax_max_b, softmax_sum_b = attn_out_all[1], softmax_max_all[1], softmax_sum_all[1]
+ rng_states_qa_kva = ctx.rng_states_qa_kva
+ rng_states_qb_kva = ctx.rng_states_qb_kva
+ rng_states_qb_kvb = ctx.rng_states_qb_kvb
+
+ start_a_idx = cp_size - rank - 1
+ start_b_idx = rank + 1
+
+ for i in range(cp_size):
+ cur_kv = kv_list[i]
+ cur_k, cur_v = cur_kv[0], cur_kv[1]
+ if i >= start_a_idx:
+ ka, va = cur_k[0], cur_v[0]
+
+ attn_outs_b = flash_attention_forward((qb, ka, va, n),
+ attn_mask=None, softmax_scale=softmax_scale)
+ cur_attn_out_b, cur_softmax_max_b, cur_softmax_sum_b = attn_outs_b[0], attn_outs_b[1], attn_outs_b[2]
+ rng_states_qb_kva[i] = (attn_outs_b[4], attn_outs_b[5], attn_outs_b[6])
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out_b, softmax_max_b, softmax_sum_b,
+ cur_attn_out_b, cur_softmax_max_b, cur_softmax_sum_b
+ )
+ attn_out_b, softmax_max_b, softmax_sum_b = attn_out_updated, softmax_max_updated, softmax_sum_updated
+ if i >= start_b_idx:
+ kb, vb = cur_k[1], cur_v[1]
+ attn_outs_b = flash_attention_forward((qb, kb, vb, n),
+ attn_mask=None, softmax_scale=softmax_scale)
+ cur_attn_out_b, cur_softmax_max_b, cur_softmax_sum_b = attn_outs_b[0], attn_outs_b[1], attn_outs_b[2]
+ rng_states_qb_kvb[i] = (attn_outs_b[4], attn_outs_b[5], attn_outs_b[6])
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out_b, softmax_max_b, softmax_sum_b,
+ cur_attn_out_b, cur_softmax_max_b, cur_softmax_sum_b
+ )
+ attn_out_b, softmax_max_b, softmax_sum_b = attn_out_updated, softmax_max_updated, softmax_sum_updated
+ kv = kv_list[-1]
+ k, v = kv[0], kv[1]
+ q, k, v = [x.view(-1, *x.shape[2:]) for x in [q, k, v]] # [2s, b, h]
+ attn_out = attn_out_b
+ attn_out_all[1], softmax_max_all[1], softmax_sum_all[1] = attn_out_b, softmax_max_b, softmax_sum_b
+
+ tensor_list.extend([q, k, v, attn_mask, softmax_max_all, softmax_sum_all])
+
+ ctx.n = n
+ ctx.rank = rank
+ ctx.keep_prob = keep_prob
+ ctx.cp_size = cp_size
+ ctx.cp_group = cp_group
+ ctx.prev_rank = prev_rank
+ ctx.next_rank = next_rank
+ ctx.cp_group_for_send_recv_overlap = cp_group_for_send_recv_overlap
+ ctx.softmax_scale = softmax_scale
+ ctx.rng_states_qa_kva = rng_states_qa_kva
+ ctx.rng_states_qb_kva = rng_states_qb_kva
+ ctx.rng_states_qb_kvb = rng_states_qb_kvb
+ return attn_out
+
+
+def attn_with_cp_for_ampipe_backward(ctx, attn_out, saved_tensor_list, dout, fa_bwd_args):
+ args = get_args()
+ kv_list, dkv_list, dout_list, ampipe_idx = (fa_bwd_args.kv_list, fa_bwd_args.dkv_list,
+ fa_bwd_args.dout_list, fa_bwd_args.cur_degree)
+
+ if kv_list is None:
+ kv_list = []
+ if dkv_list is None:
+ dkv_list = []
+ if dout_list is None:
+ dout_list = []
+ if args.ampipe_degree > 2:
+ raise RuntimeError(f"Context parallel only support ampipe_degree is 2, but got {args.ampipe_degree}")
+
+ q, k, v, attn_mask, softmax_max, softmax_sum = saved_tensor_list
+ n = ctx.n
+ rank = ctx.rank
+ softmax_scale = ctx.softmax_scale
+ cp_size = ctx.cp_size
+ cp_group = ctx.cp_group
+ cp_group_for_send_recv_overlap = ctx.cp_group_for_send_recv_overlap
+ cp_global_ranks = mpu.get_context_parallel_global_ranks()
+ keep_prob = ctx.keep_prob
+ rng_states_qa_kva = ctx.rng_states_qa_kva
+ rng_states_qb_kva = ctx.rng_states_qb_kva
+ rng_states_qb_kvb = ctx.rng_states_qb_kvb
+ # [2s, b, h] -> [2, s, b, h]
+ q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]]
+
+ attn_out_a, softmax_max_a, softmax_sum_a = attn_out[0], softmax_max[0], softmax_sum[0]
+ attn_out_b, softmax_max_b, softmax_sum_b = attn_out[1], softmax_max[1], softmax_sum[1]
+
+ if ampipe_idx == 0:
+ send_recv_comm = RingP2P(cp_global_ranks, cp_group, cp_group_for_send_recv_overlap, is_backward=True)
+ dq, dk, dv = None, None, None
+ recv_kv_dkv = None
+ recv_kv = None
+ recv_dkv = None
+ # [s, b, h]
+ qa, ka, va = [x[0] for x in [q, k, v]]
+ qb, kb, vb = [x[1] for x in [q, k, v]]
+ dq_b = torch.zeros_like(qb)
+ dk = torch.zeros_like(k)
+ dv = torch.zeros_like(v)
+ kv = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
+ send_kv_dkv = torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)
+
+ for i in range(cp_size):
+ # wait until KV is received from recv_src
+ if send_recv_comm.wait():
+ # only received kv in the second loop
+ if i == 1:
+ send_kv = recv_kv
+ send_kv_dkv[0].copy_(send_kv)
+ else:
+ send_kv_dkv = recv_kv_dkv
+ if i > 0:
+ dkv = torch.cat((dk.unsqueeze(0), dv.unsqueeze(0)), dim=0)
+ send_kv_dkv[1].copy_(dkv)
+
+ # just send-recv kv in the first loop
+ if i == 0:
+ send_kv = kv
+ recv_kv = torch.empty_like(send_kv)
+ send_recv_comm.async_send_recv(send_kv, recv_kv)
+ kv_list.append(send_kv)
+ # just send-recv dkv in the last loop
+ elif i == cp_size - 1:
+ send_dkv = send_kv_dkv[1]
+ recv_dkv = torch.empty_like(send_dkv)
+ send_recv_comm.async_send_recv(send_dkv, recv_dkv)
+ cur_k, cur_v = send_kv_dkv[0][0], send_kv_dkv[0][1]
+ ka, va = cur_k[0], cur_v[0]
+ kv_list.append(send_kv_dkv[0])
+ else:
+ recv_kv_dkv = torch.empty_like(send_kv_dkv)
+ send_recv_comm.async_send_recv(send_kv_dkv, recv_kv_dkv)
+ cur_k, cur_v = send_kv_dkv[0][0], send_kv_dkv[0][1]
+ ka, va = cur_k[0], cur_v[0]
+ kv_list.append(send_kv_dkv[0])
+
+ attn_grad_outs_b = flash_attention_backward(
+ (qb, ka, va, n),
+ dout, softmax_max_b, softmax_sum_b, attn_out_b,
+ None, softmax_scale, keep_prob, rng_states_qb_kva[cp_size - i - 1][0],
+ rng_states_qb_kva[cp_size - i - 1][1], rng_states_qb_kva[cp_size - i - 1][2]
+ )
+
+ cur_dq_b, cur_dk_a, cur_dv_a = attn_grad_outs_b[0], attn_grad_outs_b[1], attn_grad_outs_b[2]
+ if i == 0:
+ dq_b = cur_dq_b
+ dk[0].copy_(cur_dk_a)
+ dv[0].copy_(cur_dv_a)
+ else:
+ # wait until dKV is received from recv_src
+ send_recv_comm.wait()
+ # only received dkv in the last loop
+ if i == cp_size - 1:
+ dkv = recv_dkv
+ else:
+ send_kv_dkv = recv_kv_dkv
+ dkv = send_kv_dkv[1]
+ dk, dv = dkv[0], dkv[1]
+ dq_b.add_(cur_dq_b)
+ dk[0].add_(cur_dk_a)
+ dv[0].add_(cur_dv_a)
+ dkv_list.append(dq_b)
+ dkv_list.append(dk[0])
+ dkv_list.append(dv[0])
+ dout_list.append(dout)
+ else:
+ send_recv_comm = RingP2P(cp_global_ranks, cp_group, cp_group_for_send_recv_overlap)
+ kv_list.reverse()
+
+ recv_dkv = None
+ # [s, b, h]
+ qa, ka, va = [x[0] for x in [q, k, v]]
+ qb, kb, vb = [x[1] for x in [q, k, v]]
+ dq_a, dk_a, dv_a, dq_b, dk_b, dv_b = [torch.zeros_like(x) for x in [qa, ka, va, qb, kb, vb]]
+ send_dkv = torch.empty((2, 2, *ka.shape), dtype=ka.dtype, device=ka.device)
+
+ for i in range(cp_size):
+ # the first loop no send-recv
+ if i > 0:
+ if i <= rank + 1:
+ if i <= rank:
+ dkv_a = torch.cat((dk_a.unsqueeze(0), dv_a.unsqueeze(0)), dim=0)
+ # send_dkv = dkv_a
+ send_dkv[0].copy_(dkv_a)
+ else:
+ dkv_b = torch.cat((dk_b.unsqueeze(0), dv_b.unsqueeze(0)), dim=0)
+ # send_dkv = dkv_b
+ send_dkv[1].copy_(dkv_b)
+ else:
+ dkv_a = torch.cat((dk_a.unsqueeze(0), dv_a.unsqueeze(0)), dim=0)
+ dkv_b = torch.cat((dk_b.unsqueeze(0), dv_b.unsqueeze(0)), dim=0)
+ dkv = torch.cat((dkv_a.unsqueeze(0), dkv_b.unsqueeze(0)), dim=0)
+ send_dkv = dkv
+
+ recv_dkv = torch.empty_like(send_dkv)
+ send_recv_comm.async_send_recv(send_dkv, recv_dkv)
+
+ if i == cp_size - 1:
+ cur_kv = kv_list[0]
+ ka, va = cur_kv[0][0], cur_kv[1][0]
+ kb, vb = cur_kv[0][1], cur_kv[1][1]
+ attn_grad_outs_a = flash_attention_backward(
+ (qa, ka, va, n),
+ dout, softmax_max_a, softmax_sum_a, attn_out_a,
+ attn_mask, softmax_scale, keep_prob,
+ rng_states_qa_kva[0][0], rng_states_qa_kva[0][1], rng_states_qa_kva[0][2]
+ )
+ attn_grad_outs_b = flash_attention_backward(
+ (qb, kb, vb, n),
+ dout_list[0], softmax_max_b, softmax_sum_b, attn_out_b,
+ attn_mask, softmax_scale, keep_prob,
+ rng_states_qb_kvb[0][0], rng_states_qb_kvb[0][1], rng_states_qb_kvb[0][2]
+ )
+ cur_dq_a, cur_dk_a, cur_dv_a = attn_grad_outs_a[0], attn_grad_outs_a[1], attn_grad_outs_a[2]
+ cur_dq_b, cur_dk_b, cur_dv_b = attn_grad_outs_b[0], attn_grad_outs_b[1], attn_grad_outs_b[2]
+ elif i < rank:
+ cur_kv = kv_list[i + 1]
+ ka, va = cur_kv[0][0], cur_kv[1][0]
+ attn_grad_outs_a = flash_attention_backward(
+ (qa, ka, va, n),
+ dout, softmax_max_a, softmax_sum_a, attn_out_a,
+ None, softmax_scale, keep_prob,
+ rng_states_qa_kva[i + 1][0], rng_states_qa_kva[i + 1][1], rng_states_qa_kva[i + 1][2]
+ )
+ cur_dq_a, cur_dk_a, cur_dv_a = attn_grad_outs_a[0], attn_grad_outs_a[1], attn_grad_outs_a[2]
+ else:
+ cur_kv = kv_list[i + 1]
+ kb, vb = cur_kv[0][1], cur_kv[1][1]
+ attn_grad_outs_b = flash_attention_backward(
+ (qb, kb, vb, n),
+ dout_list[0], softmax_max_b, softmax_sum_b, attn_out_b,
+ None, softmax_scale, keep_prob,
+ rng_states_qb_kvb[i + 1][0], rng_states_qb_kvb[i + 1][1], rng_states_qb_kvb[i + 1][2]
+ )
+ cur_dq_b, cur_dk_b, cur_dv_b = attn_grad_outs_b[0], attn_grad_outs_b[1], attn_grad_outs_b[2]
+
+ if i == 0:
+ if rank == 0:
+ dq_b, dk_b, dv_b = cur_dq_b, cur_dk_b, cur_dv_b
+ else:
+ dq_a, dk_a, dv_a = cur_dq_a, cur_dk_a, cur_dv_a
+ else:
+ # wait until dKV is received from recv_src
+ send_recv_comm.wait()
+
+ if i < cp_size - 1:
+ if rank == 0:
+ dkv_a = recv_dkv[0]
+ dk_a, dv_a = dkv_a[0], dkv_a[1]
+
+ dq_b.add_(cur_dq_b)
+ dk_b, dv_b = cur_dk_b, cur_dv_b
+ elif i <= rank:
+ if i == rank:
+ dkv_b = recv_dkv[1]
+ dk_b, dv_b = dkv_b[0], dkv_b[1]
+
+ dq_b.add_(cur_dq_b)
+ dk_b.add_(cur_dk_b)
+ dv_b.add_(cur_dv_b)
+ else:
+ dkv_a = recv_dkv[0]
+ dk_a, dv_a = dkv_a[0], dkv_a[1]
+
+ dq_a.add_(cur_dq_a)
+ dk_a.add_(cur_dk_a)
+ dv_a.add_(cur_dv_a)
+ else:
+ dkv = recv_dkv
+ dkv_a, dkv_b = dkv[0], dkv[1]
+ dk_a, dv_a = dkv_a[0], dkv_a[1]
+ dk_b, dv_b = dkv_b[0], dkv_b[1]
+
+ dq_b.add_(cur_dq_b)
+ dk_b.add_(cur_dk_b)
+ dv_b.add_(cur_dv_b)
+ else:
+ prev_dq_b, prev_dk_a, prev_dv_a = dkv_list
+ if rank == 0:
+ dkv_a = recv_dkv[0]
+ dk_a, dv_a = dkv_a[0], dkv_a[1]
+
+ dq_a = cur_dq_a
+ dk_a.add_(cur_dk_a)
+ dv_a.add_(cur_dv_a)
+ dk_b, dv_b = cur_dk_b, cur_dv_b
+ elif rank == cp_size - 1:
+ dkv_b = recv_dkv[1]
+ dk_b, dv_b = dkv_b[0], dkv_b[1]
+
+ dq_a.add_(cur_dq_a)
+ dk_a, dv_a = cur_dk_a, cur_dv_a
+ dk_b.add_(cur_dk_b)
+ dv_b.add_(cur_dv_b)
+ else:
+ dkv = recv_dkv
+ dkv_a, dkv_b = dkv[0], dkv[1]
+ dk_a, dv_a = dkv_a[0], dkv_a[1]
+ dk_b, dv_b = dkv_b[0], dkv_b[1]
+
+ dq_a.add_(cur_dq_a)
+ dk_a.add_(cur_dk_a)
+ dv_a.add_(cur_dv_a)
+ dk_b.add_(cur_dk_b)
+ dv_b.add_(cur_dv_b)
+
+ dk_a.add_(prev_dk_a)
+ dv_a.add_(prev_dv_a)
+ dq_b.add_(cur_dq_b)
+ dq_b.add_(prev_dq_b)
+
+ dq = torch.cat((dq_a.unsqueeze(0), dq_b.unsqueeze(0)), dim=0)
+ dk = torch.cat((dk_a.unsqueeze(0), dk_b.unsqueeze(0)), dim=0)
+ dv = torch.cat((dv_a.unsqueeze(0), dv_b.unsqueeze(0)), dim=0)
+ dq, dk, dv = [x.view(-1, *x.shape[2:]) for x in [dq, dk, dv]]
+
+ return dq, dk, dv
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/ulysses_context_parallel.py b/model/train/yoco_moe/mindspeed/core/context_parallel/ulysses_context_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3266ea00b81cc16ee3ac332c6fa7c30631d7b3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/context_parallel/ulysses_context_parallel.py
@@ -0,0 +1,745 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+from abc import ABC, abstractmethod
+from typing import Any, Tuple, Optional
+import torch
+import torch_npu
+from torch import Tensor
+from torch.nn import Module
+from megatron.training import get_args
+import mindspeed
+from mindspeed.core.tensor_parallel.comm_group_api import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_utils import sync_gather_along_first_dim, sync_reduce_scatter_along_first_dim
+from mindspeed.core.context_parallel.unaligned_cp.mapping import all_to_all, cal_split_sizes
+
+
+try:
+ from einops import rearrange
+except ImportError:
+ rearrange = None
+
+
+class GatherSizeCalculator(ABC):
+ """Abstract base class defining an interface for calculating the gather size in distributed operations.
+
+ The gather size usually refers to the size of the output tensor in the `gather_idx` dimension after all-to-all
+ communication (in the Ulysses mechanism).
+ """
+
+ @abstractmethod
+ def calculate(self) -> Optional[int]:
+ """Calculates the gather size based on current context such as batch size or sequence length.
+
+ Returns:
+ Optional[int]: The calculated gather size if applicable, otherwise None.
+ """
+ pass
+
+
+class DefaultGatherSizeCalculator(GatherSizeCalculator):
+ """Default implementation where the gather size is always None. If gather_size is None, it
+ will be calculated as the product of the original size of the `gather_idx` of the input tensor and the
+ `world_size`."""
+ def calculate(self, *args, **kwargs) -> Optional[int]:
+ return None
+
+
+class DynamicGatherSizeCalculator(GatherSizeCalculator):
+ """Dynamic implementation that calculates gather size based on the current batch attention mask sequence length."""
+
+ def calculate(self, *args: Any, **kwargs: Any) -> Optional[int]:
+ """Calculates the gather size based on the attention mask sequence length.
+ """
+ # Check if the first argument is a tensor; general masks (which type is list) do not support dynamic gather size
+ if not isinstance(args[0], torch.Tensor):
+ return None
+
+ atten_mask_seq_len = args[0].shape[-1]
+ return atten_mask_seq_len
+
+
+class UlyssesCollectiveComm(CollectiveCommIntf):
+ group = None
+
+ def __init__(self, group, name="ulysses"):
+ super().__init__(name)
+ UlyssesCollectiveComm.group = group
+
+ @classmethod
+ def get_comm_rank(cls):
+ return torch.distributed.get_rank(group=cls.group)
+
+ @classmethod
+ def get_comm_group_world_size(cls):
+ return torch.distributed.get_world_size(group=cls.group)
+
+ @classmethod
+ def get_comm_group(cls):
+ return cls.group
+
+
+def single_all_to_all(input_, scatter_idx, gather_idx, group):
+ seq_world_size = torch.distributed.get_world_size(group)
+ inp_shape = list(input_.shape)
+ inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
+ if scatter_idx < 2:
+ input_t = input_.reshape(
+ [seq_world_size, inp_shape[scatter_idx]] + \
+ inp_shape[scatter_idx + 1:]
+ ).contiguous()
+ else:
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
+ input_t = input_.reshape(
+ [-1, seq_world_size, inp_shape[scatter_idx]] + \
+ inp_shape[scatter_idx + 1:]
+ ).transpose(0, 1).contiguous()
+
+ output = torch.empty_like(input_t)
+ torch.distributed.all_to_all_single(output, input_t, group=group)
+
+ # if scattering the seq-dim, transpose the heads back to the original dimension
+ # e.g., [cp, s/cp, b, n/cp, d] -> [s/cp, b, cp, n/cp, d]
+ if scatter_idx < 2:
+ output = output.transpose(0, 1).transpose(1, 2).contiguous()
+
+ return output.reshape(
+ inp_shape[: gather_idx] + [inp_shape[gather_idx] * seq_world_size, ] + inp_shape[gather_idx + 1:]).contiguous()
+
+
+class _SeqAllToAll(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, group: torch.distributed.ProcessGroup, input_: Tensor, scatter_idx: int,
+ gather_idx: int) -> Tensor:
+ ctx.group = group
+ ctx.scatter_idx = scatter_idx
+ ctx.gather_idx = gather_idx
+
+ return single_all_to_all(input_, scatter_idx, gather_idx, group)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
+ return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
+
+
+class UlyssesContextAttention(torch.nn.Module):
+ """Implementation of Ulysses Context Attention mechanism.
+ """
+
+ def __init__(
+ self,
+ local_attention: Module,
+ sequence_process_group: torch.distributed.ProcessGroup,
+ scatter_idx: int = 2,
+ gather_idx: int = 0,
+ gather_size_calculator: GatherSizeCalculator = DefaultGatherSizeCalculator(), # Injected dependency
+ ) -> None:
+ """Initialization
+
+ Args:
+ local_attention (Module): An instance of a local attention mechanism
+ sequence_process_group (ProcessGroup): A PyTorch ProcessGroup object representing the process group for context parallelism.
+ scatter_idx (int): Index specifying along which dimension the data should be scattered during all-to-all communication.
+ gather_idx (int): Index specifying along which dimension the data should be gathered during all-to-all communication.
+ gather_size_calculator (GatherSizeCalculator): A callable object responsible for calculating the gather_size,
+ which is the total size of the all-to-all output tensor along the `gather_idx`.
+ Defaults to DefaultGatherSizeCalculator().
+ """
+ super(UlyssesContextAttention, self).__init__()
+ self.local_attn = local_attention
+ self.local_attn.ulysses_comm_para = {
+ 'spg': sequence_process_group,
+ 'scatter_idx': scatter_idx,
+ 'gather_idx': gather_idx,
+ 'gather_size_calculator': gather_size_calculator
+ }
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs: Any) -> Tensor:
+ """ forward
+
+ Arguments:
+ query (Tensor): query input to the layer
+ key (Tensor): key input to the layer
+ value (Tensor): value input to the layer
+ args: other args
+
+ Returns:
+ * output (Tensor): context output
+ """
+ global_args = get_args()
+ use_custom_ulysses_backward = (
+ global_args.context_parallel_size > 1 and
+ global_args.context_parallel_algo == "ulysses_cp_algo" and
+ not global_args.use_legacy_models and
+ global_args.context_parallel_kv_cache_policy
+ )
+ if use_custom_ulysses_backward:
+ output = self.local_attn(query, key, value, *args, **kwargs)
+ else:
+ spg = self.local_attn.ulysses_comm_para.get('spg')
+ scatter_idx = self.local_attn.ulysses_comm_para.get('scatter_idx')
+ gather_idx = self.local_attn.ulysses_comm_para.get('gather_idx')
+ seq_world_size = torch.distributed.get_world_size(spg)
+
+ # Handle cases where the sequence length of keys/values needs to be adjusted to match queries.
+ if seq_world_size > key.shape[scatter_idx] and query.shape[scatter_idx] % key.shape[scatter_idx] == 0:
+ key = key.repeat_interleave(query.shape[scatter_idx] // key.shape[scatter_idx], dim=scatter_idx)
+ value = value.repeat_interleave(query.shape[scatter_idx] // value.shape[scatter_idx], dim=scatter_idx)
+
+ # Calculate the gather size using the injected gather size calculator
+ gather_size = self.local_attn.ulysses_comm_para.get('gather_size_calculator').calculate(*args, **kwargs)
+
+ # The gather size usually refers to the size of the output tensor in the `gather_idx` dimension after
+ # the all-to-all communication
+ # in shape : e.g., [s/p:h:]
+ query_layer = all_to_all(query, spg, scatter_idx, gather_idx, gather_size)
+ key_layer = all_to_all(key, spg, scatter_idx, gather_idx, gather_size)
+ value_layer = all_to_all(value, spg, scatter_idx, gather_idx, gather_size)
+
+ # out shape : e.g., [s:h/p:]
+ context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
+
+ # Reshape the context layer if necessary to align dimensions properly
+ if gather_size:
+ context_shape = context_layer.shape
+ scatter_sizes_query = cal_split_sizes(query.shape[scatter_idx], seq_world_size)
+
+ # To reshape the context_layer tensor to ensure context_layer.size(gather_idx) and context_layer.size(scatter_idx)
+ # has the correct value.
+ context_layer = context_layer.reshape(context_shape[0], context_shape[1],
+ scatter_sizes_query[torch.distributed.get_rank(spg)], -1).contiguous()
+
+ output = all_to_all(context_layer, spg, gather_idx, scatter_idx, query.size(scatter_idx))
+
+ # Final reshape to maintain correct dimensions after all-to-all communication
+ if gather_size:
+ output = output.reshape(output.shape[0], output.shape[1], -1).contiguous()
+
+ # out e.g., [s/p::h]
+ return output
+
+
+class AttnQKVReshape:
+ """Ulysses Attention Reshape QKV Implementation"""
+
+ def __init__(self, attn_para):
+ self.attn_para = attn_para
+
+ def reshape_forward(self, query, key, value):
+ """
+ Implements of qkv reshape in forward of ulysses attention
+
+ Args:
+ query (Tensor): query input to the attention layer with shape [s, b, h, d]
+ key (Tensor): key input to the attention layer with shape [s, b, h, d]
+ value (Tensor): value input to the attention layer with shape [s, b, h, d]
+
+ Returns:
+ query (Tensor): query input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ key (Tensor): key input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ value (Tensor): value input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ attn_para (Dict): the parameters used in attention computation
+ """
+ # q, k, v: [s, b, h, d]
+
+ # attention parameters
+ packed_seq_params = self.attn_para.get('packed_seq_params')
+ seq_length, bsz, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
+
+ self.attn_para['n_head'] = n_head
+ self.attn_para['q_seq_len'] = seq_length
+ self.attn_para['k_head'] = key.shape[2]
+ self.attn_para['v_head'] = value.shape[2]
+ self.attn_para['k_seq_len'] = key.shape[0]
+ self.attn_para['v_seq_len'] = value.shape[0]
+
+ # reshape [s, b, h, d] to SBH([s, b, h*d]) or TND([s*b, h, d])
+ if packed_seq_params is not None: # TND
+ actual_seq_qlen = packed_seq_params.cu_seqlens_q.tolist()
+ actual_seq_kvlen = packed_seq_params.cu_seqlens_kv.tolist()
+ query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]]
+ shape_order = 'TND'
+ else: # SBH
+ actual_seq_qlen = None
+ actual_seq_kvlen = None
+ query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]]
+ shape_order = 'SBH'
+
+ self.attn_para['shape_order'] = shape_order
+ self.attn_para['actual_seq_qlen'] = actual_seq_qlen
+ self.attn_para['actual_seq_kvlen'] = actual_seq_kvlen
+
+ return query, key, value, self.attn_para
+
+ def reshape_backward(self, dq, dk, dv):
+ """
+ Implements of qkv reshape in backward of ulysses attention
+
+ Args:
+ dq (Tensor): query grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ dk (Tensor): key grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ dv (Tensor): value grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+
+ Returns:
+ dq (Tensor): query grad output of the attention layer with shape [s, b, h, d]
+ dk (Tensor): key grad output of the attention layer with shape [s, b, h, d]
+ dv (Tensor): value grad output of the attention layer with shape [s, b, h, d]
+ """
+ # dq, dk, dv: [s, b, h*d] or [s*b, h, d]
+
+ # attention parameters
+ packed_seq_params = self.attn_para.get('packed_seq_params')
+ q_seq_len = self.attn_para.get('q_seq_len')
+ k_seq_len = self.attn_para.get('k_seq_len')
+ v_seq_len = self.attn_para.get('v_seq_len')
+ n_head = self.attn_para.get('n_head')
+ k_head = self.attn_para.get('k_head')
+ v_head = self.attn_para.get('v_head')
+
+ # reshape SBH([s, b, h*d]) or TND([s*b, h, d]) back to [s, b, h, d]
+ if packed_seq_params is not None: # TND
+ s, b = q_seq_len, dq.shape[0] // q_seq_len
+ dq = rearrange(dq, '(b s) h d -> s b h d', s=s, b=b)
+ s, b = k_seq_len, dk.shape[0] // k_seq_len
+ dk = rearrange(dk, '(b s) h d -> s b h d', s=s, b=b)
+ s, b = v_seq_len, dv.shape[0] // v_seq_len
+ dv = rearrange(dv, '(b s) h d -> s b h d', s=s, b=b)
+ else: # SBH
+ h, d = n_head, dq.shape[2] // n_head
+ dq = rearrange(dq, 's b (h d) -> s b h d', h=h, d=d)
+ h, d = k_head, dk.shape[2] // k_head
+ dk = rearrange(dk, 's b (h d) -> s b h d', h=h, d=d)
+ h, d = v_head, dv.shape[2] // v_head
+ dv = rearrange(dv, 's b (h d) -> s b h d', h=h, d=d)
+
+ return dq, dk, dv
+
+
+class RepeatAll2AllComm:
+ """Ulysses Attention Repeat All2All Communication Implementation"""
+
+ def __init__(self, ulysses_comm_para, attn_para):
+ self.ulysses_comm_para = ulysses_comm_para
+ self.attn_para = attn_para
+ self.qkv_reshape = AttnQKVReshape(attn_para)
+
+ def comm_forward(self, query, key, value):
+ """
+ Implements of Repeat-All2All communication in forward of ulysses attention
+
+ Args:
+ query (Tensor): query input to the attention layer with shape [s, b, h, d]
+ key (Tensor): key input to the attention layer with shape [s, b, h, d]
+ value (Tensor): value input to the attention layer with shape [s, b, h, d]
+
+ Returns:
+ query (Tensor): query input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ key (Tensor): key input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ value (Tensor): value input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ attn_para (Dict): the parameters used in attention computation
+ """
+ # q, k, v: [s, b, h, d]
+
+ # communication parameters
+ spg = self.ulysses_comm_para.get('spg')
+ scatter_idx = self.ulysses_comm_para.get('scatter_idx')
+ gather_idx = self.ulysses_comm_para.get('gather_idx')
+ cache_policy = self.ulysses_comm_para.get('cache_policy')
+
+ # repeat parameters
+ seq_world_size = torch.distributed.get_world_size(spg)
+ do_repeat = seq_world_size > key.shape[scatter_idx] and query.shape[scatter_idx] % key.shape[scatter_idx] == 0
+ self.ulysses_comm_para['do_repeat'] = do_repeat
+ self.ulysses_comm_para['repeat_num'] = query.shape[scatter_idx] // key.shape[scatter_idx]
+
+ # if forward repeat, [s, b, h, d] -> [s, b, h*cp, d]
+ if do_repeat:
+ key = key.repeat_interleave(query.shape[scatter_idx] // key.shape[scatter_idx], dim=scatter_idx)
+ value = value.repeat_interleave(query.shape[scatter_idx] // value.shape[scatter_idx], dim=scatter_idx)
+ elif cache_policy is not None:
+ raise AssertionError(
+ 'KV Cache dose not suggest to use when key and value do not repeat'
+ )
+
+ # all2all communication forward, [s, b, h, d] -> [s*cp, b, h//cp, d]
+ query = single_all_to_all(query, scatter_idx, gather_idx, spg)
+ key = single_all_to_all(key, scatter_idx, gather_idx, spg)
+ value = single_all_to_all(value, scatter_idx, gather_idx, spg)
+
+ # reshape [s, b, h, d] to SBH([s, b, h*d]) or TND([s*b, h, d])
+ query, key, value, self.attn_para = self.qkv_reshape.reshape_forward(query, key, value)
+
+ return query, key, value, self.attn_para
+
+ def comm_backward(self, dq, dk, dv):
+ """
+ Implements of Repeat-All2All communication in backward of ulysses attention
+
+ Args:
+ dq (Tensor): query grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ dk (Tensor): key grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ dv (Tensor): value grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+
+ Returns:
+ dq (Tensor): query grad output of the attention layer with shape [s, b, h, d]
+ dk (Tensor): key grad output of the attention layer with shape [s, b, h, d]
+ dv (Tensor): value grad output of the attention layer with shape [s, b, h, d]
+ """
+ # dq, dk, dv: SBH([s, b, h*d]) or TND([s*b, h, d])
+
+ # reshape SBH([s, b, h*d]) or TND([s*b, h, d]) back to [s, b, h, d]
+ dq, dk, dv = self.qkv_reshape.reshape_backward(dq, dk, dv)
+
+ # communication parameters
+ spg = self.ulysses_comm_para.get('spg')
+ scatter_idx = self.ulysses_comm_para.get('scatter_idx')
+ gather_idx = self.ulysses_comm_para.get('gather_idx')
+ do_repeat = self.ulysses_comm_para.get('do_repeat')
+ repeat_num = self.ulysses_comm_para.get('repeat_num')
+
+ # all2all communication backward, [s, b, h, d] -> [s//cp, b, h*cp, d]
+ dq = single_all_to_all(dq, gather_idx, scatter_idx, spg)
+ dk = single_all_to_all(dk, gather_idx, scatter_idx, spg)
+ dv = single_all_to_all(dv, gather_idx, scatter_idx, spg)
+
+ # if backward repeat, [s, b, h, d] -> [s, b, h//cp, d]
+ if do_repeat:
+ dk = dk.view(
+ *dk.shape[:scatter_idx], dk.shape[scatter_idx] // repeat_num, repeat_num, *dk.shape[scatter_idx + 1:]
+ ).sum(dim=scatter_idx + 1)
+ dv = dv.view(
+ *dv.shape[:scatter_idx], dv.shape[scatter_idx] // repeat_num, repeat_num, *dv.shape[scatter_idx + 1:]
+ ).sum(dim=scatter_idx + 1)
+
+ return dq, dk, dv
+
+ def recomm_backward(self, input_tensor):
+ """
+ Implements of Repeat-All2All re-communication in backward of ulysses attention
+
+ Args:
+ input_tensor (Tensor): key or value input of the attention layer with shape [s, b, h, d]
+
+ Returns:
+ output (Tensor): key or value input of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ """
+ # k, v: [s, b, h, d]
+
+ # communication parameters
+ spg = self.ulysses_comm_para.get('spg')
+ scatter_idx = self.ulysses_comm_para.get('scatter_idx')
+ gather_idx = self.ulysses_comm_para.get('gather_idx')
+ do_repeat = self.ulysses_comm_para.get('do_repeat')
+ repeat_num = self.ulysses_comm_para.get('repeat_num')
+
+ # attention parameters
+ packed_seq_params = self.attn_para.get('packed_seq_params')
+
+ # if repeat, [s, b, h, d] -> [s, b, h*cp, d]
+ if do_repeat:
+ input_tensor = input_tensor.repeat_interleave(repeat_num, dim=scatter_idx)
+
+ # all2all re-communication, [s, b, h, d] -> [s*cp, b, h//cp, d]
+ output = single_all_to_all(input_tensor, scatter_idx, gather_idx, spg)
+
+ # reshape [s, b, h, d] to SBH([s, b, h*d]) or TND([s*b, h, d])
+ if packed_seq_params is not None:
+ output = rearrange(output, 's b h d -> (b s) h d')
+ else: # SBH
+ output = rearrange(output, 's b h d -> s b (h d)')
+
+ return output
+
+
+class AllGatherComm:
+ """Ulysses Attention AllGather KV + All2All Q Communication Implementation"""
+
+ def __init__(self, ulysses_comm_para, attn_para):
+ self.ulysses_comm_para = ulysses_comm_para
+ self.attn_para = attn_para
+ self.qkv_reshape = AttnQKVReshape(attn_para)
+ spg = self.ulysses_comm_para.get('spg')
+ self.ulysses_collective_comm = UlyssesCollectiveComm(spg)
+
+ def comm_forward(self, query, key, value):
+ """
+ Implements of AllGather KV + All2All Q communication in forward of ulysses attention
+
+ Args:
+ query (Tensor): query input to the attention layer with shape [s, b, h, d]
+ key (Tensor): key input to the attention layer with shape [s, b, h, d]
+ value (Tensor): value input to the attention layer with shape [s, b, h, d]
+
+ Returns:
+ query (Tensor): query input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ key (Tensor): key input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ value (Tensor): value input to the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ attn_para (Dict): the parameters used in attention computation
+ """
+ # q, k, v: [s, b, h, d]
+
+ # communication parameters
+ spg = self.ulysses_comm_para.get('spg')
+ scatter_idx = self.ulysses_comm_para.get('scatter_idx')
+ gather_idx = self.ulysses_comm_para.get('gather_idx')
+
+ # query all2all communication forward, [s, b, h, d] -> [s*cp, b, h//cp, d]
+ query = single_all_to_all(query, scatter_idx, gather_idx, spg)
+
+ # key and value allgather communication forward, [s, b, h, d] -> [s*cp, b, h, d]
+ key = sync_gather_along_first_dim(key, self.ulysses_collective_comm)
+ value = sync_gather_along_first_dim(value, self.ulysses_collective_comm)
+
+ # reshape [s, b, h, d] to SBH([s, b, h*d]) or TND([s*b, h, d])
+ query, key, value, self.attn_para = self.qkv_reshape.reshape_forward(query, key, value)
+
+ return query, key, value, self.attn_para
+
+ def comm_backward(self, dq, dk, dv):
+ """
+ Implements of AllGather KV + All2All Q communication in backward of ulysses attention
+
+ Args:
+ dq (Tensor): query grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ dk (Tensor): key grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ dv (Tensor): value grad output of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+
+ Returns:
+ dq (Tensor): query grad output of the attention layer with shape [s, b, h, d]
+ dk (Tensor): key grad output of the attention layer with shape [s, b, h, d]
+ dv (Tensor): value grad output of the attention layer with shape [s, b, h, d]
+ """
+ # dq, dk, dv: SBH([s, b, h*d]) or TND([s*b, h, d])
+
+ # reshape SBH([s, b, h*d]) or TND([s*b, h, d]) back to [s, b, h, d]
+ dq, dk, dv = self.qkv_reshape.reshape_backward(dq, dk, dv)
+
+ # communication parameters
+ spg = self.ulysses_comm_para.get('spg')
+ scatter_idx = self.ulysses_comm_para.get('scatter_idx')
+ gather_idx = self.ulysses_comm_para.get('gather_idx')
+
+ # query all2all communication backward, [s, b, h, d] -> [s//cp, b, h*cp, d]
+ dq = single_all_to_all(dq, gather_idx, scatter_idx, spg)
+
+ # key and value allgather communication backward, [s, b, h, d] -> [s//cp, b, h, d]
+ dk = sync_reduce_scatter_along_first_dim(dk, self.ulysses_collective_comm)
+ dv = sync_reduce_scatter_along_first_dim(dv, self.ulysses_collective_comm)
+
+ return dq, dk, dv
+
+ def recomm_backward(self, input_tensor):
+ """
+ Implements of AllGather KV + All2All Q re-communication in backward of ulysses attention
+
+ Args:
+ input_tensor (Tensor): key or value input of the attention layer with shape [s, b, h, d]
+
+ Returns:
+ output (Tensor): key or value input of the attention layer with shape [s, b, h*d] or [s*b, h, d]
+ """
+ # k, v: [s, b, h, d]
+
+ # attention parameters
+ packed_seq_params = self.attn_para.get('packed_seq_params')
+
+ # allgather re-communication, [s, b, h, d] -> [s*cp, b, h, d]
+ output = sync_gather_along_first_dim(input_tensor, self.ulysses_collective_comm)
+
+ # reshape [s, b, h, d] to SBH([s, b, h*d]) or TND([s*b, h, d])
+ if packed_seq_params is not None: # TND
+ output = rearrange(output, 's b h d -> (b s) h d')
+ else: # SBH
+ output = rearrange(output, 's b h d -> s b (h d)')
+
+ return output
+
+
+class UlyssesAttnWithKVCache(torch.autograd.Function):
+ """Ulysses Attention With KV Cache Implementation"""
+
+ @staticmethod
+ def forward(ctx, query, key, value, attn_para, ulysses_comm_para) -> Tensor:
+ """
+ Implements of Ulysses Attention With KV Cache forward
+
+ Args:
+ query (Tensor): query input to the attention layer with shape [s, b, h, d]
+ key (Tensor): key input to the attention layer with shape [s, b, h, d]
+ value (Tensor): value input to the attention layer with shape [s, b, h, d]
+
+ Returns:
+ output (Tensor): ulysses attention output with shape [s, b, h*d] or [s*b, h, d]
+ """
+ # q, k, v: [s, b, h, d]
+
+ # communication parameters
+ spg = ulysses_comm_para.get('spg')
+ scatter_idx = ulysses_comm_para.get('scatter_idx')
+ gather_idx = ulysses_comm_para.get('gather_idx')
+ cache_policy = ulysses_comm_para.get('cache_policy')
+ use_ulysses_allgather_kv = ulysses_comm_para.get('use_ulysses_allgather_kv')
+
+ # repeat-all2all or allgather kv + all2all q
+ if use_ulysses_allgather_kv:
+ if key.shape[2] != 1:
+ raise AssertionError(
+ 'When either the head number of key or value is not equal to 1, '
+ 'use all2all communication to get better performance.'
+ )
+ # allgather kv + all2all q communication forward
+ ulysses_comm = AllGatherComm(ulysses_comm_para, attn_para)
+ else:
+ # repeat-all2all communication forward
+ ulysses_comm = RepeatAll2AllComm(ulysses_comm_para, attn_para)
+
+ # communication forward
+ q, k, v = query.clone(), key.clone(), value.clone()
+ q, k, v, attn_para = ulysses_comm.comm_forward(q, k, v)
+
+ # attention parameters
+ packed_seq_params = attn_para.get('packed_seq_params')
+ attention_mask = attn_para.get('attention_mask')
+ scale = attn_para.get('scale')
+ pre_tokens = attn_para.get('pre_tokens')
+ next_tokens = attn_para.get('next_tokens')
+ keep_prob = attn_para.get('keep_prob')
+ sparse_mode = attn_para.get('sparse_mode')
+ n_head = attn_para.get('n_head')
+ shape_order = attn_para.get('shape_order')
+ actual_seq_len = attn_para.get('actual_seq_qlen')
+ actual_seq_kvlen = attn_para.get('actual_seq_kvlen')
+ seq_length = attn_para.get('q_seq_len')
+
+ # kv cache
+ if cache_policy == "full":
+ k_cache, v_cache = key.clone(), value.clone()
+ elif cache_policy == "half":
+ k_cache, v_cache = key.clone(), v.clone()
+ else:
+ k_cache, v_cache = k.clone(), v.clone()
+
+ # attention forward
+ res = torch_npu.npu_fusion_attention(
+ q, k, v, n_head, shape_order,
+ pse=None,
+ padding_mask=None,
+ atten_mask=attention_mask,
+ scale=scale,
+ pre_tockens=pre_tokens,
+ next_tockens=next_tokens,
+ keep_prob=keep_prob,
+ inner_precise=0,
+ sparse_mode=sparse_mode,
+ actual_seq_qlen=actual_seq_len,
+ actual_seq_kvlen=actual_seq_kvlen
+ )
+
+ attn_out, softmax_max, softmax_sum = res[0], res[1], res[2]
+
+ # if TND, reshape TND([b*s, h, d]) to SBH([s, b, h*d])
+ if packed_seq_params is not None:
+ s, b = seq_length, attn_out.shape[0] // seq_length
+ attn_out = rearrange(attn_out, '(b s) h d -> s b (h d)', s=s, b=b)
+
+ # output all2all communication forward
+ output = single_all_to_all(attn_out, gather_idx, scatter_idx, spg)
+
+ ctx.save_for_backward(q, k_cache, v_cache, attn_out, softmax_max, softmax_sum, attention_mask)
+ ctx.ulysses_comm = ulysses_comm
+ ctx.ulysses_comm_para = ulysses_comm_para
+ ctx.attn_para = attn_para
+
+ return output
+
+ @staticmethod
+ def backward(ctx, dout):
+ """
+ Implements of Ulysses Attention With KV Cache backward
+
+ Args:
+ dout (Tensor): the attention layer output grad with shape [s, b, h*d] or [s*b, h, d]
+
+ Returns:
+ dq (Tensor): query grad output of the attention layer with shape [s, b, h, d]
+ dk (Tensor): key grad output of the attention layer with shape [s, b, h, d]
+ dv (Tensor): value grad output of the attention layer with shape [s, b, h, d]
+ """
+ # input, attention output grad: [s, b, h*d] or [s*b, h, d]
+
+ # get forward parameters
+ query, k_cache, v_cache, attn_out, softmax_max, softmax_sum, attention_mask = ctx.saved_tensors
+ ulysses_comm = ctx.ulysses_comm
+ ulysses_comm_para = ctx.ulysses_comm_para
+ attn_para = ctx.attn_para
+
+ # communication parameters
+ spg = ulysses_comm_para.get('spg')
+ scatter_idx = ulysses_comm_para.get('scatter_idx')
+ gather_idx = ulysses_comm_para.get('gather_idx')
+ cache_policy = ulysses_comm_para.get('cache_policy')
+
+ # attention parameters
+ packed_seq_params = attn_para.get('packed_seq_params')
+ attention_mask = attn_para.get('attention_mask')
+ scale = attn_para.get('scale')
+ pre_tokens = attn_para.get('pre_tokens')
+ next_tokens = attn_para.get('next_tokens')
+ keep_prob = attn_para.get('keep_prob')
+ sparse_mode = attn_para.get('sparse_mode')
+ n_head = attn_para.get('n_head')
+ shape_order = attn_para.get('shape_order')
+ actual_seq_len = attn_para.get('actual_seq_qlen')
+ actual_seq_kvlen = attn_para.get('actual_seq_kvlen')
+
+ # output all2all communication backward
+ dout = single_all_to_all(dout, scatter_idx, gather_idx, spg)
+
+ # if TND, reshape SBH([s, b, h*d]) to TND([b*s, h, d])
+ if packed_seq_params is not None:
+ h, d = n_head, dout.shape[2] // n_head
+ dout = rearrange(dout, 's b (h d) -> (b s) h d', h=h, d=d)
+ attn_out = rearrange(attn_out, 's b (h d) -> (b s) h d', h=h, d=d)
+
+ # kv cache re-communication
+ if cache_policy == "full":
+ key = ulysses_comm.recomm_backward(k_cache)
+ value = ulysses_comm.recomm_backward(v_cache)
+ elif cache_policy == "half":
+ key = ulysses_comm.recomm_backward(k_cache)
+ value = v_cache
+ else:
+ key = k_cache
+ value = v_cache
+
+ # attention backward
+ attn_grad_outs = torch_npu.npu_fusion_attention_grad(
+ query, key, value, dout, n_head,
+ shape_order,
+ pse=None,
+ padding_mask=None,
+ atten_mask=attention_mask,
+ softmax_max=softmax_max,
+ softmax_sum=softmax_sum,
+ attention_in=attn_out,
+ scale_value=scale,
+ pre_tockens=pre_tokens,
+ next_tockens=next_tokens,
+ sparse_mode=sparse_mode,
+ keep_prob=keep_prob,
+ actual_seq_qlen=actual_seq_len,
+ actual_seq_kvlen=actual_seq_kvlen
+ )
+
+ dq, dk, dv = attn_grad_outs[0], attn_grad_outs[1], attn_grad_outs[2]
+
+ dq, dk, dv = ulysses_comm.comm_backward(dq, dk, dv)
+
+ return dq, dk, dv, None, None
+
+
+def ulyssesattn_context_parallel(query, key, value, attn_para, ulysses_comm_para):
+ out = UlyssesAttnWithKVCache.apply(query, key, value, attn_para, ulysses_comm_para)
+ return out
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/unaligned_cp/__init__.py b/model/train/yoco_moe/mindspeed/core/context_parallel/unaligned_cp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/unaligned_cp/mapping.py b/model/train/yoco_moe/mindspeed/core/context_parallel/unaligned_cp/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..44f097e5844e466e74935d654f627dd4f5e93bf4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/context_parallel/unaligned_cp/mapping.py
@@ -0,0 +1,623 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from typing import Optional, List
+
+import torch
+import torch.distributed as dist
+
+
+PERMUTE_DIMS1 = {
+ 4: (1, 2, 3, 0),
+ 5: (1, 2, 3, 0, 4),
+}
+
+
+PERMUTE_DIMS2 = {
+ 4: (1, 2, 0, 3),
+ 5: (1, 2, 0, 3, 4),
+}
+
+
+def adjust_tensor_dimensions(tensor, scatter_idx, gather_idx):
+ """
+ Adjusts the dimensions of a tensor to move scatter_idx and gather_idx to dim 0 and dim 1 respectively.
+
+ Args:
+ tensor (torch.Tensor): The input tensor.
+ scatter_idx (int): The index of the dimension to scatter.
+ gather_idx (int): The index of the dimension to gather.
+
+ Returns:
+ tuple: A tuple containing the adjusted tensor and the list of adjusted dimensions.
+ """
+ dims = list(range(tensor.dim()))
+ assert scatter_idx != gather_idx
+ if gather_idx == 0:
+ if scatter_idx != 1:
+ dims[1], dims[gather_idx] = dims[gather_idx], dims[1]
+ dims[0], dims[scatter_idx] = dims[scatter_idx], dims[0]
+ else:
+ dims[scatter_idx], dims[gather_idx] = dims[gather_idx], dims[scatter_idx]
+
+ elif gather_idx == 1:
+ if scatter_idx != 0:
+ # If scatter_idx is not 0, move it to 0
+ dims[0], dims[scatter_idx] = dims[scatter_idx], dims[0]
+ else:
+ if scatter_idx == 0:
+ dims[1], dims[gather_idx] = dims[gather_idx], dims[1]
+ else:
+ dims[0], dims[scatter_idx] = dims[scatter_idx], dims[0]
+ dims[1], dims[gather_idx] = dims[gather_idx], dims[1]
+ return tensor.permute(dims).contiguous(), dims
+
+
+def unadjust_tensor_dimensions(tensor, adjusted_dims):
+ """
+ Reverses the dimension adjustments using the list of adjusted dimensions.
+
+ Args:
+ tensor (torch.Tensor): The tensor whose dimensions need to be restored.
+ adjusted_dims (list): The list of adjusted dimensions used during the adjustment process.
+
+ Returns:
+ torch.Tensor: The tensor with its dimensions reverted to the original order.
+ """
+ inverse_dims = [0] * len(adjusted_dims)
+
+ for new_pos, old_pos in enumerate(adjusted_dims):
+ inverse_dims[old_pos] = new_pos
+
+ # Restore the dimension order
+ unadjusted_tensor = tensor.permute(inverse_dims).contiguous()
+ return unadjusted_tensor
+
+
+def _all_to_all(
+ input_: torch.Tensor,
+ group: dist.ProcessGroup,
+ scatter_dim: int,
+ gather_dim: int,
+ gather_size: Optional[int] = None
+):
+ """
+ Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
+ dimension and then gathers it along the specified gather dimension. The function supports aligned and unaligned
+ data.
+ Args:
+ input_ (torch.Tensor): The input tensor to be processed.
+ group (dist.ProcessGroup): The process group perform the operation within.
+ scatter_dim (int): The index of the dimension that needs to be scattered.
+ gather_dim (int): The index of the dimension that needs to be gathered.
+ gather_size (Optional[int]): The total size of the output tensor along the `gather_dim`. If not provided, it
+ will be calculated as the product of the original size of the `gather_dim` of the input tensor and the
+ `world_size`.
+
+ Returns:
+ torch.Tensor: The resulting tensor after performing the all-to-all operation.
+
+ Note:
+ - The tensor will be split into `world_size` chunks along the `scatter_dim`. Each process will receive one
+ chunk. If the total size of the `scatter_dim` is not divisible by `world_size`, the extra elements will be
+ distributed to the first few processes, ensuring that no process receives more than one additional element
+ compared to the others.
+ - The tensor will be gathered along the `gather_dim`, with each process contributing its part to form the
+ final output tensor. The gathering process also supports unaligned data, where the remainder elements
+ are distributed to the first few processes.
+ """
+ assert 3 <= input_.dim() <= 4
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return input_
+
+ scatter_size = input_.size(scatter_dim)
+ if gather_size is None:
+ gather_size = input_.size(gather_dim) * world_size
+ gather_mod = gather_size % world_size
+ scatter_mod = scatter_size % world_size
+
+ if gather_mod == 0 and scatter_mod == 0:
+ # In the case of aligned data (both scatter_size and gather_size are divisible by world_size),
+ # _aligned_all_to_all function performs better than _partial_unaligned_all_to_all function
+ return _aligned_all_to_all(input_, group, scatter_dim, gather_dim)
+ elif gather_mod != 0 and scatter_mod != 0:
+ return _full_unaligned_all_to_all(input_, group, scatter_dim, gather_dim, gather_size)
+ else:
+ return _partial_unaligned_all_to_all(input_, group, scatter_dim, gather_dim, gather_size)
+
+
+def _full_unaligned_all_to_all(
+ input_: torch.Tensor,
+ group: dist.ProcessGroup,
+ scatter_dim: int,
+ gather_dim: int,
+ gather_size: Optional[int] = None
+):
+ """
+ Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
+ dimension and then gathers it along the specified gather dimension. This function supports unaligned scatter
+ and gather sizes.
+
+ Args:
+ input_ (torch.Tensor): The input tensor to be processed.
+ world_size (int): The number of processes in the process group.
+ group (dist.ProcessGroup): The process group to perform the operation within.
+ scatter_dim (int): The index of the dimension that needs to be scattered.
+ gather_dim (int): The index of the dimension that needs to be gathered.
+ gather_size (Optional[int]): The total size of the output tensor along the `gather_dim`. If not provided, it
+ will be calculated as the product of the original size of the `gather_dim` of the input tensor and the
+ `world_size`.
+
+ Returns:
+ torch.Tensor: The resulting tensor after performing the all-to-all operation.
+ """
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+
+ scatter_sizes = cal_split_sizes(dim_size=input_.size(scatter_dim), world_size=world_size)
+ input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
+
+ gather_sizes = cal_split_sizes(dim_size=gather_size, world_size=world_size)
+ output_list = []
+ tensor_shape_base = input_list[rank].size()
+ for i in range(world_size):
+ tensor_shape = list(tensor_shape_base)
+ tensor_shape[gather_dim] = gather_sizes[i]
+ output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
+
+ dist.all_to_all(output_list, input_list, group=group)
+
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+
+def _aligned_all_to_all(
+ input_: torch.Tensor,
+ group: dist.ProcessGroup,
+ scatter_dim: int,
+ gather_dim: int,
+):
+ """
+ Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
+ dimension and then gathers it along the specified gather dimension.
+ Special note: The function only supports aligned data (both scatter_size and gather_size are divisible by
+ world_size)
+ """
+ world_size = dist.get_world_size(group)
+ inp_shape = list(input_.shape)
+ inp_shape[scatter_dim] = inp_shape[scatter_dim] // world_size
+ if scatter_dim == 0:
+ input_t = input_.reshape([world_size] + inp_shape).contiguous()
+ else:
+ input_t = input_.reshape([-1, world_size] + inp_shape[scatter_dim:]).transpose(0, 1).contiguous()
+
+ output = torch.empty_like(input_t)
+
+ dist.all_to_all_single(output, input_t, group=group)
+
+ output = output.view([world_size] + inp_shape).contiguous()
+ output_dim = output.dim()
+ if gather_dim == 1:
+ # the shape of input_t is (world_size, inp_shape[0], inp_shape[gather_dim], *inp_shape[2:])
+ output = output.transpose(0, 1).contiguous()
+ # the shape of output is (inp_shape[0], world_size, inp_shape[gather_dim], *inp_shape[2:])
+ elif gather_dim == 2:
+ # the shape of input_t is (world_size, inp_shape[0], inp_shape[1], *inp_shape[gather_dim:])
+ output = output.permute(*PERMUTE_DIMS2[output_dim]).contiguous()
+ # the shape of output is (inp_shape[0], inp_shape[1], world_size, *inp_shape[gather_dim:])
+ elif gather_dim == 3:
+ # the shape of input_t is (world_size, inp_shape[0], inp_shape[1], inp_shape[2], inp_shape[gather_dim])
+ output = output.permute(*PERMUTE_DIMS1[output_dim]).contiguous()
+ # the shape of output is (inp_shape[0], inp_shape[1], inp_shape[2], world_size, inp_shape[gather_dim])
+ # The last case: gather_dim == 0:
+ # the shape of input_t is (world_size, inp_shape[gather_dim], inp_shape[0], *inp_shape[1:])
+ # output requires no action
+ # the shape of output is (world_size, inp_shape[gather_dim], inp_shape[0], *inp_shape[1:])
+ output = output.view(inp_shape[:gather_dim] + [inp_shape[gather_dim] * world_size, ] + inp_shape[gather_dim + 1:]
+ ).contiguous()
+
+ return output
+
+
+def _partial_unaligned_all_to_all(
+ input_: torch.Tensor,
+ group: dist.ProcessGroup,
+ scatter_dim: int,
+ gather_dim: int,
+ gather_size: Optional[int] = None
+):
+ """
+ Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
+ dimension and then gathers it along the specified gather dimension. The function supports aligned and unaligned
+ data.
+ Special note: In the case of aligned data (both scatter_size and gather_size are divisible by world_size),
+ _partial_unaligned_all_to_all function performs worse than _aligned_all_to_all function. Therefore, in the case of
+ aligning data, it is recommended to use _aligned_all_to_all function.
+ """
+ world_size = dist.get_world_size(group)
+ input_ = input_.contiguous()
+ rank = dist.get_rank(group=group)
+
+ scatter_size = input_.size(scatter_dim)
+ if gather_size is None:
+ gather_size = input_.size(gather_dim) * world_size
+ assert not (gather_size % world_size != 0 and scatter_size % world_size != 0)
+
+ scatter_size_per_rank = scatter_size // world_size
+ scatter_size_remainder = scatter_size % world_size
+ input_split_sizes = [scatter_size_per_rank + (1 if i < scatter_size_remainder else 0) for i in range(world_size)]
+
+ gather_size_per_rank = gather_size // world_size
+ gather_size_remainder = gather_size % world_size
+ output_split_sizes = [gather_size_per_rank + (1 if i < gather_size_remainder else 0) for i in range(world_size)]
+
+ # Adjusts the dimensions of a tensor to move scatter_idx and gather_idx to dim 0 and dim 1 respectively.
+ reshaped_input, reshaped_input_dims = adjust_tensor_dimensions(input_, scatter_dim, gather_dim)
+ reshaped_input_shape = list(reshaped_input.shape)
+ # the shape of reshaped_input is (input_.size(scatter_dim), input_.size(gather_dim), *reshaped_input_shape[2:])
+
+ if scatter_size % world_size == 0:
+ reshaped_input = reshaped_input.view(
+ [world_size, input_.size(scatter_dim) // world_size, input_.size(gather_dim)] + reshaped_input_shape[2:]
+ ).transpose(1, 2).contiguous()
+
+ output_dims = reshaped_input_dims
+ # Relative to reshaped_input(the return value of adjust_tensor_dimensions func),
+ # which shape is (input_.size(scatter_dim), input_.size(gather_dim), *reshaped_input_shape[2:]),
+ # output just swaps the 0th and 1st axes.
+ output_dims[1], output_dims[0] = output_dims[0], output_dims[1]
+ output = torch.empty((gather_size, input_split_sizes[rank], *reshaped_input_shape[2:]),
+ dtype=input_.dtype, device=input_.device)
+ output_shape = list(output.shape)
+
+ dist.all_to_all_single(
+ output,
+ reshaped_input,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes if scatter_size % world_size != 0 else [1 for _ in range(world_size)],
+ group=group,
+ )
+
+ if gather_size % world_size == 0 and scatter_size % world_size != 0:
+ output = output.view(
+ [world_size, input_split_sizes[rank], gather_size // world_size] + reshaped_input_shape[2:]
+ ).transpose(1, 2).reshape(output_shape).contiguous()
+
+ # Reverses the dimension adjustments using the list of adjusted dimensions.
+ unadjust_output_ = unadjust_tensor_dimensions(output, output_dims)
+
+ return unadjust_output_
+
+
+class _AllToAll(torch.autograd.Function):
+ """Custom autograd function that performs an all-to-all communication.
+ This function supports both aligned and unaligned data.
+ """
+ @staticmethod
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim, gather_size=None):
+ """
+ Forward pass: Perform all-to-all communication by scattering the input tensor along the specified scatter
+ dimension and then gathering it along the specified gather dimension.
+
+ Args:
+ input_ (torch.Tensor): The input tensor to be processed.
+ process_group (dist.ProcessGroup): The process group to perform the operation within.
+ scatter_dim (int): The index of the dimension that needs to be scattered.
+ gather_dim (int): The index of the dimension that needs to be gathered.
+ gather_size (int): The size of the gather dimension.
+
+ Returns:
+ torch.Tensor: The resulting tensor after performing the all-to-all operation.
+ """
+ ctx.process_group = process_group
+ ctx.scatter_dim = scatter_dim
+ ctx.scatter_size = input_.size(scatter_dim)
+ ctx.gather_dim = gather_dim
+ ctx.gather_size = gather_size
+ output = _all_to_all(
+ input_, process_group, scatter_dim, gather_dim, gather_size
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ Backward pass: Perform the reverse all-to-all communication
+
+ Args:
+ grad_output (torch.Tensor): The gradient of the output with respect to the loss.
+
+ Returns:
+ tuple: The gradient of the input with respect to the loss and `None` for other arguments.
+ """
+ grad_output = _all_to_all(
+ grad_output,
+ ctx.process_group,
+ ctx.gather_dim,
+ ctx.scatter_dim,
+ ctx.scatter_size
+ )
+ return (
+ grad_output,
+ None,
+ None,
+ None,
+ None,
+ None
+ )
+
+
+def _split(
+ input_: torch.Tensor,
+ pg: dist.ProcessGroup,
+ dim: int = -1,
+ split_sizes: Optional[List[int]] = None
+) -> torch.Tensor:
+ """
+ Splits a tensor across the specified dimension and returns the part corresponding to the current rank,
+ supporting aligned and unaligned data.
+
+ Args:
+ input_ (torch.Tensor): The input tensor to be split.
+ pg (dist.ProcessGroup): The process group to perform the operation within.
+ dim (int, optional): The dimension along which to split the tensor. Defaults to -1 (last dimension).
+ split_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be split.
+ If not provided, the tensor will be split equally among the processes, with the remainder
+ distributed to the first few processes. Defaults to None.
+
+ Returns:
+ torch.Tensor: The part of the tensor corresponding to the current rank in the process group.
+ """
+ # Ensure split_sizes is a list if provided
+ assert split_sizes is None or isinstance(split_sizes, list)
+
+ # skip if only one rank involved
+ world_size = dist.get_world_size(pg)
+
+ if world_size == 1:
+ return input_
+
+ # Calculate split sizes if not provided
+ if split_sizes is None:
+ dim_size = input_.size(dim)
+ base_size = dim_size // world_size
+ remainder = dim_size % world_size
+
+ # Calculate the size for each process
+ split_sizes = [base_size + 1 if i < remainder else base_size for i in range(world_size)]
+
+ tensor_list = torch.split(input_, split_sizes, dim=dim)
+
+ # Get the part corresponding to the current rank
+ rank = dist.get_rank(pg)
+ output = tensor_list[rank].contiguous()
+
+ return output
+
+
+def _gather(input_: torch.Tensor,
+ pg: dist.ProcessGroup,
+ dim: int = -1,
+ gather_sizes: Optional[List[int]] = None):
+ """
+ Gathers tensors from all processes in the process group and concatenates them along the specified dimension,
+ supporting aligned and unaligned data.
+
+ Args:
+ input_ (torch.Tensor): The input tensor to be gathered.
+ pg (dist.ProcessGroup): The process group to perform the operation within.
+ dim (int, optional): The dimension along which to concatenate the gathered tensors. Defaults to -1 (last dimension).
+ gather_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be gathered.
+ If not provided, it is assumed that all tensors have the same shape as the input tensor. Defaults to None.
+
+ Returns:
+ torch.Tensor: The concatenated tensor after gathering from all processes in the process group.
+ """
+ # Ensure gather_sizes is a list if provided
+ assert gather_sizes is None or isinstance(gather_sizes, list)
+
+ # Skip if only one rank is involved
+ world_size = dist.get_world_size(pg)
+ if world_size == 1:
+ return input_
+
+ input_ = input_.contiguous()
+
+ # Prepare the output list with appropriate shapes
+ if gather_sizes:
+ tensor_list = []
+ tensor_shape_base = input_.size()
+ for i in range(world_size):
+ tensor_shape = list(tensor_shape_base)
+ tensor_shape[dim] = gather_sizes[i]
+ tensor_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
+ else:
+ tensor_list = [torch.empty_like(input_, dtype=input_.dtype, device=input_.device) for _ in range(world_size)]
+
+ assert input_.device.type == "cuda" or input_.device.type == "npu"
+ torch.distributed.all_gather(tensor_list, input_, group=pg)
+
+ # concat
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+ return output
+
+
+class _GatherForwardSplitBackward(torch.autograd.Function):
+ """
+ Custom autograd function that gathers the input tensor from all processes in the model parallel region and
+ concatenates them.
+ During the backward pass, it splits the gradients and scales them according to the gradient scaling mode.
+
+ """
+
+ @staticmethod
+ def symbolic(graph, input_, process_group, dim, gather_sizes):
+ """
+ Define the symbolic representation of the custom operation.
+ """
+ return _gather(input_, process_group, dim, gather_sizes)
+
+ @staticmethod
+ def forward(ctx, input_, process_group, dim, gather_sizes, grad_scale="up"):
+ """
+ Forward pass: Gathers tensors from all processes in the specified process group and concatenates them along the specified dimension.
+
+ Args:
+ input_ (torch.Tensor): The input tensor to be processed.
+ process_group (dist.ProcessGroup): The process group to perform the operation within.
+ dim (int): The dimension along which to concatenate the gathered tensors.
+ gather_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be gathered.
+ grad_scale (str, optional): Gradient scaling mode. Can be "up", "down", or None. Defaults to "up".
+
+ Returns:
+ torch.Tensor: The resulting tensor after gathering and concatenating.
+ """
+ ctx.mode = process_group
+ ctx.dim = dim
+ ctx.grad_scale = grad_scale
+
+ ctx.gather_sizes = gather_sizes
+ return _gather(input_, process_group, dim, ctx.gather_sizes)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ Backward pass: Distribute the gradients to the input tensors and scales them according to the gradient scaling mode.
+
+ Args:
+ grad_output (torch.Tensor): The gradient of the output.
+
+ Returns:
+ torch.Tensor: The gradient of the input with respect to the loss.
+ """
+ if ctx.grad_scale == "up":
+ grad_output = grad_output * dist.get_world_size(ctx.mode)
+ elif ctx.grad_scale == "down":
+ grad_output = grad_output / dist.get_world_size(ctx.mode)
+
+ return _split(grad_output, ctx.mode, ctx.dim, ctx.gather_sizes), None, None, None, None
+
+
+class _SplitForwardGatherBackward(torch.autograd.Function):
+ """
+ Custom autograd function that splits the input tensor and keeps only the corresponding chunk for the current rank.
+ During the backward pass, it gathers the gradients and scales them according to the gradient scaling mode.
+
+ """
+ @staticmethod
+ def symbolic(graph, input_, process_group, dim, split_sizes):
+ return _split(input_, process_group, dim, split_sizes)
+
+ @staticmethod
+ def forward(ctx, input_, process_group, dim, split_sizes, grad_scale):
+ ctx.mode = process_group
+ ctx.dim = dim
+ ctx.grad_scale = grad_scale
+
+ ctx.split_sizes = split_sizes
+
+ return _split(input_, process_group, dim, ctx.split_sizes)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.grad_scale == "up":
+ grad_output = grad_output * dist.get_world_size(ctx.mode)
+ elif ctx.grad_scale == "down":
+ grad_output = grad_output / dist.get_world_size(ctx.mode)
+ return _gather(grad_output, ctx.mode, ctx.dim, ctx.split_sizes), None, None, None, None
+
+
+def all_to_all(
+ input_: torch.Tensor,
+ process_group: dist.ProcessGroup,
+ scatter_dim: int = 2,
+ gather_dim: int = 1,
+ gather_size: Optional[int] = None
+):
+ """
+ Performs an all-to-all operation on the input tensor. The input tensor is scattered along the specified scatter
+ dimension and then gathered along the specified gather dimension.
+ This function supports both aligned and unaligned data.
+
+ Args:
+ input_ (torch.Tensor): The input tensor to be processed.
+ process_group (dist.ProcessGroup): The process group to perform the operation within.
+ scatter_dim (int, optional): The index of the dimension that needs to be scattered. Defaults to 2.
+ gather_dim (int, optional): The index of the dimension that needs to be gathered. Defaults to 1.
+ gather_size (Optional[int]): The total size of the output tensor along the `gather_dim`. If not provided, it
+ will be calculated as the product of the original size of the `gather_dim` of the input tensor and the
+ `world_size`.
+
+ Returns:
+ torch.Tensor: The resulting tensor after performing the all-to-all operation.
+ """
+ return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, gather_size)
+
+
+def split_forward_gather_backward(
+ input_: torch.Tensor,
+ process_group: dist.ProcessGroup,
+ dim: int,
+ split_sizes: Optional[List[int]] = None,
+ grad_scale: str = "down"
+
+) -> torch.Tensor:
+ """
+ Splits the input tensor and keeps only the corresponding chunk for the current rank.
+ During the backward pass, it gathers the gradients and scales them according to the gradient scaling mode.
+ This function supports both aligned and unaligned data.
+ Args:
+ input_ (torch.Tensor): The input tensor to be processed.
+ process_group (dist.ProcessGroup): The process group to perform the operation within.
+ dim (int): The dimension along which to split the tensor.
+ split_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be split.
+ If not provided, the tensor will be split equally among the processes. Defaults to None.
+ grad_scale (str, optional): Gradient scaling mode. Can be "up", "down", or None. Defaults to "down".
+
+ Returns:
+ torch.Tensor: The resulting tensor after splitting and keeping only the corresponding chunk.
+ """
+ return _SplitForwardGatherBackward.apply(input_, process_group, dim, split_sizes, grad_scale)
+
+
+def gather_forward_split_backward(
+ input_: torch.Tensor,
+ process_group: dist.ProcessGroup,
+ dim: int,
+ gather_sizes: Optional[List[int]] = None,
+ grad_scale: str = "up"
+) -> torch.Tensor:
+ """
+ Gathers the input tensor from all processes in the model parallel region and concatenates them along the specified
+ dimension. During the backward pass, it splits the gradients and scales them according to the gradient scaling mode.
+ This function handles both aligned and unaligned data during the gather and scatter operations.
+ Args:
+ input_ (torch.Tensor): The input tensor to be processed.
+ process_group (dist.ProcessGroup): The process group to perform the operation within.
+ dim (int): The dimension along which to concatenate the gathered tensors.
+ gather_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be gathered.
+ If not provided, it is assumed that all tensors have the same shape as the input tensor. Defaults to None.
+ grad_scale (str, optional): Gradient scaling mode. Can be "up", "down", or None. Defaults to "up".
+
+ Returns:
+ torch.Tensor: The resulting tensor after gathering and concatenating.
+ """
+ return _GatherForwardSplitBackward.apply(input_, process_group, dim, gather_sizes, grad_scale)
+
+
+def cal_split_sizes(dim_size, world_size):
+ """
+ Calculate the split sizes for a given dimension size and number of processes.
+
+ This function divides the dimension size into `world_size` parts, distributing any remainder
+ among the first few parts.
+
+ Args:
+ dim_size (int): The total size of the dimension to be split.
+ world_size (int): The number of processes (or parts) to split the dimension into.
+
+ Returns:
+ List[int]: A list of integers representing the size of each part after splitting.
+ """
+ split_size = dim_size // world_size
+ remainder = dim_size % world_size
+ size_list = [split_size + (1 if i < remainder else 0) for i in range(world_size)]
+ return size_list
diff --git a/model/train/yoco_moe/mindspeed/core/context_parallel/utils.py b/model/train/yoco_moe/mindspeed/core/context_parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99970c59e74e2c5fac32f63285dede027fe6037
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/context_parallel/utils.py
@@ -0,0 +1,661 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+import torch_npu
+import torch.distributed as dist
+import numpy as np
+from einops import rearrange
+from scipy.sparse.linalg import eigsh
+from megatron.training import get_args
+from megatron.core.parallel_state import get_context_parallel_global_ranks
+from mindspeed.ops.npu_ring_attention_update import npu_ring_attention_update
+from mindspeed.core.parallel_state import get_context_parallel_for_hybrid_ring_global_ranks
+from mindspeed.op_builder import AdaptiveCpOpBuilder
+
+
+ADAPTIVE_CP_SCHEDULING_INFO = None
+ADAPTIVE_CP_SEQ_ORDER = None
+CACHED_GRID_MASK = None
+CACHED_SEQ = None
+CACHED_MASK_LIST = []
+CACHED_SCHEDULING = None
+COMM_THRESHOLD = 6
+ADAPTIVE_CP_DEFAULT_SHAPE = 1024
+ADAPTIVE_CP_MASK_LIST_SET_BY_USER = None
+ADAPTIVE_CP_GRID_MASK_SET_BY_USER = None
+
+
+# SBH -> TND
+def sbh_to_tnd(x, n):
+ s, b, h = x.shape
+ d, t = h // n, int(b * s)
+ return x.transpose(0, 1).view(t, h).view(t, n, d)
+
+
+# TND -> SBH
+def tnd_to_sbh(x, b):
+ t, n, d = x.shape
+ s, h = t // b, int(n * d)
+ return x.view(b, s, n, d).transpose(0, 1).view(s, b, h)
+
+
+def get_selection_indices_for_tnd_softmax_update(t, n, sub_seq_len):
+ full_indices = list(range(t * n))
+ cur_seq_start_idx = 0
+ indices = []
+ seq_start = 0
+ for seq_len in sub_seq_len:
+ for i in range(n):
+ start = seq_start + seq_len * 2 * i + seq_len
+ end = seq_start + seq_len * 2 * (i + 1)
+ indices.extend(full_indices[start:end])
+ seq_start += seq_len * n * 2
+
+ return torch.tensor(indices)
+
+
+def flatten_softmax(x, sub_seq_len):
+ orig_shape = x.shape
+ section_len = [s * orig_shape[1] for s in sub_seq_len]
+ splits = x.view(-1, orig_shape[-1]).split(section_len, dim=0)
+ merged = [item.view(orig_shape[1], -1, orig_shape[-1]).transpose(0, 1) for item in splits]
+ merged = torch.cat(merged, dim=0)
+ return merged
+
+
+def unflatten_softmax(x, sub_seq_len):
+ orig_shape = x.shape
+ section_len = [s * orig_shape[1] for s in sub_seq_len]
+ splits = x.view(-1, orig_shape[-1]).split(section_len, dim=0)
+ merged = [item.view(-1, orig_shape[1], orig_shape[-1]).transpose(0, 1) \
+ .view(-1, orig_shape[-1]) for item in splits]
+ merged = torch.cat(merged, dim=0)
+ return merged.view(*orig_shape)
+
+
+def forward_update_without_fused(prev_attn_out, prev_softmax_max, prev_softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH'):
+ if layout == 'TND':
+ cur_softmax_max = flatten_softmax(cur_softmax_max, actual_seq_qlen)
+ cur_softmax_sum = flatten_softmax(cur_softmax_sum, actual_seq_qlen)
+ prev_softmax_max = flatten_softmax(prev_softmax_max, actual_seq_qlen)
+ prev_softmax_sum = flatten_softmax(prev_softmax_sum, actual_seq_qlen)
+ # update softmax_max
+ origin_dtype = prev_attn_out.dtype
+ softmax_max = torch.maximum(prev_softmax_max, cur_softmax_max)
+ prev_scale = torch.exp(prev_softmax_max - softmax_max)
+ cur_scale = torch.exp(cur_softmax_max - softmax_max)
+
+ # update softmax_sum
+ prev_softmax_sum_scaled = prev_softmax_sum * prev_scale
+ cur_softmax_sum_scaled = cur_softmax_sum * cur_scale
+ softmax_sum = prev_softmax_sum_scaled + cur_softmax_sum_scaled
+
+ # out updating scale
+ prev_out_scale = prev_softmax_sum_scaled / softmax_sum
+ cur_out_scale = cur_softmax_sum_scaled / softmax_sum
+
+ # [b, n, s, 8] -> [s, b, h]
+ if layout == 'SBH':
+ n = prev_out_scale.shape[1]
+ h = prev_attn_out.shape[-1]
+ d = h // n
+ prev_out_scale = prev_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
+ prev_out_scale = rearrange(prev_out_scale, 'b n s d -> s b (n d)').contiguous()
+ cur_out_scale = cur_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
+ cur_out_scale = rearrange(cur_out_scale, 'b n s d -> s b (n d)').contiguous()
+ elif layout == 'TND':
+ d = prev_attn_out.shape[-1]
+ prev_out_scale = prev_out_scale[..., 0].unsqueeze(2).repeat(1, 1, d)
+ cur_out_scale = cur_out_scale[..., 0].unsqueeze(2).repeat(1, 1, d)
+
+ # update output
+ attn_out = prev_attn_out * prev_out_scale + cur_attn_out * cur_out_scale
+ attn_out = attn_out.to(origin_dtype)
+ if layout == 'TND':
+ softmax_max = unflatten_softmax(softmax_max, actual_seq_qlen)
+ softmax_sum = unflatten_softmax(softmax_sum, actual_seq_qlen)
+ return attn_out, softmax_max, softmax_sum
+
+
+class RingP2P:
+ def __init__(self, ring_global_ranks, group, group_for_send_recv_overlap=None, is_backward=False) -> None:
+ self.group = group
+ self.group_for_send_recv_overlap = group
+ if group_for_send_recv_overlap is not None:
+ self.group_for_send_recv_overlap = group_for_send_recv_overlap
+
+ global_rank = dist.get_rank()
+ ring_rank = ring_global_ranks.index(global_rank)
+ ring_size = len(ring_global_ranks)
+ self.next = ring_global_ranks[(ring_rank + 1) % ring_size]
+ self.prev = ring_global_ranks[(ring_rank + ring_size - 1) % ring_size]
+ self.ring_rank = ring_rank
+ if is_backward:
+ self.next, self.prev = self.prev, self.next
+
+ self.send_recv_ops = []
+
+ def async_send_recv(self, send_tensor, recv_tensor):
+ if self.ring_rank % 2 == 0:
+ send_op = dist.isend(send_tensor, self.next, self.group)
+ recv_op = dist.irecv(recv_tensor, self.prev, self.group_for_send_recv_overlap)
+ self.send_recv_ops.append(send_op)
+ self.send_recv_ops.append(recv_op)
+ else:
+ recv_op = dist.irecv(recv_tensor, self.prev, self.group)
+ send_op = dist.isend(send_tensor, self.next, self.group_for_send_recv_overlap)
+ self.send_recv_ops.append(recv_op)
+ self.send_recv_ops.append(send_op)
+
+ def wait(self):
+ if len(self.send_recv_ops) > 0:
+ for op in self.send_recv_ops:
+ op.wait()
+ self.send_recv_ops = []
+ return 1
+ else:
+ return 0
+
+
+def forward_update(prev_attn_out, prev_softmax_max, prev_softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH'):
+ """
+ Updates the attention output and softmax statistics for the ring attention mechanism,
+ with added parameters for enhanced flexibility and extensibility.
+
+ This function is designed to update the attention output and related softmax statistics
+ for a given sequence length in a ring attention mechanism. It handles the merging of
+ previous and current attention outputs and their corresponding softmax statistics.
+ The introduction of `actual_seq_qlen` and `layout` parameters allows for greater flexibility
+ in handling variable sequence lengths and different tensor layouts, respectively.
+
+ Parameters:
+ - prev_attn_out (Tensor): The attention output from the previous process.
+ - prev_softmax_max (Tensor): The maximum value of the softmax distribution from the previous process.
+ - prev_softmax_sum (Tensor): The sum of the softmax distribution from the previous process.
+ - cur_attn_out (Tensor): The attention output from the current process.
+ - cur_softmax_max (Tensor): The maximum value of the softmax distribution from the current process.
+ - cur_softmax_sum (Tensor): The sum of the softmax distribution from the current process.
+ - actual_seq_qlen (Tensor, optional): The actual sequence length for the query. This parameter
+ is crucial for handling variable-length sequences and ensuring
+ that the attention mechanism operates correctly under such conditions.
+ If not provided, it defaults to the length of the current attention output.
+ - layout (str, optional): The layout format of the input tensors. This parameter allows for the specification
+ of different tensor layouts, enhancing the function's versatility across various
+ model architectures. Default is 'SBH', where:
+ - S: Sequence length
+ - B: Batch size
+ - H: Hidden size (number of attention heads)
+
+ Returns:
+ - updated_attn_out (Tensor): The updated attention output after merging previous and current process.
+ - updated_softmax_max (Tensor): The updated maximum value of the softmax distribution.
+ - updated_softmax_sum (Tensor): The updated sum of the softmax distribution.
+ """
+ _args = get_args()
+ if hasattr(_args, 'use_fused_ring_attention_update') and _args.use_fused_ring_attention_update:
+ def accumulate_list(input_list):
+ """
+ 借助numpy库将列表转换为numpy数组进行元素累加,再转换回列表并在开头添加0
+ """
+ np_array = np.array(input_list)
+ cumsum_result = np.cumsum(np_array)
+ return torch.tensor([0] + list(cumsum_result), dtype=torch.int64).to(prev_attn_out.device)
+
+ if layout == "TND":
+ actual_seq_qlen = accumulate_list(actual_seq_qlen)
+ return npu_ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out,
+ cur_softmax_max, cur_softmax_sum, actual_seq_qlen, layout)
+
+ return forward_update_without_fused(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out,
+ cur_softmax_max, cur_softmax_sum, actual_seq_qlen, layout)
+
+
+def tnd_out_update(q_block_id, kv_block_id, cur_attn_outs, global_attn_outs, q_index, softmax_indices, cur_sub_out_seq_len):
+ cur_attn_out, cur_softmax_max, cur_softmax_sum = cur_attn_outs[0], cur_attn_outs[1], cur_attn_outs[2]
+ attn_out, softmax_max, softmax_sum, rng_states = global_attn_outs
+
+ layout = 'TND'
+
+ if len(cur_attn_outs) > 3:
+ rng_states[kv_block_id] = (cur_attn_outs[4], cur_attn_outs[5], cur_attn_outs[6])
+
+ if q_block_id == kv_block_id:
+ attn_out = cur_attn_out
+ softmax_max = cur_softmax_max
+ softmax_sum = cur_softmax_sum
+ elif kv_block_id <= q_block_id:
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out, softmax_max, softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=cur_sub_out_seq_len, layout=layout
+ )
+ attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
+ else:
+ n = attn_out.shape[1]
+ t = attn_out.shape[0]
+ prev_softmax_max = softmax_max.view(-1, 8)[softmax_indices].view(-1, n, 8)
+ prev_softmax_sum = softmax_sum.view(-1, 8)[softmax_indices].view(-1, n, 8)
+
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ torch.index_select(attn_out, 0, q_index), prev_softmax_max, prev_softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=cur_sub_out_seq_len, layout=layout
+ )
+ attn_out.index_copy_(0, q_index, attn_out_updated)
+ softmax_max = softmax_max.view(-1, 8).index_copy(0, softmax_indices, softmax_max_updated.view(-1, 8)).view(-1, n, 8)
+ softmax_sum = softmax_sum.view(-1, 8).index_copy(0, softmax_indices, softmax_sum_updated.view(-1, 8)).view(-1, n, 8)
+
+
+ return [attn_out, softmax_max, softmax_sum, rng_states]
+
+
+def causal_out_update(q_block_id, kv_block_id, cur_attn_outs, global_attn_outs):
+ cur_attn_out, cur_softmax_max, cur_softmax_sum = cur_attn_outs[0], cur_attn_outs[1], cur_attn_outs[2]
+ attn_out, softmax_max, softmax_sum, rng_states = global_attn_outs
+ layout = 'SBH'
+ if len(cur_attn_outs) > 3:
+ rng_states[kv_block_id] = (cur_attn_outs[4], cur_attn_outs[5], cur_attn_outs[6])
+
+ if q_block_id == kv_block_id:
+ attn_out = cur_attn_out
+ softmax_max = cur_softmax_max
+ softmax_sum = cur_softmax_sum
+ elif kv_block_id <= q_block_id:
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out, softmax_max, softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout=layout
+ )
+ attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
+ else:
+ # [2s, b, h] -> [2, s, b, h]
+ attn_out = attn_out.view(2, attn_out.shape[0] // 2, *attn_out.shape[1:])
+ # [b, n, 2s, 8] -> [b, n, 2, s, 8]
+ softmax_max = softmax_max.view(softmax_max.shape[0], softmax_max.shape[1],
+ 2, softmax_max.shape[2] // 2, softmax_max.shape[-1])
+ softmax_sum = softmax_sum.view(softmax_sum.shape[0], softmax_sum.shape[1],
+ 2, softmax_sum.shape[2] // 2, softmax_sum.shape[-1])
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out[1], softmax_max[:, :, 1, :, :], softmax_sum[:, :, 1, :, :],
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout=layout
+ )
+ attn_out[1].copy_(attn_out_updated)
+ softmax_max[:, :, 1, :, :].copy_(softmax_max_updated)
+ softmax_sum[:, :, 1, :, :].copy_(softmax_sum_updated)
+ # [2, s, b, h] -> [2s, b, h]
+ attn_out = attn_out.view(-1, *attn_out.shape[2:])
+ # [b, n, 2, s, 8] -> [b, n, 2s, 8]
+ softmax_max = softmax_max.view(softmax_max.shape[0], softmax_max.shape[1], -1,
+ softmax_max.shape[-1])
+ softmax_sum = softmax_sum.view(softmax_sum.shape[0], softmax_sum.shape[1], -1,
+ softmax_sum.shape[-1])
+
+ return [attn_out, softmax_max, softmax_sum, rng_states]
+
+
+def general_out_update(q_block_id, kv_block_id, cur_attn_outs, global_attn_outs):
+ cur_attn_out, cur_softmax_max, cur_softmax_sum = cur_attn_outs[0], cur_attn_outs[1], cur_attn_outs[2]
+ attn_out, softmax_max, softmax_sum, rng_states = global_attn_outs
+ layout = 'SBH'
+ rng_states[kv_block_id] = (cur_attn_outs[4], cur_attn_outs[5], cur_attn_outs[6])
+ if q_block_id == kv_block_id:
+ attn_out = cur_attn_out
+ softmax_max = cur_softmax_max
+ softmax_sum = cur_softmax_sum
+ else:
+ attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
+ attn_out, softmax_max, softmax_sum,
+ cur_attn_out, cur_softmax_max, cur_softmax_sum, layout=layout
+ )
+ attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
+
+ return [attn_out, softmax_max, softmax_sum, rng_states]
+
+
+class SchedulingInfo:
+ def __init__(self, round_idx, recv_q_src: int = -1, recv_kv_src: int = -1, recv_o_src: list = None,
+ send_q_dst=None, send_kv_dst: list = None, send_o_dst: int = -1, comm_unit_limit=6):
+ self.round_idx = round_idx
+ self.recv_q_src = recv_q_src # 下一轮计算需要的来自别处的Q,-1代表不需要
+ self.recv_kv_src = recv_kv_src # 下一轮计算需要的来自别处的KV,-1代表不需要
+ self.recv_o_src = [] if recv_o_src is None else recv_o_src # 本轮计算中哪些device帮本机算了
+ self.send_q_dst = [] if send_q_dst is None else send_q_dst # 下一轮计算中哪些device需要本机的Q
+ self.send_kv_dst = [] if send_kv_dst is None else send_kv_dst # 下一轮计算中哪些device需要本机的KV
+ self.send_o_dst = send_o_dst # 本轮计算帮哪个device算
+ self.comm_unit_limit = comm_unit_limit
+ self.cnt_comm_unit_forward = -1
+ self.check_eligibility()
+
+ def check_eligibility(self):
+ # 检查不能同时收Q和KV
+ if self.recv_q_src > -1 and self.recv_kv_src > -1:
+ raise ValueError("only receive one of q and kv in a single round")
+ # 检查总通信量是否符合限制
+ self.count_comm_units()
+ if self.cnt_comm_unit_forward > self.comm_unit_limit:
+ raise ValueError(f"comm unit exceed limit: round {self.round_idx}, device {torch.npu.current_device()}")
+
+ def count_comm_units(self):
+ sum_recv_units = self.recv_q_src > -1 + (self.recv_kv_src > -1) * 2 + len(self.recv_o_src)
+ sum_send_units = len(self.send_q_dst) + len(self.send_kv_dst) * 2 + self.send_o_dst > -1
+ self.cnt_comm_unit_forward = sum_recv_units + sum_send_units
+
+
+def coarsen_attn_mask_npu(attn_mask, coarse_ratio):
+ # 输出mask中0为需要计算的,1为不需要计算的
+ orig_size = attn_mask.shape[0]
+ attn_mask_reshaped = (~attn_mask)
+ attn_mask_reshaped = attn_mask_reshaped.view(orig_size // coarse_ratio, coarse_ratio,
+ orig_size // coarse_ratio, coarse_ratio).permute(0, 2, 1, 3)
+ coarse_attn_mask = ~torch.any(torch.any(attn_mask_reshaped, dim=3), dim=2)
+ return coarse_attn_mask
+
+
+def set_scheduling_info(cp_rank, scheduling):
+ global ADAPTIVE_CP_SCHEDULING_INFO
+ if ADAPTIVE_CP_SCHEDULING_INFO is None or get_args().adaptive_cp_dynamic_attn_mask:
+ ADAPTIVE_CP_SCHEDULING_INFO = process_scheduling_info(cp_rank, scheduling)[1:]
+
+
+def get_scheduling_info():
+ if ADAPTIVE_CP_SCHEDULING_INFO is None:
+ raise RuntimeError("Trying to get scheduling info before setting it, ADAPTIVE_CP_SCHEDULING_INFO is still None")
+ return ADAPTIVE_CP_SCHEDULING_INFO
+
+
+def set_remapped_seq_order(seq_order):
+ global ADAPTIVE_CP_SEQ_ORDER
+ ADAPTIVE_CP_SEQ_ORDER = seq_order
+
+
+def get_remapped_seq_order():
+ if ADAPTIVE_CP_SEQ_ORDER is None:
+ raise RuntimeError("Trying to get optimized sequence before setting it, ADAPTIVE_CP_SEQ_ORDER is still None")
+ return ADAPTIVE_CP_SEQ_ORDER
+
+
+def set_adaptive_cp_mask_list_by_user(mask_list):
+ global ADAPTIVE_CP_MASK_LIST_SET_BY_USER
+ ADAPTIVE_CP_MASK_LIST_SET_BY_USER = mask_list
+
+
+def get_adaptive_cp_mask_list_by_user():
+ global ADAPTIVE_CP_MASK_LIST_SET_BY_USER
+ if ADAPTIVE_CP_MASK_LIST_SET_BY_USER is None:
+ raise RuntimeError("Trying to get mask list before setting it, ADAPTIVE_CP_MASK_LIST_SET_BY_USER is still None")
+ return ADAPTIVE_CP_MASK_LIST_SET_BY_USER
+
+
+def generate_adaptive_cp_mask_list_by_user(opt_seq, scheduling_info, cp_rank, cp_size):
+ mask_list = None # replace with customized function to generate mask list
+ set_adaptive_cp_mask_list_by_user(mask_list)
+
+
+def set_adaptive_cp_grid_mask_by_user(grid_mask):
+ global ADAPTIVE_CP_GRID_MASK_SET_BY_USER
+ ADAPTIVE_CP_GRID_MASK_SET_BY_USER = grid_mask
+
+
+def get_adaptive_cp_grid_mask_by_user():
+ global ADAPTIVE_CP_GRID_MASK_SET_BY_USER
+ if ADAPTIVE_CP_GRID_MASK_SET_BY_USER is None:
+ raise RuntimeError("Trying to get grid mask before setting it, ADAPTIVE_CP_GRID_MASK_SET_BY_USER is still None")
+ return ADAPTIVE_CP_GRID_MASK_SET_BY_USER
+
+
+def generate_adaptive_cp_grid_mask_by_user(cp_size):
+ grid_mask = None # replace with customized function to generate grid mask
+ set_adaptive_cp_grid_mask_by_user(grid_mask)
+
+
+def process_scheduling_info(local_rank, orig_scheduling, comm_limit=6):
+ round_num = len(orig_scheduling)
+ device_num = len(orig_scheduling[0])
+ processed_scheduling_info = [SchedulingInfo(round_idx=i, comm_unit_limit=comm_limit) for i in range(round_num + 1)]
+ for rnd_idx in range(round_num):
+ process_single_scheduling_info(local_rank, device_num, rnd_idx, orig_scheduling[rnd_idx],
+ processed_scheduling_info)
+ return processed_scheduling_info
+
+
+def process_single_scheduling_info(local_rank, device_num, round_idx, round_scheduling_info, processed_scheduling_info):
+ if get_args().context_parallel_algo == 'adaptive_cp_algo':
+ rank_list = get_context_parallel_global_ranks()
+ else:
+ rank_list = get_context_parallel_for_hybrid_ring_global_ranks()
+ for execute_device_id, task_id in enumerate(round_scheduling_info): # 当前任务和实际执行当前任务的设备
+ if task_id == -1:
+ continue
+ origin_device_id = rank_list[int(task_id / device_num)] # 原本应该执行当前任务的设备
+ kv_device_id = rank_list[task_id % device_num] # 存储当前任务kv的设备
+ execute_device_id = rank_list[execute_device_id]
+ if execute_device_id != origin_device_id: # 需要收发qo
+ if execute_device_id == local_rank: # 当前rank对应的device是执行任务的device
+ processed_scheduling_info[round_idx].recv_q_src = origin_device_id
+ processed_scheduling_info[round_idx + 1].send_o_dst = origin_device_id
+ elif origin_device_id == local_rank: # 当前rank对应的device是原始的device
+ processed_scheduling_info[round_idx].send_q_dst.append(execute_device_id)
+ processed_scheduling_info[round_idx + 1].recv_o_src.append(execute_device_id)
+ else: # 需要收发kv
+ if execute_device_id == local_rank: # 当前rank对应的device是执行任务的device
+ processed_scheduling_info[round_idx].recv_kv_src = kv_device_id
+ elif kv_device_id == local_rank: # 当前rank对应的device是存储kv的device
+ processed_scheduling_info[round_idx].send_kv_dst.append(execute_device_id)
+ processed_scheduling_info[round_idx].check_eligibility()
+
+
+def adaptive_reschedule_task(grid_mask, cp_size):
+ scheduling_info = []
+ total_task = torch.sum(grid_mask)
+ round_idx = 0
+ next_comm = np.zeros(cp_size)
+ while total_task > 0:
+ scheduling_info.append([-1 for _ in range(cp_size)])
+ cur_comm = next_comm
+ next_comm = np.zeros(cp_size)
+ total_task -= execute_scheduling(grid_mask, cp_size, round_idx, cur_comm, next_comm, scheduling_info[round_idx])
+ round_idx += 1
+ return scheduling_info
+
+
+def execute_scheduling(grid_mask, cp_size, round_idx, cur_comm, next_comm, scheduling_info):
+ count = 0
+ is_free = np.ones(cp_size)
+ for device_id in range(cp_size):
+ row, col = find_kv_task(grid_mask, cp_size, round_idx, cur_comm, device_id, is_free)
+ if row != -1 and col != -1:
+ scheduling_info[device_id] = row * cp_size + col
+ grid_mask[row][col] = 0
+ count += 1
+ is_send_q = np.zeros(cp_size, dtype=int)
+ for device_id in range(cp_size):
+ if is_free[device_id] == 0:
+ continue
+ row, col = find_qo_task(grid_mask, cp_size, cur_comm, next_comm, device_id, is_send_q)
+ if row != -1 and col != -1:
+ scheduling_info[device_id] = row * cp_size + col
+ grid_mask[row][col] = 0
+ count += 1
+ return count
+
+
+def find_kv_task(grid_mask, cp_size, round_idx, cur_comm, device_id, is_free):
+ is_free[device_id] = 0
+ row = device_id
+ col = (device_id + round_idx) % cp_size
+ if grid_mask[row][col] == 1:
+ cur_comm[row] = cur_comm[row] + 2 # recv KV
+ cur_comm[col] = cur_comm[col] + 2 # send KV
+ return row, col
+ for i in range(1, cp_size): # find kv task
+ row = device_id
+ col = (device_id - i + cp_size) % cp_size
+ if grid_mask[row][col] == 1 and cur_comm[row] <= COMM_THRESHOLD - 2 and cur_comm[col] <= COMM_THRESHOLD - 2:
+ cur_comm[row] += 2 # recv KV
+ cur_comm[col] += 2 # send KV
+ return row, col
+ is_free[device_id] = 1
+ return -1, -1
+
+
+def find_qo_task(grid_mask, cp_size, cur_comm, next_comm, device_id, is_send_q):
+ for i in range(1, cp_size): # find qo task
+ row = (device_id + i) % cp_size
+ col = device_id
+ if grid_mask[row][col] == 1 and cur_comm[row] <= COMM_THRESHOLD - 1 and \
+ cur_comm[col] <= COMM_THRESHOLD - 1 and is_send_q[row] != 1:
+ is_send_q[row] = 1
+ cur_comm[row] += 1 # send Q
+ cur_comm[col] += 1 # recv Q
+ next_comm[row] += 1 # recv O
+ next_comm[col] += 1 # send O
+ return row, col
+ return -1, -1
+
+
+def clear_global_info():
+ global CACHED_SEQ, CACHED_GRID_MASK, CACHED_MASK_LIST, CACHED_SCHEDULING, ADAPTIVE_CP_SCHEDULING_INFO
+ CACHED_SEQ, CACHED_GRID_MASK, CACHED_MASK_LIST, CACHED_SCHEDULING, ADAPTIVE_CP_SCHEDULING_INFO = (None, None, [],
+ None, None)
+
+
+class AdaptiveCpOps:
+ def __init__(self):
+ self.ops = AdaptiveCpOpBuilder().load()
+
+ def coarsen_attn_mask_cpu(self, attn_mask, sampling_ratio):
+ if not attn_mask.is_contiguous():
+ attn_mask = attn_mask.contiguous()
+ mask_size_after_sampling = attn_mask.shape[0] // sampling_ratio
+ coarse_mask = torch.ones((mask_size_after_sampling, mask_size_after_sampling), dtype=torch.bool)
+ self.ops.coarsen_mask(attn_mask, mask_size_after_sampling, coarse_mask)
+ return coarse_mask
+
+ def get_grid_mask(self, attn_mask, cp_size):
+ if not attn_mask.is_contiguous():
+ attn_mask = attn_mask.contiguous()
+ if get_args().attention_mask_on_cpu:
+ grid_mask = torch.ones((cp_size, cp_size), dtype=torch.bool)
+ self.ops.coarsen_mask(attn_mask, cp_size, grid_mask)
+ else:
+ grid_mask = coarsen_attn_mask_npu(attn_mask, attn_mask.shape[0] // cp_size)
+ grid_mask = ~grid_mask
+ return grid_mask
+
+ def search_kmeans_cpu(self, attn_mask, reduced_mask, cp_size, num_iters=100):
+ tmp_attn_mask = torch.ones_like(attn_mask)
+ tmp_grid_mask = torch.ones((cp_size, cp_size), dtype=torch.bool)
+ optimal_attn_mask = torch.ones_like(attn_mask)
+ optimal_grid_mask = torch.ones((cp_size, cp_size), dtype=torch.bool)
+ optimal_num_cluster = [-1]
+ optimal_sorted_indices = self.ops.search_kmeans(attn_mask, reduced_mask, tmp_attn_mask, tmp_grid_mask,
+ optimal_grid_mask, optimal_attn_mask,
+ optimal_num_cluster, cp_size, num_iters)
+ return optimal_sorted_indices, optimal_grid_mask, optimal_attn_mask, optimal_num_cluster
+
+ def adaptive_remap(self, attn_mask, cp_size, truncated_dim=10):
+ args = get_args()
+ if attn_mask.dim() != 2 or attn_mask.shape[0] != attn_mask.shape[1]:
+ raise RuntimeError("Only 2-dimensional self-attention mask supported in adaptive cp")
+
+ if args.adaptive_cp_without_coarse:
+ sampling_ratio = 1
+ if args.attention_mask_on_cpu:
+ coarse_mask = attn_mask
+ else:
+ coarse_mask = attn_mask.cpu()
+ else:
+ if attn_mask.shape[0] % ADAPTIVE_CP_DEFAULT_SHAPE != 0:
+ raise RuntimeError("Shape of attention mask needs to be a multiple of 1024 if not enable "
+ "args.adaptive_cp_without_coarse in adaptive cp")
+ if args.attention_mask_on_cpu:
+ sampling_ratio = attn_mask.shape[0] // ADAPTIVE_CP_DEFAULT_SHAPE
+ coarse_mask = self.coarsen_attn_mask_cpu(attn_mask, sampling_ratio)
+ else:
+ sampling_ratio = attn_mask.shape[0] // ADAPTIVE_CP_DEFAULT_SHAPE
+ coarse_mask = coarsen_attn_mask_npu(attn_mask, sampling_ratio).cpu()
+
+ coarse_mask_np = coarse_mask.to(torch.float16).numpy()
+ mean_matrix = np.mean(coarse_mask_np, axis=0)
+ centered_matrix = (coarse_mask_np - mean_matrix).astype(float)
+ cov_matrix = np.matmul(centered_matrix.T, centered_matrix)
+ eigenvalues, eigenvectors = eigsh(cov_matrix, k=truncated_dim, which='LM')
+ feature_matrix = np.matmul(coarse_mask_np, eigenvectors).tolist()
+
+ optimal_seq, optimal_grid_mask, optimal_coarsen_attn_mask, optimal_num_cluster = (
+ self.search_kmeans_cpu(coarse_mask, feature_matrix, cp_size))
+
+ if args.adaptive_cp_without_coarse:
+ final_opt_seq = optimal_seq
+ else:
+ final_opt_seq = sampling_ratio * torch.tensor(optimal_seq)[:, None] + torch.arange(sampling_ratio)
+ final_opt_seq = final_opt_seq.view(-1).tolist()
+
+ optimal_grid_mask = ~optimal_grid_mask
+
+ return optimal_grid_mask, final_opt_seq
+
+ def get_adaptive_cp_info(self, attn_mask, cp_size):
+ args = get_args()
+ global CACHED_GRID_MASK, CACHED_SEQ
+ if args.attention_mask_on_cpu != (attn_mask.device.type == 'cpu'):
+ raise RuntimeError("args.attention_mask_on_cpu does not match the device of set attention mask")
+
+ # 生成重映射后的序列和重排后的gird mask,输出tensor(npu/cpu) opt_grid_mask和list opt_seq
+ if not args.adaptive_cp_only_reschedule:
+ if args.adaptive_cp_dynamic_attn_mask or CACHED_GRID_MASK is None:
+ opt_grid_mask, opt_seq = self.adaptive_remap(attn_mask, cp_size)
+ if not args.adaptive_cp_dynamic_attn_mask:
+ CACHED_GRID_MASK, CACHED_SEQ = opt_grid_mask, opt_seq
+ else:
+ opt_grid_mask, opt_seq = CACHED_GRID_MASK, CACHED_SEQ
+ else:
+ opt_seq = list(range(attn_mask.shape[0]))
+ if args.adaptive_cp_dynamic_attn_mask or CACHED_GRID_MASK is None:
+ opt_grid_mask = self.get_grid_mask(attn_mask, cp_size)
+ CACHED_GRID_MASK = opt_grid_mask
+ else:
+ opt_grid_mask = CACHED_GRID_MASK
+
+ # 生成调度方案
+ opt_scheduling = adaptive_reschedule_task(opt_grid_mask, cp_size)
+
+ return opt_seq, opt_scheduling
+
+ def get_mask_list(self, attn_mask, opt_scheduling, opt_seq, cp_rank, cp_size):
+ args = get_args()
+ global CACHED_MASK_LIST
+ if not args.adaptive_cp_dynamic_attn_mask and len(CACHED_MASK_LIST) > 0:
+ return CACHED_MASK_LIST
+ round_num = len(opt_scheduling)
+ grid_size = attn_mask.shape[0] // cp_size
+ mask_list = []
+
+ for rnd_idx in range(round_num):
+ task_id = opt_scheduling[rnd_idx][cp_rank]
+ if task_id == -1:
+ mask_list.append(None)
+ continue
+ q_device_id = task_id // cp_size
+ kv_device_id = task_id % cp_size
+ if args.attention_mask_on_cpu:
+ mask_list.append(torch.empty((grid_size, grid_size), dtype=torch.bool, device='cpu'))
+ if args.adaptive_cp_only_reschedule:
+ grid_inds = [q_device_id, kv_device_id]
+ self.ops.get_mask_list_without_remap(attn_mask, mask_list[rnd_idx], grid_inds, cp_size)
+ else:
+ q_token_list = opt_seq[grid_size * q_device_id: grid_size * (q_device_id + 1)]
+ kv_token_list = opt_seq[grid_size * kv_device_id: grid_size * (kv_device_id + 1)]
+ self.ops.get_mask_list_with_remap(attn_mask, mask_list[rnd_idx], q_token_list, kv_token_list)
+ else:
+ q_token_list = opt_seq[grid_size * q_device_id: grid_size * (q_device_id + 1)]
+ kv_token_list = opt_seq[grid_size * kv_device_id: grid_size * (kv_device_id + 1)]
+ mask_list.append(attn_mask[q_token_list, :][:, kv_token_list])
+
+ if args.attention_mask_on_cpu:
+ for rnd_idx in range(round_num):
+ if mask_list[rnd_idx] is not None:
+ mask_list[rnd_idx] = mask_list[rnd_idx].npu(non_blocking=True)
+
+ CACHED_MASK_LIST = mask_list
+ return mask_list
+
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/data_parallel/distributed_data_parallel.py b/model/train/yoco_moe/mindspeed/core/data_parallel/distributed_data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..1292a347a886d7ec321dfaa62adf2cf210b63194
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/data_parallel/distributed_data_parallel.py
@@ -0,0 +1,455 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Bytedance Inc. All rights reserved.
+import logging
+from functools import wraps
+from collections import deque
+from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
+from megatron.core.parallel_state import (
+ get_data_parallel_world_size,
+ get_data_parallel_group,
+ get_tensor_model_parallel_world_size,
+ get_global_memory_buffer)
+from megatron.legacy.model.transformer import FlashSelfAttention
+
+from megatron.training import get_args
+from megatron.core.distributed.distributed_data_parallel import DistributedDataParallel, logger
+from megatron.core.distributed.param_and_grad_buffer import ParamAndGradBuffer
+from megatron.core import parallel_state
+from megatron.core.utils import log_single_rank
+import torch
+
+
+@torch.no_grad()
+def all_gather_param(param, wait_buffer):
+ dp_size = get_data_parallel_world_size()
+ group = get_data_parallel_group()
+ dim_size = list(param.data.size())
+ dim_size[0] = dim_size[0] * dp_size
+ param.ds_tensor = param.data
+ param.data = torch.empty(dim_size, dtype=param.data.dtype, device=torch.cuda.current_device())
+ wait_buffer.append(torch.distributed._all_gather_base(param.data, param.ds_tensor.contiguous(), async_op=True, group=group))
+
+
+@torch.no_grad()
+def reduce_scatter_grad(param, wait_grad_buffer):
+ dp_size = get_data_parallel_world_size()
+ scale = 1.0
+ if dp_size > 0 :
+ scale = scale / dp_size
+ param.full_grad.data *= scale
+ group = get_data_parallel_group()
+ param.grad_data_buffer = torch.empty(param.ds_tensor.shape, dtype=param.full_grad.dtype, device=torch.cuda.current_device())
+ wait_grad_buffer.append(torch.distributed._reduce_scatter_base(param.grad_data_buffer, param.full_grad.data.contiguous(), async_op=True, group=group))
+
+
+@torch.no_grad()
+def release_param_data(param):
+ param.data = param.ds_tensor
+
+
+def wait_grad(param, wait_grad_buffer):
+ wait_grad_buffer.popleft().wait()
+ param.main_grad.add_(param.grad_data_buffer)
+ param.grad_data_buffer = None
+ param.full_grad = None
+ param.grad = None
+
+
+def set_model_fw_bw_hook(modules):
+ wait_buffer = deque()
+ wait_grad_buffer = deque()
+ dp_size = get_data_parallel_world_size()
+ if dp_size == 1:
+ return
+ module_list = []
+ fa_module = False
+ for module in modules:
+ fa_module |= isinstance(module, FlashSelfAttention)
+ if isinstance(module, (ColumnParallelLinear, RowParallelLinear)):
+ module.pre_module_id = module.next_module_id = None
+ module_list.append(module)
+ if fa_module:
+ # Send h_to_4h information in advance for communication masking.
+ module.light_weight = True
+ fa_module = False
+ if len(module_list) > 0:
+ module_list[0].zero_start = True
+ module_list[-1].zero_end = True
+ for i in range(len(module_list) - 1):
+ module_list[i].next_module_id = i + 1
+ module_list[i + 1].pre_module_id = i
+
+
+ def forward_pre_hook(module, *arg):
+ if hasattr(module, 'zero_start'):
+ all_gather_param(module.weight, wait_buffer)
+ wait_buffer.popleft().wait()
+ if hasattr(module, 'light_weight'):
+ return
+ next_module_id = module.next_module_id
+ if next_module_id is not None:
+ next_module = module_list[next_module_id]
+ all_gather_param(next_module.weight, wait_buffer)
+ if hasattr(next_module, 'light_weight') and next_module.next_module_id is not None:
+ all_gather_param(module_list[next_module.next_module_id].weight, wait_buffer)
+
+
+ def forward_hook(module, *args):
+ release_param_data(module.weight)
+
+
+ def backward_pre_hook(module, *args):
+ if hasattr(module, 'zero_end'):
+ all_gather_param(module.weight, wait_buffer)
+ wait_buffer.popleft().wait()
+ if hasattr(module, 'light_weight'):
+ return
+ pre_module_id = module.pre_module_id
+ if pre_module_id is not None:
+ pre_module = module_list[pre_module_id]
+ all_gather_param(pre_module.weight, wait_buffer)
+ if hasattr(pre_module, 'light_weight') and pre_module.pre_module_id is not None:
+ all_gather_param(module_list[pre_module.pre_module_id].weight, wait_buffer)
+
+
+ def backward_hook(module, *arg):
+ release_param_data(module.weight)
+ reduce_scatter_grad(module.weight, wait_grad_buffer)
+ if hasattr(module, 'light_weight'):
+ return
+ next_module_id = module.next_module_id
+ if next_module_id is not None:
+ next_module = module_list[next_module_id]
+ if hasattr(next_module, 'light_weight') and next_module.next_module_id is not None:
+ wait_grad(module_list[next_module.next_module_id].weight, wait_grad_buffer)
+ wait_grad(next_module.weight, wait_grad_buffer)
+ if hasattr(module, 'zero_start'):
+ wait_grad(module.weight, wait_grad_buffer)
+
+ for module in module_list:
+ module.register_forward_pre_hook(hook=forward_pre_hook)
+ module.register_forward_hook(hook=forward_hook)
+ module.register_full_backward_pre_hook(hook=backward_pre_hook)
+ module.register_full_backward_hook(hook=backward_hook)
+
+
+def distributed_data_parallel_init_zero3(
+ self,
+ config,
+ module,
+ data_parallel_group,
+ accumulate_allreduce_grads_in_fp32: bool,
+ overlap_grad_reduce: bool,
+ use_distributed_optimizer: bool,
+ expert_data_parallel_group,
+ disable_bucketing: bool = False,
+ check_for_nan_in_grad: bool = False,
+ bucket_size: int = 40000000,
+):
+ super(DistributedDataParallel, self).__init__(config)
+ self.module = module
+ if get_args().enable_zero3:
+ set_model_fw_bw_hook(self.module.modules())
+
+ # Set bucket_size to infinity if overlap_grad_reduce is False.
+ self.overlap_grad_reduce = overlap_grad_reduce
+ self.use_distributed_optimizer = use_distributed_optimizer
+
+ # Turn off bucketing if overlap_grad_reduce is False, if we are on a pipeline stage
+ # that is not the first (since data-parallel communication on these stages is not on
+ # the critical path), or if disable_bucketing is True (e.g., we might not want to
+ # break up model parameters into buckets for model chunks after the first
+ # in the interleaved schedule).
+ if not self.overlap_grad_reduce:
+ bucket_size = None
+ if parallel_state.get_pipeline_model_parallel_rank() > 0:
+ bucket_size = None
+ if disable_bucketing:
+ bucket_size = None
+
+ self.check_for_nan_in_grad = check_for_nan_in_grad
+ self.bucket_size = bucket_size
+
+ self.module = module
+ self.param_to_buffer = {}
+ self.zero3_param = []
+
+ # Group parameters by their gradient type.
+ param_to_name = {}
+ dense_params = []
+ expert_parallel_params = []
+ for name, param in self.module.named_parameters():
+ if not param.requires_grad:
+ continue
+ dtype = param.dtype
+ param.grad_added_to_main_grad = False
+ param_to_name[param] = name
+
+ if hasattr(param, 'enable_zero3') and param.enable_zero3:
+ param.main_grad = torch.zeros_like(param, dtype=dtype)
+ self.zero3_param.append(param)
+ continue
+
+ if getattr(param, 'allreduce', True):
+ dense_params.append(param)
+ else:
+ expert_parallel_params.append(param)
+
+
+ def allocate_buffers_for_parameters(
+ input_params, data_parallel_group, gradient_scaling_factor=1.0,
+ ):
+ param_and_grad_dtype_to_params = {}
+
+ # Group parameters by their gradient type.
+ for param in input_params:
+ if not param.requires_grad:
+ continue
+
+ param_dtype = param.dtype
+ grad_dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype
+
+ params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
+ params.append(param)
+ param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
+
+ # Allocate the grad buffers and map the grads.
+ buffers = []
+ for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
+ buffers.append(
+ ParamAndGradBuffer(
+ param_dtype,
+ grad_dtype,
+ params,
+ data_parallel_group,
+ bucket_size,
+ param_to_name,
+ self.overlap_grad_reduce,
+ self.use_distributed_optimizer,
+ gradient_scaling_factor,
+ self.check_for_nan_in_grad,
+ )
+ )
+ for param in params:
+ self.param_to_buffer[param] = buffers[-1]
+
+ return buffers
+
+ data_parallel_world_size = torch.distributed.get_world_size(data_parallel_group)
+ # Allocate the param+grad buffers for dense params' grads.
+ self.buffers = allocate_buffers_for_parameters(
+ dense_params,
+ data_parallel_group,
+ gradient_scaling_factor=1.0 / data_parallel_world_size,
+ )
+
+ # Allocate separate param+grad buffers for expert parallel params' grads.
+ self.expert_parallel_buffers = allocate_buffers_for_parameters(
+ expert_parallel_params,
+ expert_data_parallel_group,
+ gradient_scaling_factor=1.0 / data_parallel_world_size,
+ )
+
+ # Delete references to weight_tensor if they exist since we don't want two parameter copies
+ # if we re-mapped parameters (which happens when we use the distributed optimizer).
+ # This is a temporary workaround around a TE bug that is fixed with
+ # https://github.com/NVIDIA/TransformerEngine/pull/719.
+ if self.use_distributed_optimizer:
+
+ @torch.no_grad()
+ def unmap_weight_tensor(m):
+ if hasattr(m, 'weight_tensor'):
+ m.weight_tensor = None
+
+ self.module.apply(unmap_weight_tensor)
+
+ # Register backward hook.
+ # Accumulation function for the gradients need to be stored so they
+ # don't go out of scope.
+ self.grad_accs = []
+ for param in self.module.parameters():
+ if param.requires_grad:
+ # Expand so we get access to grad_fn.
+ param_tmp = param.expand_as(param)
+ # Get the gradient accumulator function.
+ grad_acc = param_tmp.grad_fn.next_functions[0][0]
+ if not (hasattr(param, 'enable_zero3') and param.enable_zero3):
+ grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer))
+ self.grad_accs.append(grad_acc)
+
+
+def distributed_data_parallel_zero_grad_wrapper(function):
+ @wraps(function)
+ def distributed_data_parallel_zero_grad(self, *args, **kwargs):
+ function(self, *args, **kwargs)
+ for p in self.zero3_param:
+ p.main_grad.data.zero_()
+ return distributed_data_parallel_zero_grad
+
+
+def distributed_data_parallel_init_with_cp(
+ self,
+ config,
+ ddp_config,
+ module: torch.nn.Module,
+ disable_bucketing: bool = False,
+):
+ super(DistributedDataParallel, self).__init__(config)
+ self.module = module
+
+ # If bucket_size is not provided as an input, use sane default.
+ # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
+ # ring-reduce implementations are large enough to remain bandwidth-bound rather than
+ # latency-bound.
+ if ddp_config.bucket_size is None:
+ ddp_config.bucket_size = max(
+ 40000000, 1000000 * parallel_state.get_data_parallel_world_size()
+ )
+ # Set bucket_size to infinity if overlap_grad_reduce is False.
+ if not ddp_config.overlap_grad_reduce:
+ ddp_config.bucket_size = None
+
+ self.ddp_config = ddp_config
+ log_single_rank(
+ logger,
+ logging.INFO,
+ f'Setting up DistributedDataParallel with config {self.ddp_config}',
+ )
+
+ # Turn off bucketing if we are on a pipeline stage that is not the first (since
+ # data-parallel communication on these stages is not on the critical path), or if
+ # disable_bucketing is True (e.g., we might not want to break up model parameters
+ # into buckets for model chunks after the first in the interleaved schedule).
+ self.bucket_size = self.ddp_config.bucket_size
+ if parallel_state.get_pipeline_model_parallel_rank() > 0:
+ self.bucket_size = None
+ if disable_bucketing:
+ self.bucket_size = None
+
+ self.module = module
+ self.param_to_buffer = {}
+
+ # Group parameters by their gradient type.
+ param_to_name = {}
+ dense_params = []
+ expert_parallel_params = []
+ for name, param in self.module.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ param.grad_added_to_main_grad = False
+ param_to_name[param] = name
+
+ if getattr(param, 'allreduce', True):
+ dense_params.append(param)
+ else:
+ expert_parallel_params.append(param)
+
+ def allocate_buffers_for_parameters(
+ input_params,
+ data_parallel_group,
+ gradient_scaling_factor,
+ ):
+ param_and_grad_dtype_to_params = {}
+
+ # Group parameters by their gradient type.
+ for param in input_params:
+ if not param.requires_grad:
+ continue
+
+ param_dtype = param.dtype
+ grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
+
+ params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
+ params.append(param)
+ param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
+
+ if not config.calculate_per_token_loss:
+ target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size(
+ with_context_parallel=True
+ )
+ if self.ddp_config.average_in_collective:
+ # Collective is averaging gradients in collective with data_parallel_group.
+ assert (
+ gradient_scaling_factor
+ / torch.distributed.get_world_size(group=data_parallel_group)
+ == target_gradient_scaling_factor
+ )
+ else:
+ assert gradient_scaling_factor == target_gradient_scaling_factor
+
+ # Allocate the grad buffers and map the grads.
+ buffers = []
+ for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
+ buffers.append(
+ ParamAndGradBuffer(
+ self.ddp_config,
+ param_dtype,
+ grad_dtype,
+ params,
+ data_parallel_group,
+ self.bucket_size,
+ param_to_name,
+ gradient_scaling_factor,
+ )
+ )
+ for param in params:
+ self.param_to_buffer[param] = buffers[-1]
+
+ return buffers
+
+ if config.calculate_per_token_loss:
+ gradient_scaling_factor = 1.0
+ expert_gradient_scaling_factor = 1.0
+ else:
+ if self.ddp_config.average_in_collective:
+ gradient_scaling_factor = 1.0
+ expert_gradient_scaling_factor = (
+ 1.0 / parallel_state.get_expert_model_parallel_world_size()
+ )
+ else:
+ data_parallel_world_size = parallel_state.get_data_parallel_world_size(
+ with_context_parallel=True
+ )
+ gradient_scaling_factor = 1.0 / data_parallel_world_size
+ expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
+
+ # Allocate the param+grad buffers for dense params' grads.
+ self.buffers = allocate_buffers_for_parameters(
+ dense_params,
+ parallel_state.get_data_parallel_group(with_context_parallel=True),
+ gradient_scaling_factor=gradient_scaling_factor,
+ )
+
+ # Allocate separate param+grad buffers for expert parallel params' grads.
+ self.expert_parallel_buffers = allocate_buffers_for_parameters(
+ expert_parallel_params,
+ parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),
+ gradient_scaling_factor=expert_gradient_scaling_factor,
+ )
+
+ # Delete references to weight_tensor if they exist since we don't want two parameter copies
+ # if we re-mapped parameters (which happens when we use the distributed optimizer).
+ # This is a temporary workaround around a TE bug that is fixed with
+ if self.ddp_config.use_distributed_optimizer:
+
+ @torch.no_grad()
+ def unmap_weight_tensor(m):
+ if hasattr(m, 'weight_tensor'):
+ m.weight_tensor = None
+
+ self.module.apply(unmap_weight_tensor)
+
+ # Register backward hook.
+ # Accumulation function for the gradients need to be stored so they
+ # don't go out of scope.
+ self.grad_accs = []
+ for param in self.module.parameters():
+ if param.requires_grad:
+ # Expand so we get access to grad_fn.
+ param_tmp = param.expand_as(param)
+ # Get the gradient accumulator function.
+ grad_acc = param_tmp.grad_fn.next_functions[0][0]
+ grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer))
+ self.grad_accs.append(grad_acc)
diff --git a/model/train/yoco_moe/mindspeed/core/datasets/gpt_dataset.py b/model/train/yoco_moe/mindspeed/core/datasets/gpt_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ca57c1d7c450116c129d57e3c43e1b28de988c7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/datasets/gpt_dataset.py
@@ -0,0 +1,98 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from functools import wraps
+import torch
+
+
+def _get_ltor_masks_and_position_ids(
+ data: torch.Tensor,
+ eod_token: int,
+ reset_position_ids: bool,
+ reset_attention_mask: bool,
+ eod_mask_loss: bool,
+ create_attention_mask: bool,
+):
+ """Build masks and position id for left to right model.
+
+ Args:
+ data (torch.Tensor): The data tenor that holds the tokens from the dataset
+
+ eod_token (int): ID of the token to that is considered the EOD
+
+ reset_position_ids (bool): Switch to reset the document position ID's
+
+ reset_attention_mask (bool): Switch to reset the attention mask
+
+ eod_mask_loss (bool): Switch to enable the EOD mask loss
+
+ create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
+
+ Returns:
+ torch.Tensor: Attention mask needed to be used for Attention
+
+ torch.Tensor: The mask used for loss value during training
+
+ torch.Tensor: The position ID's of the token
+ """
+ seq_length = data.numel()
+
+ if create_attention_mask:
+ attention_mask = torch.tril(
+ torch.ones((seq_length, seq_length), device=data.device)
+ ).unsqueeze(0)
+ else:
+ attention_mask = None
+
+ # Loss mask.
+ loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
+ if eod_mask_loss:
+ loss_mask[data == eod_token] = 0.0
+
+ # Position ids.
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
+ # We need to clone as the ids will be modifed based on batch index.
+ if reset_position_ids:
+ position_ids = position_ids.clone()
+
+ if reset_attention_mask:
+ # Find indices where EOD token is.
+ eod_index = position_ids[data == eod_token]
+ # Detach indices from positions if going to modify positions.
+ if reset_position_ids:
+ eod_index = eod_index.clone()
+
+ # Loop through EOD indices:
+ prev_index = 0
+ for j in range(eod_index.numel()):
+ i = eod_index[j]
+ # Mask attention loss.
+ if reset_attention_mask and attention_mask is not None:
+ attention_mask[0, (i + 1) :, : (i + 1)] = 0
+ # Reset positions.
+ if reset_position_ids:
+ position_ids[(i + 1) :] -= i + 1 - prev_index
+ prev_index = i + 1
+
+ if attention_mask is not None:
+ # Convert attention mask to binary:
+ attention_mask = attention_mask < 0.5
+
+ seq_length_tensor = torch.tensor([seq_length])
+ actual_seq_len = torch.cat([eod_index + 1, seq_length_tensor])
+
+ return attention_mask, loss_mask, (position_ids, actual_seq_len)
+
+
+def collate_wrapper(fn):
+ @wraps(fn)
+ def wrapper(samples):
+ actual_seq_len = [elem['position_ids'][1] for elem in samples]
+ samples = [{key: val if key != 'position_ids' else val[0] for key, val in elem.items()} for elem in samples]
+ batch = fn(samples)
+ seq_len = actual_seq_len[0][-1]
+ actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)]
+ batch['actual_seq_len'] = torch.cat(actual_seq_len)
+ return batch
+
+ return wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/__init__.py b/model/train/yoco_moe/mindspeed/core/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/__init__.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..30a4ad6c85014acc60912805727a0926b88a7fcd
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/__init__.py
@@ -0,0 +1,4 @@
+from .zero3 import LayerZeRO3
+from .config import layerzero_setup_model_and_optimizer_wrapper, layerzero_initialize_model_parallel_wrapper
+from .megatron_adaptor import mga_finalize_model_grads_wrapper
+from .state.mga_checkpoint import save_checkpoint
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/comm/hookwrap.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/comm/hookwrap.py
new file mode 100644
index 0000000000000000000000000000000000000000..46596b6624ea5aa18a06625e0a467376f6729b32
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/comm/hookwrap.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from typing import Dict, Tuple, TYPE_CHECKING, Optional, Deque
+from collections import deque
+from abc import abstractmethod
+from contextlib import contextmanager
+
+import torch
+
+
+class EventQueueBase:
+
+ def __init__(self) -> None:
+ pass
+
+ @abstractmethod
+ @contextmanager
+ def block(self):
+ ...
+
+ @abstractmethod
+ def empty(self):
+ ...
+
+ @abstractmethod
+ def enqueue(self, free_event: torch.cuda.Event) -> None:
+ ...
+
+ @abstractmethod
+ def pop_left(self) -> Optional[torch.cuda.Event]:
+ ...
+
+
+class CriticalPathEventQueue(EventQueueBase):
+
+ def __init__(self):
+ super().__init__()
+ self._queue: Deque[torch.cuda.Event] = deque()
+ self._buffer: Deque[torch.cuda.Event] = deque()
+ self.__blocked = False
+
+ @contextmanager
+ def block(self):
+ try:
+ self.__blocked = True
+ yield
+ finally:
+ for event in self._buffer:
+ self.enqueue(event)
+ self._buffer.clear()
+ self.__blocked = False
+
+
+ def empty(self):
+ return len(self._queue) == 0
+
+ def enqueue(self, free_event: torch.cuda.Event) -> None:
+ if self.__blocked:
+ self._buffer.append(free_event)
+ else:
+ self._queue.append(free_event)
+
+ @abstractmethod
+ def pop_left(self) -> Optional[torch.cuda.Event]:
+ if self._queue:
+ event = self._queue.popleft()
+ return event
+ return None
+
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/config.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a377397e8904d82aae607e4ae8f1d0e6c3b5942
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/config.py
@@ -0,0 +1,415 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import gc
+import dataclasses
+import importlib
+from functools import wraps
+from typing import Tuple, Literal, Union, Iterable, Optional
+
+import yaml
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+
+from megatron.core import mpu
+from megatron.core.optimizer import OptimizerConfig
+from megatron.training.training import get_optimizer_param_scheduler, get_model
+from megatron.training.global_vars import get_args, get_timers
+from megatron.training.utils import (
+ print_rank_0,
+ unwrap_model,
+)
+from megatron.core.utils import get_model_config
+from megatron.training.checkpointing import load_checkpoint
+
+from mindspeed.core.distributed.layerzero.zero3 import LayerZeRO3
+from mindspeed.core.distributed.layerzero.zero3.wrap import ModuleWrapPolicy
+from mindspeed.core.distributed.layerzero.zero3.api import (
+ BackwardPrefetch,
+ BackwardReduceScatter,
+ MixedPrecision,
+)
+from mindspeed.core.distributed.layerzero.megatron_adaptor import get_optimizer
+from mindspeed.core.distributed.layerzero.state.mga_checkpoint import save_checkpoint, load_layerzero_checkpoint
+from . import constants
+#!===============Globals============================
+_ZERO1_PROCESS_GROUP = None
+_ZERO3_PROCESS_GROUP = None
+_ZERO1_PROCESS_GROUP_RANKS = None
+_ZERO3_PROCESS_GROUP_RANKS = None
+_TP_ZERO1_PROCESS_GROUP = None
+_TP_ZERO1_PROCESS_GROUP_RANKS = None
+_TP_ZERO3_PROCESS_GROUP = None
+_TP_ZERO3_PROCESS_GROUP_RANKS = None
+
+
+@dataclasses.dataclass
+class LayerzeroConfig:
+ zero3_size: int = 8
+ transformer_layers: Optional[Iterable[torch.nn.Module]] = None
+ backward_prefetch: Literal["BACKWARD_PRE",
+ "BACKWARD_POST"] = 'BACKWARD_PRE'
+ backward_reduce_scatter: Literal["BACKWARD_PRE",
+ "BACKWARD_POST"] = 'BACKWARD_PRE'
+ param_dtype: Optional[Literal["fp16", "bf16", "fp32"]] = "fp16"
+ reduce_dtype: Optional[Literal["fp16", "bf16", "fp32"]] = "fp16"
+ buffer_dtype: Optional[Literal["fp16", "bf16", "fp32"]] = None
+ ignored_modules: Optional[Iterable[torch.nn.Module]] = None
+ param_init_fn: Optional[str] = None,
+ forward_prefetch: bool = True
+ limit_all_gathers: bool = True
+ offload_grads: bool = False
+ ckpt_load_path: str = None
+ autocast_input: bool = True
+ autocast_output: bool = True
+
+ def __post_init__(self):
+ if self.zero3_size <= 0 or not isinstance(self.zero3_size, int):
+ raise ValueError("zero3_size must be a non-negative int value")
+
+ @classmethod
+ def load_from_yaml(cls, yml_file: str):
+ with open(yml_file, 'r') as f:
+ config = yaml.safe_load(f)
+ kwargs = {}
+ for f in dataclasses.fields(cls):
+ if f.name in config:
+ kwargs[f.name] = config[f.name]
+ print_rank_0(kwargs)
+ return cls(**kwargs)
+
+ def to_dict(self):
+ process_group = self._process_group()
+ wrap_policy = self._wrap_policy()
+ mixed_precision = self._mp_policy()
+ backward_prefetch = self._backward_prefetch()
+ backward_rs = self._backward_reduce_scatter()
+ kwargs = {
+ "process_group": process_group,
+ "tp_zero_process_group": self._tp_process_group(),
+ "auto_wrap_policy": wrap_policy,
+ "mixed_precision": mixed_precision,
+ "device_id": torch.cuda.current_device(),
+ "backward_prefetch": backward_prefetch,
+ "backward_reduce_scatter": backward_rs,
+ "forward_prefetch": self.forward_prefetch,
+ "offload_grads": self.offload_grads
+ }
+ return kwargs
+
+ def _mp_policy(self):
+ # if self.fwd_bwd_dtype or
+ param_dtype = _get_dtype(
+ self.param_dtype) if self.param_dtype else None
+ reduce_dtype = _get_dtype(
+ self.reduce_dtype) if self.reduce_dtype else None
+ buffer_dtype = _get_dtype(
+ self.buffer_dtype) if self.buffer_dtype else None
+ return MixedPrecision(param_dtype=param_dtype,
+ reduce_dtype=reduce_dtype,
+ buffer_dtype=buffer_dtype)
+
+ def _wrap_policy(self):
+ if self.transformer_layers:
+ try:
+ transformer_layer_cls = set(_get_class_type(
+ m_class_name) for m_class_name in self.transformer_layers)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(f"Module {transformer_layer_cls} Not Found, \
+ check yaml config file and your model, or add it to PYTHONPATH") from e
+ else:
+ transformer_layer_cls = []
+ print_rank_0(f"Each of these layers will be wrapped as a single layer:{transformer_layer_cls}")
+ wrap_policy = ModuleWrapPolicy(transformer_layer_cls)
+ return wrap_policy
+
+ def _process_group(self):
+ if not _is_layerzero_pg_initialized():
+ raise RuntimeError("Layerzero process group is not initialized")
+ return _ZERO3_PROCESS_GROUP, _ZERO1_PROCESS_GROUP
+
+ def _tp_process_group(self):
+ return _TP_ZERO3_PROCESS_GROUP, _TP_ZERO1_PROCESS_GROUP
+
+ def _backward_prefetch(self):
+ if self.backward_prefetch not in ['BACKWARD_PRE', 'BACKWARD_POST']:
+ raise ValueError(f"{self.backward_prefetch} is not supported")
+ return BackwardPrefetch[self.backward_prefetch]
+
+ def _backward_reduce_scatter(self):
+ if self.backward_reduce_scatter not in ['BACKWARD_PRE', 'BACKWARD_POST']:
+ raise ValueError(f"{self.backward_reduce_scatter} is not supported")
+ return BackwardReduceScatter[self.backward_reduce_scatter]
+
+ def setup_cast_settings(self):
+ constants.set_auto_cast_input(self.autocast_input)
+ constants.set_auto_cast_output(self.autocast_output)
+
+
+def _get_module_attr(model: nn.Module, name: Iterable[str]):
+ if name is None:
+ return None
+ if not isinstance(name, list):
+ name = [name]
+ name = set(list(name))
+ if not all(isinstance(n, str) for n in name):
+ raise AssertionError("All name should be str")
+ results = set(getattr(model, n, None) for n in name)
+ if all([m is None for m in results]):
+ return None
+ return results
+
+
+def _get_module_and_class(name: str) -> Tuple[str, str]:
+ names = name.rsplit('.', 1)
+ if len(names) == 1:
+ raise RuntimeError(f"Please Provide a module.class name, got {name}")
+ module_name, class_name = names
+ return module_name, class_name
+
+
+def _get_class_type(name: str) -> type:
+ """
+ Args:
+ name (str): module.class
+
+ Returns:
+ type: Class Type
+ """
+ module_name, class_name = _get_module_and_class(name)
+ module = importlib.import_module(module_name)
+ class_type = getattr(module, class_name, None)
+ return class_type
+
+
+def _get_dtype(dtype: str):
+ if dtype not in {'fp16', 'bf16', 'fp32'}:
+ raise AssertionError(f"dtype {dtype} not Supported")
+ if dtype == 'fp16':
+ return torch.float16
+ elif dtype == 'bf16':
+ return torch.bfloat16
+ elif dtype == 'fp32':
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {dtype}")
+
+
+def wrap_model_with_layerzero(model: Union[Iterable[torch.nn.Module], torch.nn.Module], lz_config: LayerzeroConfig):
+
+ kwargs = lz_config.to_dict()
+ if isinstance(model, nn.Module):
+ model = [model]
+
+ model_list = []
+ for model_chunk in model:
+ ignored_modules = _get_module_attr(
+ model_chunk, lz_config.ignored_modules)
+ kwargs["ignored_modules"] = ignored_modules
+ zero3_model = LayerZeRO3(model_chunk, **kwargs)
+ model_list.append(zero3_model)
+ return model_list
+
+
+def create_optimizer_layerzero(model,
+ no_wd_decay_cond=None,
+ scale_lr_cond=None,
+ lr_mult=1.0):
+ args = get_args()
+ timers = get_timers()
+ kwargs = {}
+ for f in dataclasses.fields(OptimizerConfig):
+ if hasattr(args, f.name):
+ kwargs[f.name] = getattr(args, f.name)
+ config = OptimizerConfig(**kwargs)
+ config.timers = timers
+ optimizer = get_optimizer(config, model[0], no_wd_decay_cond,
+ scale_lr_cond, lr_mult)
+ opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
+ return optimizer, opt_param_scheduler
+
+
+def layerzero_setup_model_and_optimizer_wrapper(setup_model_and_optimizer):
+ @wraps(setup_model_and_optimizer)
+ def wrapper(model_provider_func,
+ model_type,
+ no_wd_decay_cond=None,
+ scale_lr_cond=None,
+ lr_mult=1.0):
+ args = get_args()
+ if getattr(args, 'layerzero', False):
+ # ========================================================
+ timers = get_timers()
+ models = get_model(model_provider_func, model_type, False)
+ if args.load is not None or args.pretrained_checkpoint is not None:
+ timers('load-checkpoint', log_level=0).start(barrier=True)
+ args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
+ models, None, None)
+ timers('load-checkpoint').stop(barrier=True)
+ timers.log(['load-checkpoint'])
+ else:
+ args.iteration = 0
+ args.num_floating_point_operations_so_far = 0
+ # ========================================================
+ config_yaml = args.layerzero_config
+ config = LayerzeroConfig.load_from_yaml(config_yaml)
+ config.setup_cast_settings()
+ zero_models = wrap_model_with_layerzero(
+ unwrap_model(models), config)
+ del models
+ gc.collect()
+
+ optimizer, opt_param_scheduler = create_optimizer_layerzero(zero_models,
+ no_wd_decay_cond=no_wd_decay_cond,
+ scale_lr_cond=scale_lr_cond,
+ lr_mult=lr_mult)
+ if config.ckpt_load_path is not None:
+ load_layerzero_checkpoint(
+ zero_models, config.ckpt_load_path, optimizer, opt_param_scheduler)
+ torch.cuda.empty_cache()
+ print_rank_0(f"{zero_models[0]=}")
+
+ model_config = get_model_config(zero_models[0])
+ if len(zero_models) == 1:
+ model_config.no_sync_func = zero_models[0].no_sync
+ else:
+ model_config.no_sync_func = [m.no_sync for m in zero_models]
+ return zero_models, optimizer, opt_param_scheduler
+ else:
+ return setup_model_and_optimizer(model_provider_func,
+ model_type,
+ no_wd_decay_cond,
+ scale_lr_cond,
+ lr_mult)
+
+ return wrapper
+
+
+def initialize_zero_process_group_with_pp(pp_size, zero3_size):
+ global _ZERO1_PROCESS_GROUP
+ global _ZERO1_PROCESS_GROUP_RANKS
+ global _ZERO3_PROCESS_GROUP
+ global _ZERO3_PROCESS_GROUP_RANKS
+
+ world_size = dist.get_world_size()
+ global_rank = dist.get_rank()
+ zero1_size = world_size // pp_size
+ zero3_size = min(zero3_size, zero1_size)
+ ensure_divisibility(zero1_size, zero3_size)
+ num_zero3_groups = zero1_size // zero3_size
+
+ for zero1_idx in range(pp_size):
+ cur_zero1_ranks = list(
+ range(zero1_idx * zero1_size, (zero1_idx + 1) * zero1_size))
+ zero1_group = dist.new_group(ranks=cur_zero1_ranks, backend="hccl")
+ if global_rank in cur_zero1_ranks:
+ _ZERO1_PROCESS_GROUP = zero1_group
+ _ZERO1_PROCESS_GROUP_RANKS = cur_zero1_ranks
+
+ for zero3_idx in range(num_zero3_groups):
+ cur_zero3_ranks = cur_zero1_ranks[zero3_idx *
+ zero3_size: (zero3_idx + 1) * zero3_size]
+ zero3_group = dist.new_group(ranks=cur_zero3_ranks, backend="hccl")
+ if global_rank in cur_zero3_ranks:
+ _ZERO3_PROCESS_GROUP = zero3_group
+ _ZERO3_PROCESS_GROUP_RANKS = cur_zero3_ranks
+ return
+
+
+def initialize_tp_zero_process_group(tp_zero3_size: int):
+ if not mpu.is_initialized() or not _is_layerzero_pg_initialized():
+ raise RuntimeError("Mpu or ZeRO process group is not initialized")
+
+ global _TP_ZERO1_PROCESS_GROUP
+ global _TP_ZERO1_PROCESS_GROUP_RANKS
+ global _TP_ZERO3_PROCESS_GROUP
+ global _TP_ZERO3_PROCESS_GROUP_RANKS
+
+ _TP_ZERO1_PROCESS_GROUP = mpu.get_data_parallel_group(
+ with_context_parallel=True)
+ _TP_ZERO1_PROCESS_GROUP_RANKS = list(
+ mpu._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP)
+ tp_zero1_size = len(_TP_ZERO1_PROCESS_GROUP_RANKS)
+ tp_zero3_size = min(tp_zero1_size, tp_zero3_size)
+ ensure_divisibility(tp_zero1_size, tp_zero3_size)
+
+ world_size = dist.get_world_size()
+ global_rank = dist.get_rank()
+ num_zero1_groups = world_size // tp_zero1_size
+ num_zero3_groups = tp_zero1_size // tp_zero3_size
+ for zero1_idx in range(num_zero1_groups):
+ for zero3_idx in range(num_zero3_groups):
+ cur_zero1_ranks = list(
+ range(zero1_idx, world_size, num_zero1_groups))
+ group_ranks = cur_zero1_ranks[zero3_idx *
+ tp_zero3_size: (zero3_idx + 1) * tp_zero3_size]
+ group = dist.new_group(ranks=group_ranks, backend="hccl")
+ if global_rank in group_ranks:
+ _TP_ZERO3_PROCESS_GROUP = group
+ _TP_ZERO3_PROCESS_GROUP_RANKS = group_ranks
+ return
+
+
+def initialized_zero_process_group(zero3_size):
+ '''
+ For TP > 1 or PP > 1 or TP + PP situation, the process group needs to be taken care of.
+ '''
+ if not mpu.is_initialized():
+ raise AssertionError(f"mpu is not initialized")
+ args = get_args()
+ global _ZERO1_PROCESS_GROUP
+ global _ZERO1_PROCESS_GROUP_RANKS
+ global _ZERO3_PROCESS_GROUP
+ global _ZERO3_PROCESS_GROUP_RANKS
+ global _TP_ZERO1_PROCESS_GROUP
+ global _TP_ZERO1_PROCESS_GROUP_RANKS
+ global _TP_ZERO3_PROCESS_GROUP
+ global _TP_ZERO3_PROCESS_GROUP_RANKS
+
+ initialize_zero_process_group_with_pp(
+ args.pipeline_model_parallel_size, zero3_size)
+ #! process TP process groups
+ if args.tensor_model_parallel_size > 1:
+ ensure_divisibility(zero3_size, args.tensor_model_parallel_size)
+ tp_zero3_size = max(1, zero3_size // args.tensor_model_parallel_size)
+ initialize_tp_zero_process_group(tp_zero3_size)
+ else:
+ _TP_ZERO1_PROCESS_GROUP = _ZERO1_PROCESS_GROUP
+ _TP_ZERO1_PROCESS_GROUP_RANKS = _ZERO1_PROCESS_GROUP_RANKS
+ _TP_ZERO3_PROCESS_GROUP = _ZERO3_PROCESS_GROUP
+ _TP_ZERO3_PROCESS_GROUP_RANKS = _ZERO3_PROCESS_GROUP_RANKS
+
+ print(f"Layerzero with zero1 process group: {_ZERO1_PROCESS_GROUP_RANKS}, \
+ zero3 process group: {_ZERO3_PROCESS_GROUP_RANKS}, \
+ TP zero1 process group: {_TP_ZERO1_PROCESS_GROUP_RANKS}, \
+ TP zero3 process group: {_TP_ZERO3_PROCESS_GROUP_RANKS}, \
+ global rank: {dist.get_rank()}")
+ return
+
+
+def _is_layerzero_pg_initialized():
+ return _ZERO1_PROCESS_GROUP is not None and _ZERO3_PROCESS_GROUP is not None
+
+
+def layerzero_initialize_model_parallel_wrapper(initialize_model_parallel):
+ @wraps(initialize_model_parallel)
+ def wrapper(*args, **kargs):
+ results = initialize_model_parallel(*args, **kargs)
+ global_args = get_args()
+ if getattr(global_args, 'layerzero', False):
+ print_rank_0(
+ f"Entering initialize_model_parallel to create layerzero process groups")
+ config_yaml = global_args.layerzero_config
+ config = LayerzeroConfig.load_from_yaml(config_yaml)
+ zero3_size = config.zero3_size
+ initialized_zero_process_group(zero3_size)
+ return results
+
+ return wrapper
+
+
+def ensure_divisibility(a: int, b: int):
+ """Ensure that 'a' is divisible by 'b'. If not, raise an AssertionError with a custom or default message."""
+ if b == 0:
+ raise ValueError("The divisor (b) must not be zero.")
+ if a % b != 0:
+ raise ValueError(f"{a} is not divisible by {b}")
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/constants.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a915ec51b4cb631c763fa1cb59bc089b9cb1665
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/constants.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+AUTO_CAST_INPUT = True
+AUTO_CAST_OUTPUT = True
+
+
+def set_auto_cast_input(state: bool):
+ global AUTO_CAST_INPUT
+ if not isinstance(state, bool):
+ raise AssertionError("state must be a boolean")
+ AUTO_CAST_INPUT = state
+
+
+def set_auto_cast_output(state: bool):
+ global AUTO_CAST_OUTPUT
+ if not isinstance(state, bool):
+ raise AssertionError("state must be a boolean")
+ AUTO_CAST_OUTPUT = state
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/debug/sum.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/debug/sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aa4c0046c61ae7dd55ac415cf2417056d73e15e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/debug/sum.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+import torch.distributed as dist
+from megatron.training.utils import print_rank_0
+
+
+@torch.no_grad()
+def print_total_grad_sum(params):
+ for param in params:
+ print_grad_sum(param)
+
+
+@torch.no_grad()
+def print_grad_sum(param):
+ from megatron.core import mpu
+ if getattr(param, "tensor_model_parallel", False):
+ print_grad_sum_helper(param, mpu.get_data_parallel_group(with_context_parallel=True), "TP_shard")
+ else:
+ print_grad_sum_helper(param, dist.group.WORLD, "None TP")
+
+
+@torch.no_grad()
+def print_grad_sum_helper(param, group, msg):
+ if param.grad is not None:
+ g_sum = param.grad.contiguous().float().sum()
+ p_sum = param.contiguous().float().sum()
+ else:
+ g_sum = torch.zeros([1]).float().to(param.device)
+ p_sum = torch.zeros([1]).float().to(param.device)
+
+ dist.all_reduce(g_sum, group=group)
+ dist.all_reduce(p_sum, group=group)
+ print_rank_0(f"{msg} Psum {p_sum.item()}, Gsum {g_sum.item()}")
+
+
+def all_gather_into_flat_tensor(tensor: torch.Tensor, process_group):
+ '''这个函数用于将不同rank上不同大小的tensor 聚合成一个大的flatTensor'''
+ world_size = process_group.size()
+ rank = dist.get_rank(process_group)
+
+ # 如果tensor为None或没有元素,使用一个空 tensor
+ if tensor is None or tensor.numel() == 0:
+ local_tensor = torch.empty([0]).float().cuda()
+ else:
+ local_tensor = tensor.contiguous().flatten().float()
+
+ # 获取所有进程中的 tensor 大小
+ tensor_sizes = [torch.zeros(1, dtype=torch.int64).cuda() for _ in range(world_size)]
+ if local_tensor.numel() > 0:
+ tensor_sizes[rank] = torch.tensor([local_tensor.numel()], dtype=torch.int64).cuda()
+ else:
+ tensor_sizes[rank] = torch.tensor([0], dtype=torch.int64).cuda()
+ dist.all_gather(tensor_sizes, tensor_sizes[rank], group=process_group)
+ tensor_sizes = [int(size.item()) for size in tensor_sizes]
+
+ # 找到最大 tensor 大小
+ max_size = max(tensor_sizes)
+
+ # 创建填充 tensor
+ if max_size > 0:
+ padding_tensor = torch.zeros(max_size, dtype=torch.float32, device=local_tensor.device).cuda()
+ else:
+ padding_tensor = torch.tensor([], dtype=torch.float32, device=local_tensor.device).cuda()
+
+ # 将 local_tensor 填充到 padding_tensor
+ if local_tensor.numel() > 0:
+ padding_tensor[:local_tensor.numel()] = local_tensor
+
+ # 创建列表来存储所有填充后的 tensor
+ all_padding_tensors = [torch.zeros_like(padding_tensor).cuda() for _ in range(world_size)]
+
+ # 收集所有填充后的 tensor
+ dist.all_gather(all_padding_tensors, padding_tensor, group=process_group)
+
+ # 拼接所有 tensor,去除填充部分
+ flatten_tensor = torch.cat([t[:size] for t, size in zip(all_padding_tensors, tensor_sizes)], dim=0)
+ return flatten_tensor
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/__init__.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..84e9a82743f3def1f6fb6becdfcee242d2d17bc4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/__init__.py
@@ -0,0 +1,2 @@
+from .optimizer.zero import LayerZeROptimizer, get_optimizer
+from .optimizer.misc import mga_finalize_model_grads_wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/clip.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..13de37606ff69a3dc665ecfe72f1bed0f9f3831b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/clip.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from typing import Iterable
+import math
+import amp_C
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from apex.multi_tensor_apply import multi_tensor_applier
+from mindspeed.core.distributed.layerzero.zero3._common_utils import _is_zero3_flattened
+
+
+@torch.no_grad()
+def _get_grad_norm(
+ params: Iterable[nn.Parameter],
+ norm_type: float,
+) -> torch.Tensor:
+ """
+ Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
+
+ The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
+ use of this return value is a reduction across ranks.
+ """
+ params_with_grad = [param for param in params if param.grad is not None]
+ if len(params_with_grad) == 0:
+ return torch.tensor(0.0)
+ grads = [param.grad for param in params_with_grad]
+ grad_dtypes = {grad.dtype for grad in grads}
+ if len(grad_dtypes) != 1:
+ raise ValueError(
+ f"Requires uniform dtype across all gradients but got {grad_dtypes}"
+ )
+ # Compute the gradient norm in FP32, where we treat the gradients as a
+ # single vector
+ grad_norm = torch.linalg.vector_norm(
+ torch.stack(
+ [
+ torch.linalg.vector_norm(
+ grad.detach(), norm_type, dtype=torch.float32)
+ for grad in grads
+ ],
+ ),
+ norm_type,
+ dtype=torch.float32,
+ )
+ return grad_norm
+
+
+def clip_grad_norm(params, max_norm, norm_type=2, process_group=dist.group.WORLD):
+ '''
+ For distributed ZERO optimizers, the gradient norm is calculated since the parameter/gradient
+ is distributed across the individual ranks, Additional communication is required
+ It is worth noting here that the grad_norm is divided by world_size approximate DDP
+ #! ZeRO-managed parameters and non-ZeRO-managed parameters are handled separately
+ '''
+ if not max_norm > 0.:
+ raise ValueError("clip_grad should be a number greater than 0.0")
+
+ if isinstance(params, torch.Tensor):
+ params = [params]
+ norm_type = float(norm_type)
+ device = params[0].device
+ sharded_params = set(p for p in params if _is_zero3_flattened(p))
+ non_sharded_params = set(p for p in params if p not in sharded_params)
+
+ local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to(device)
+ local_nonsharded_norm = _get_grad_norm(
+ non_sharded_params, norm_type).to(device)
+ if norm_type == math.inf:
+ total_norm = (
+ torch.maximum(local_sharded_norm, local_nonsharded_norm)
+ if local_nonsharded_norm is not None
+ else local_sharded_norm
+ )
+ dist.all_reduce(
+ total_norm, op=torch.distributed.ReduceOp.MAX, group=process_group
+ )
+ else:
+ total_norm = local_sharded_norm**norm_type
+ dist.all_reduce(total_norm, group=process_group)
+ # All-reducing the local non-sharded norm would count it an extra
+ # world-size-many times
+ if local_nonsharded_norm is not None:
+ total_norm += local_nonsharded_norm**norm_type
+ total_norm = total_norm ** (1.0 / norm_type)
+
+ clip_coef = max_norm / (total_norm + 1e-6)
+ grads = list(set(param.grad for param in params if param.grad is not None))
+ if clip_coef < 1.0:
+ dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
+ multi_tensor_applier(
+ amp_C.multi_tensor_scale, dummy_overflow_buf, [
+ grads, grads], clip_coef
+ )
+ return total_norm
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/misc.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fb02ae82efb93e7d4b7a4dd44b9a7bcc22dac30
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/misc.py
@@ -0,0 +1,47 @@
+from typing import List, Optional
+from functools import wraps
+
+import torch
+from megatron.core import parallel_state
+from megatron.training.global_vars import get_args
+from mindspeed.core.distributed.layerzero.zero3 import LayerZeRO3
+
+
+def scale_gradients(model, scaling_factor: float):
+ if not (isinstance(model, LayerZeRO3) and model._is_root):
+ raise ValueError(f"This func expects to be called on a LayerZeRO3 root instance, got {type(model)}")
+
+ for param in model.parameters():
+ if param.requires_grad and param.grad is not None:
+ param.grad.data *= scaling_factor
+
+
+def mga_finalize_model_grads_wrapper(finalize_model_grads):
+ @wraps(finalize_model_grads)
+ def wrapper(*args, **kwargs):
+ global_args = get_args()
+ if getattr(global_args, 'layerzero', False):
+ num_tokens = kwargs.get('num_tokens', None)
+ if num_tokens is not None:
+ # the number of tokens is only present on the last stage, so broadcast it
+ # to the other ranks in the pipeline parallel group.
+ torch.distributed.broadcast(
+ num_tokens,
+ src=parallel_state.get_pipeline_model_parallel_last_rank(),
+ group=parallel_state.get_pipeline_model_parallel_group(),
+ )
+ # all-reduce across DP ranks.
+ torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
+ model = kwargs.get('model', None)
+ if model is None and args:
+ model = args[0]
+ for model_chunk in model:
+ if num_tokens > 0:
+ scaling = 1.0 / num_tokens
+ scale_gradients(model_chunk, scaling)
+ return None
+ else:
+ return finalize_model_grads(*args, **kwargs)
+
+ return wrapper
+
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/sharded_grad_scaler.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/sharded_grad_scaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b0926fa740cb24fd79a4ee1fbffc333bd36b0d8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/sharded_grad_scaler.py
@@ -0,0 +1,387 @@
+import logging
+from collections import abc, defaultdict
+from typing import Dict, List, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch.cuda import FloatTensor
+from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
+from torch.distributed.distributed_c10d import ProcessGroup
+
+
+log = logging.getLogger(__name__)
+
+
+def _refresh_per_optimizer_state():
+ return {"stage": OptState.READY, "found_inf_per_device": {}}
+
+
+def _check_supported_device(tensor: torch.Tensor):
+ if tensor.is_cuda or tensor.device.type in ("xla", "cpu", "npu"):
+ return
+ raise AssertionError(f"Unsupported device: {tensor.device}")
+
+
+class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
+ """
+ Lazily serves tensor to request device. This class extends
+ _MultiDeviceReplicator to allow support for "cpu" as a device.
+ """
+
+ def __init__(self, master_tensor: torch.Tensor) -> None:
+ _check_supported_device(master_tensor)
+ self.master = master_tensor
+ self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+
+class ShardedGradScaler(GradScaler):
+
+ def __init__(
+ self,
+ init_scale: float = 2.0**16,
+ min_scale: float = 1.,
+ backoff_factor: float = 0.5,
+ growth_factor: float = 2.0,
+ growth_interval: int = 2000,
+ hysteresis: int = 2,
+ enabled: bool = True,
+ process_group: Optional[ProcessGroup] = dist.group.WORLD,
+ ):
+ if init_scale is None:
+ init_scale = 1.0
+ super().__init__(
+ init_scale=init_scale,
+ backoff_factor=backoff_factor,
+ growth_factor=growth_factor,
+ growth_interval=growth_interval,
+ enabled=enabled,
+ )
+ if self._enabled:
+ self.process_group = process_group
+ self._per_optimizer_states = defaultdict(
+ _refresh_per_optimizer_state)
+ self.device = torch.device("cuda")
+ self.hysteresis = hysteresis
+ self._hysteresis_tracker = self.hysteresis
+
+ @property
+ def loss_scale(self) -> torch.Tensor:
+ '''
+ The scaler's scale is lazily initialized, or None if _lazy_init_scale_growth_tracker is not used
+ Initialization is only done when scale() is called for the first time
+
+ But megatronOptimizer doesn't scale directly, but manually scales loss
+ '''
+ if not self._enabled:
+ return torch.tensor([1.0], dtype=torch.float32, device=self.device)
+ elif self._scale is None:
+ self._lazy_init_scale_growth_tracker(self.device)
+ self._check_none_scale()
+ return self._scale
+
+ def scale(
+ self, outputs: Union[torch.Tensor, List[torch.Tensor]]
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ if not self._enabled:
+ return outputs
+
+ if isinstance(outputs, torch.Tensor):
+ _check_supported_device(outputs)
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(outputs.device)
+ self._check_none_scale()
+ scaled_output = outputs * self._scale.to(
+ device=outputs.device, non_blocking=True
+ )
+ # Here we ensure the return dtype is the same as the outputs dtype.
+ # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
+ # format (fp16, bf16) and so the scaled loss should be of the same dtype.
+ return scaled_output.type(outputs.dtype)
+
+ stash: List[_GeneralMultiDeviceReplicator] = []
+
+ def apply_scale(
+ val: Union[torch.Tensor, abc.Iterable]
+ ) -> Union[torch.Tensor, abc.Iterable]:
+ if isinstance(val, torch.Tensor):
+ _check_supported_device(val)
+ if len(stash) == 0:
+ if self._scale is None:
+ self._lazy_init_scale_growth_tracker(val.device)
+ self._check_none_scale()
+ stash.append(_GeneralMultiDeviceReplicator(self._scale))
+ scaled_val = val * stash[0].get(val.device)
+
+ return scaled_val.type(val.dtype)
+ elif isinstance(val, abc.Iterable):
+ iterator = map(apply_scale, val)
+ if isinstance(val, (list, tuple)):
+ return type(val)(iterator)
+ else:
+ return iterator
+ else:
+ raise ValueError(
+ "outputs must be a Tensor or an iterable of Tensors")
+
+ return apply_scale(outputs) # type: ignore[return-value]
+
+ def _foreach_non_finite_check_and_unscale_cpu_(
+ self, grads: List, found_inf: torch.Tensor, inv_scale: torch.Tensor
+ ) -> None:
+ if len(grads) == 0:
+ return
+ if inv_scale.numel() != 1:
+ raise ValueError("inv_scale must be a 1-element tensor.")
+ if found_inf.numel() != 1:
+ raise ValueError("found_inf must be a 1-element tensor.")
+
+ for grad in grads:
+ if grad.device.type != "cpu":
+ log.error(
+ "tensor device is %s but was expected to be ``cpu``",
+ grad.device,
+ )
+ raise ValueError(
+ "Gradients were found on a non-CPU device when"
+ " expected to be on CPU."
+ )
+ if (
+ torch.isinf(grad).any().item() is True
+ or torch.isnan(grad).any().item() is True
+ ):
+ found_inf.data = torch.tensor([1.0])
+ break
+ else:
+ grad.data *= inv_scale.item()
+
+ def _unscale_grads_(
+ self,
+ optimizer: torch.optim.Optimizer,
+ inv_scale: torch.Tensor,
+ found_inf: torch.Tensor,
+ allow_fp16: bool = True,
+ ) -> Dict[torch.device, torch.Tensor]:
+ per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
+ per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
+
+ per_device_and_dtype_grads = defaultdict(
+ lambda: defaultdict(list))
+ with torch.no_grad():
+ for group in optimizer.param_groups:
+ for param in group['params']:
+ if param.grad is None:
+ continue
+ if (not allow_fp16) and param.grad.dtype == torch.float16:
+ raise ValueError(
+ "Attempting to unscale FP16 gradients.")
+ if param.grad.is_sparse:
+ if param.grad.dtype is torch.float16:
+ # coalesce is not supported in torch.float16
+ param_grad_fp32 = param.grad.type(
+ torch.float32).coalesce()
+ param.grad = param_grad_fp32.type(torch.float16)
+ to_unscale = param.grad._values()
+ else:
+ to_unscale = param.grad
+
+ per_device_and_dtype_grads[to_unscale.device][
+ to_unscale.dtype
+ ].append(to_unscale)
+
+ for device, per_dtype_grads in per_device_and_dtype_grads.items():
+ for grads in per_dtype_grads.values():
+ if grads[0].device.type == "cpu":
+ self._foreach_non_finite_check_and_unscale_cpu_(
+ grads,
+ per_device_found_inf.get(device),
+ per_device_inv_scale.get(device),
+ )
+ else:
+ torch._amp_foreach_non_finite_check_and_unscale_(
+ grads,
+ per_device_found_inf.get(device),
+ per_device_inv_scale.get(device),
+ )
+ # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some
+ # ranks may have no (non-zero sized) parameter shards, necessitating the
+ # initialization of `per_device_found_inf._per_device_tensors` here
+ if not per_device_found_inf._per_device_tensors:
+ self._check_none_scale()
+ per_device_found_inf.get(self._scale.device)
+ return per_device_found_inf._per_device_tensors
+
+ def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
+ if not self._enabled:
+ return False
+
+ self._check_scale_growth_tracker("unscale_")
+
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
+
+ if optimizer_state["stage"] is OptState.UNSCALED:
+ raise RuntimeError(
+ "unscale_() has already been called on this optimizer since the last update()."
+ )
+ elif optimizer_state["stage"] is OptState.STEPPED:
+ raise RuntimeError("unscale_() is being called after step().")
+
+ # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
+ self._check_none_scale()
+ inv_scale = self._scale.double().reciprocal().float()
+ found_inf = torch.full(
+ (1,), 0.0, dtype=torch.float32, device=self._scale.device
+ )
+
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
+ optimizer, inv_scale, found_inf, True
+ )
+ optimizer_state["stage"] = OptState.UNSCALED
+
+ # Synchronize the detected inf across the ranks
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
+ future_handles = []
+
+ for v in optimizer_state["found_inf_per_device"].values():
+ if v.device.type == "cpu":
+ v_on_cuda = v.cuda()
+ future_handles.append(
+ dist.all_reduce(
+ v_on_cuda, async_op=True, group=self.process_group
+ ).get_future()
+ )
+ v.copy_(v_on_cuda.cpu())
+ else:
+ future_handles.append(
+ dist.all_reduce(
+ v, async_op=True, group=self.process_group
+ ).get_future()
+ )
+
+ # Make sure that the calls are done before moving out.
+ if future_handles:
+ torch.futures.wait_all(future_handles)
+
+ if (
+ len(optimizer_state["found_inf_per_device"]) == 0
+ ):
+ raise AssertionError("No inf checks were recorded for this optimizer.")
+
+ found_inf = sum(v.item()
+ for v in optimizer_state["found_inf_per_device"].values())
+ return found_inf > 0.
+
+ def step(
+ self, optimizer: torch.optim.Optimizer, *args, **kwargs
+ ) -> Optional[float]:
+ return super().step(optimizer, *args, **kwargs)
+
+ def _update_scale(self, found_inf) -> None:
+ """
+ If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
+ Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
+ """
+ if found_inf.item() >= 1.0:
+ self._scale *= self._backoff_factor # type: ignore[arg-type]
+ self._growth_tracker = 0
+ self._hysteresis_tracker -= 1
+ if self._hysteresis_tracker <= 0:
+ self._scale = torch.max(
+ self._scale * self.backoff_factor, self.min_scale)
+ else:
+ successful = self._growth_tracker + 1 # type: ignore[operator]
+ if successful == self._growth_interval: # type: ignore[arg-type]
+ self._scale *= self._growth_factor # type: ignore[arg-type]
+ self._growth_tracker = 0
+ self._hysteresis_tracker = self.hysteresis
+ else:
+ self._growth_tracker = successful
+
+ def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None:
+ """
+ Updates the scale factor.
+ If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
+ to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
+ the scale is multiplied by ``growth_factor`` to increase it.
+ Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
+ used directly, it's used to fill GradScaler's internal scale tensor. So if
+ ``new_scale`` was a tensor, later in-place changes to that tensor will not further
+ affect the scale GradScaler uses internally.)
+ Args:
+ new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
+ .. warning::
+ :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
+ been invoked for all optimizers used this iteration.
+ """
+
+ if not self._enabled:
+ return
+
+ _scale, _growth_tracker = self._check_scale_growth_tracker(
+ "update") # type: ignore[var-annotated]
+
+ if new_scale is not None:
+ # Accept a new user-defined scale.
+ if isinstance(new_scale, float):
+ self._scale.fill_(new_scale) # type: ignore[union-attr]
+ else:
+ if not (isinstance(new_scale, torch.cuda.FloatTensor) and (new_scale.numel() == 1) and not new_scale.requires_grad):
+ raise AssertionError("new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False.")
+ self._scale.copy_(new_scale) # type: ignore[union-attr]
+ else:
+ # Consume shared inf/nan data collected from optimizers to update the scale.
+ # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
+ found_infs = [
+ found_inf.to(device=_scale.device, non_blocking=True)
+ for state in self._per_optimizer_states.values()
+ for found_inf in state["found_inf_per_device"].values()
+ ]
+
+ if len(found_infs) == 0:
+ raise AssertionError("No inf checks were recorded prior to update.")
+
+ found_inf_combined = found_infs[0]
+ if len(found_infs) > 1:
+ for i in range(1, len(found_infs)):
+ found_inf_combined += found_infs[i]
+
+ self._update_scale(found_inf_combined)
+
+ # To prepare for next iteration, clear the data collected from optimizers this iteration.
+ self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
+
+ def _meg_step(self, optimizer, *args, **kwargs):
+ '''Split the optional step with unscale for adapted with megatron
+ In between we can insert other operations like clip grad
+ '''
+ if not self._enabled:
+ return optimizer.step(*args, **kwargs)
+
+ self._check_scale_growth_tracker("step")
+
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
+
+ if optimizer_state["stage"] is OptState.STEPPED:
+ raise RuntimeError(
+ "step() has already been called since the last update()."
+ )
+
+ retval = self._maybe_opt_step(
+ optimizer, optimizer_state, *args, **kwargs)
+ optimizer_state["stage"] = OptState.STEPPED
+ return retval
+
+ def state_dict(self):
+ state_dict = {}
+ state_dict['scale'] = self._scale
+ state_dict['growth_tracker'] = self._growth_tracker
+ state_dict['hysteresis_tracker'] = self._hysteresis_tracker
+ return state_dict
+
+ def load_state_dict(self, state_dict: Dict):
+ self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
+ self._growth_tracker = state_dict['growth_tracker']
+ self._hysteresis_tracker = state_dict['hysteresis_tracker']
+
+ def _check_none_scale(self):
+ if self._scale is None:
+ raise AssertionError("Got none scale")
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/zero.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/zero.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1418442e8993f03ac6d8a5648c551b8e2df4c4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/megatron_adaptor/optimizer/zero.py
@@ -0,0 +1,277 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import warnings
+
+from typing import Callable, Optional, List, Tuple, Dict
+import torch
+import torch.distributed as dist
+from torch.distributed.distributed_c10d import ProcessGroup
+
+from apex.optimizers import FusedAdam as Adam
+from apex.optimizers import FusedSGD as SGD
+from megatron.core.optimizer import MegatronOptimizer, OptimizerConfig
+from megatron.training.utils import print_rank_0
+from megatron.core import mpu
+from mindspeed.core.distributed.layerzero.debug.sum import all_gather_into_flat_tensor, print_total_grad_sum
+from .sharded_grad_scaler import ShardedGradScaler
+from .clip import clip_grad_norm
+
+
+def _get_param_groups(
+ model_chunks: List,
+ no_weight_decay_cond: Callable,
+ scale_lr_cond: Callable,
+ lr_mult: float,
+) -> List[Dict]:
+ """Create parameter groups for optimizer.
+
+ Creates parameter groups based on weight decay condition (regularized vs
+ non regularized), learning rate scale condition (lr vs lr_mult * lr),
+ and whether it is expert parameters. scale_lr_cond is used during finetuning
+ where head of the network requires a scaled version of the base learning rate.
+
+ Args:
+ model_chunks (List[MegatronModule]): model chunks to create parameter
+ groups for.
+ no_weight_decay_cond (func): function to determine whether a parameter
+ should not perform weight decay.
+ scale_lr_cond (func): function to determine whether a parameter
+ should have a scaled learning rate.
+ lr_mult (float): learning rate multiplier for parameters that
+ satisfy scale_lr_cond.
+
+ Returns:
+ List of parameter groups.
+ """
+ if not isinstance(model_chunks, list):
+ model_chunks = [model_chunks]
+ # Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
+ params_map = {}
+ for model_chunk in model_chunks:
+ for name, param in model_chunk.named_parameters():
+ if not param.requires_grad:
+ continue
+ if no_weight_decay_cond is not None:
+ no_wd = no_weight_decay_cond(name, param)
+ else:
+ # Do not regularize biases and norm parameters.
+ #! currently do not support norm parameters, case all zero1 param has len(param.shape) == 1
+ no_wd = name.endswith(".bias") or getattr(param, "_is_1D_param", False)
+
+ if scale_lr_cond is not None:
+ scale_lr = scale_lr_cond(name, param)
+ else:
+ scale_lr = False
+
+ if not no_wd and not scale_lr:
+ wd_mult, lr_mult = 1.0, 1.0
+ elif not no_wd and scale_lr:
+ wd_mult, lr_mult = 1.0, lr_mult
+ elif no_wd and not scale_lr:
+ wd_mult, lr_mult = 0.0, 1.0
+ else:
+ wd_mult, lr_mult = 0.0, lr_mult
+
+ key = (wd_mult, lr_mult)
+ if key not in params_map:
+ params_map[key] = []
+ params_map[key].append(param)
+
+ param_groups = []
+ for (wd_mult, lr_mult), params in params_map.items():
+ if len(params) == 0:
+ raise ValueError(f"Empty params list")
+ param_groups.append(
+ {
+ 'params': params,
+ 'wd_mult': wd_mult,
+ 'lr_mult': lr_mult,
+ 'is_decoupled_lr' : False
+ }
+ )
+ return param_groups
+
+
+def get_optimizer(
+ config: OptimizerConfig,
+ model: List,
+ no_weight_decay_cond: Callable = None,
+ scale_lr_cond: Callable = None,
+ lr_mult: float = 1.0
+) -> "MegatronOptimizer":
+ param_groups = _get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr_mult)
+ optimizer = _get_zero_optimizer(config, param_groups)
+ return optimizer
+
+
+def _get_zero_optimizer(
+ config,
+ param_groups
+):
+ print(f"{config.weight_decay=}")
+ if config.optimizer == 'adam':
+ optimizer = Adam(
+ param_groups,
+ lr=config.lr,
+ weight_decay=config.weight_decay,
+ betas=(config.adam_beta1, config.adam_beta2),
+ eps=config.adam_eps,
+ )
+ init_state_fn = None
+
+ elif config.optimizer == 'sgd':
+ optimizer = SGD(
+ param_groups,
+ lr=config.lr,
+ weight_decay=config.weight_decay,
+ momentum=config.sgd_momentum,
+ )
+ init_state_fn = None
+ else:
+ raise Exception('{} optimizer is not supported.'.format(config.optimizer))
+
+ grad_scaler = None
+ if config.fp16:
+ grad_scaler = ShardedGradScaler(
+ init_scale=config.initial_loss_scale,
+ min_scale=config.min_loss_scale,
+ growth_factor=2.0,
+ backoff_factor=0.5,
+ growth_interval=config.loss_scale_window,
+ hysteresis=config.hysteresis,
+ )
+
+ optimizer_args = [optimizer, config, grad_scaler, init_state_fn]
+ optimizer = LayerZeROptimizer(*optimizer_args)
+ return optimizer
+
+
+def pp_stages():
+ if not mpu.is_initialized():
+ return 1
+ world_size = dist.get_world_size()
+ return world_size // len(mpu.get_pipeline_model_parallel_group())
+
+
+def pp_broadcast_grad_scale(grad_scale, device):
+ if pp_stages() == 1:
+ return grad_scale
+ pp_world_size = mpu.get_pipeline_model_parallel_world_size()
+ world_size = dist.get_world_size()
+ last_stage_rank0 = world_size - pp_world_size
+ if not isinstance(grad_scale, torch.Tensor):
+ grad_scale = torch.tensor(grad_scale, dtype=torch.float32).to(device)
+ dist.broadcast(grad_scale, src=last_stage_rank0)
+ return grad_scale
+
+
+class LayerZeROptimizer(MegatronOptimizer):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ config: OptimizerConfig,
+ grad_scaler: Optional[ShardedGradScaler],
+ init_state_fn: Callable = lambda x: None,
+ process_group: Optional[ProcessGroup] = dist.group.WORLD,
+ ):
+ super().__init__(optimizer, config, lambda x: None)
+ self.grad_scaler = grad_scaler
+ self.process_group = process_group or dist.group.WORLD
+ self.device = torch.device('cuda')
+
+
+ def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
+ """Simple scaling."""
+ return self.get_loss_scale() * loss
+
+ def get_loss_scale(self) -> torch.Tensor:
+ '''if PP enabled, broadcast scale'''
+ if self.grad_scaler is None:
+ return torch.tensor([1.], dtype=torch.float32, device=self.device)
+ return self.grad_scaler.loss_scale.to(self.device)
+
+ @torch.no_grad()
+ def step(self) -> Tuple[bool, torch.Tensor, torch.Tensor]:
+ if self.grad_scaler:
+ self.grad_scaler._scale = pp_broadcast_grad_scale(self.get_loss_scale(), self.device)
+ found_inf = self.grad_scaler.unscale_(self.optimizer)
+ else:
+ found_inf = False
+
+ grad_norm = None
+ if self.config.clip_grad > 0.0:
+ if self.process_group is None:
+ raise RuntimeError(f"{self.process_group=} is None")
+ grad_norm = clip_grad_norm(self.get_parameters(), self.config.clip_grad, norm_type=2, process_group=self.process_group)
+
+ num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
+
+ if self.grad_scaler:
+ self.grad_scaler._meg_step(self.optimizer)
+ self.grad_scaler.update()
+ else:
+ self.optimizer.step()
+
+ return not found_inf, grad_norm, num_zeros_in_grad
+
+ def prepare_grads(self) -> bool:
+ raise RuntimeError("This function should not be explicitly called by user")
+
+ def step_with_ready_grads(self) -> bool:
+ raise RuntimeError("This function should not be explicitly called by user")
+
+ def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
+ raise RuntimeError("This function should not be explicitly called by user")
+
+ def count_zeros(self):
+ num_zeros = sum(param.grad.numel() - torch.count_nonzero(param.grad) \
+ for param in self.get_parameters() if param.grad is not None)
+ dist.all_reduce(num_zeros, group=self.process_group)
+ return num_zeros
+
+ def reload_model_params(self):
+ '''Megatron optimizer api'''
+ pass
+
+ def state_dict(self):
+ state_dict = {}
+ state_dict['optimizer'] = self.optimizer.state_dict()
+ if self.grad_scaler:
+ state_dict['grad_scaler'] = self.grad_scaler.state_dict()
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ # Optimizer.
+ optimizer_key = 'optimizer'
+ if optimizer_key not in state_dict:
+ optimizer_key = 'optimizer_state_dict'
+ self.optimizer.load_state_dict(state_dict[optimizer_key])
+ # Grad scaler.
+ if self.grad_scaler:
+ if "grad_scaler" not in state_dict:
+ warnings.warn(f"grad scaler state dict missing")
+ else:
+ self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
+
+ def sharded_state_dict(
+ self, model_sharded_state_dict, is_loading: bool = False
+ ):
+ """ Builds sharded state dict for the optimizer, based on model's sharded state dict.
+
+ Args:
+ model_sharded_state_dict (ShardedStateDict): sharded state dict of the model
+ is_loading (bool, optional): flag indicating whether the state dict will be used to save or load the optimizer state.
+ Defaults to False.
+
+ Returns: optimizer sharded state dict
+ """
+ raise NotImplementedError("This api should not be called")
+
+ def zero_grad(self, set_to_none: bool = True):
+ self.optimizer.zero_grad()
+
+ def disable_pre_hook(self):
+ return
+
+ def enable_pre_hook(self):
+ return
+
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_forward.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..500756cad415bd723ee17c3f8f0316c495bfcf03
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_forward.py
@@ -0,0 +1,566 @@
+import functools
+from itertools import chain
+from collections import deque
+import logging
+from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from torch.distributed.utils import (
+ _cast_forward_inputs,
+ _p_assert,
+ _apply_to_tensors
+)
+from torch.utils._pytree import tree_flatten
+from torch.autograd.graph import register_multi_grad_hook
+from mindspeed.core.distributed.layerzero.zero3.api import BackwardReduceScatter
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ _assert_in_training_states,
+ _is_composable,
+ _ZeRO3State,
+ TrainingState,
+)
+
+from mindspeed.core.distributed.layerzero.zero3.flat_param import FlatParamHandle, HandleTrainingState
+from mindspeed.core.distributed.layerzero import constants
+from ._shard import _unshard, _reshard, _pre_forward_backward_unshard, _post_forward_reshard, _post_backward_reshard, _get_handle_to_post_backward
+from ._grad import _reduce_grad, _accumulate_grad, _pre_bwd_reload_full_prec_grad
+from ._utils import _reset_flat_param_grad_info_if_needed
+from .hook import register_multi_post_grad_hook
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+BACKWARD_POST_QUEUE = deque()
+
+
+@no_type_check
+def _register_pre_backward_hooks(
+ state: _ZeRO3State,
+ module: nn.Module,
+ outputs: Any,
+ handle: FlatParamHandle,
+) -> None:
+ """
+ Registers pre-backward hooks on the tensors that require gradients in the
+ forward pass outputs ``outputs``, which were computed using the
+ ``FlatParameter`` s of ``handles``.
+
+ Args:
+ module (nn.Module): Fully sharded module (see [Note: Fully Sharded
+ Module]).
+
+ Returns:
+ Forward pass outputs with pre-backward hooks registered to tensors that
+ require gradients.
+ """
+ # If there is no gradient computation, then there is no need for
+ # pre-backward logic
+ if not torch.is_grad_enabled():
+ return outputs
+ if state._is_root:
+ state._post_backward_callback_queued = False # only defined on the root
+
+ if handle:
+ handle._needs_pre_backward_unshard = False
+ handle._ran_pre_backward_hook = False
+ # Since these handles' `FlatParameter`s participated in a forward, we
+ # conservatively assume that they will be used in the backward
+
+ def _register_hook(t: torch.Tensor) -> torch.Tensor:
+ if t.requires_grad:
+ t.register_hook(
+ functools.partial(_pre_backward_hook, state, module, handle)
+ )
+ if handle:
+ handle._needs_pre_backward_unshard = True
+ return t
+
+ return _apply_to_tensors(_register_hook, outputs)
+
+
+def _register_post_backward_hook(
+ state: _ZeRO3State,
+ handle: Optional[FlatParamHandle],
+) -> None:
+ # If there is no gradient computation, then there is no need for
+ # post-backward logic
+ if not handle:
+ return
+ flat_param = handle.flat_param
+ inp_tensors = [p for p in flat_param._tensors if p.requires_grad]
+ hook_handle = register_multi_post_grad_hook(
+ inp_tensors, functools.partial(_post_backward_ready_hook, state, handle)
+ )
+ flat_param._post_backward_hook_state = (
+ None, hook_handle) # type: ignore[attr-defined]
+
+
+def _register_post_backward_reshard_only_hook(
+ state: _ZeRO3State,
+ handle: Optional[FlatParamHandle],
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+) -> None:
+ """
+ Registers post-backward hooks to reshard flat parameters that do not
+ require gradient. We register these using multi-post-grad hooks on the
+ input activations to ensure that all gradients that may depend on the
+ parameters have been computed before resharding.
+ """
+ # If there is no gradient computation, then there is no need for
+ # post-backward logic
+ if not torch.is_grad_enabled():
+ return
+ # Construct `inp_tensors` lazily to avoid CPU overhead in typical case
+ # where each flat parameter requires gradient
+ inp_tensors: Optional[List[torch.Tensor]] = None
+ if not handle:
+ return
+ if handle.flat_param.requires_grad:
+ return
+ if inp_tensors is None:
+ args_list, _ = tree_flatten(args)
+ kwargs_list, _ = tree_flatten(kwargs)
+ inp_tensors = [
+ obj
+ for obj in chain(args_list, kwargs_list)
+ if torch.is_tensor(obj) and obj.requires_grad
+ ]
+ _p_assert(inp_tensors is not None, "Got None inp_tensor")
+ hook_handle = register_multi_grad_hook(
+ inp_tensors, functools.partial(_post_backward_reshard, state, handle)
+ )
+ handle.flat_param._post_backward_hook_state = (
+ hook_handle,)
+
+
+@no_type_check
+def _register_post_backward_final_callback(
+ state: _ZeRO3State, module: nn.Module
+) -> None:
+ """
+ Registers the post-backward final callback that runs at the end of the
+ backward pass. This should be called from the root FSDP instance at the
+ beginning of the pre-backward.
+ """
+ _p_assert(
+ state._is_root,
+ "Only the root ZeRo3 instance should register the post-backward callback",
+ )
+ if state._post_backward_callback_queued:
+ return
+ _assert_in_training_states(state, [TrainingState.IDLE])
+ state._post_backward_callback_queued = True
+ Variable._execution_engine.queue_callback(
+ functools.partial(_post_backward_final_callback, state, module)
+ )
+
+
+@no_type_check
+def _pre_forward(
+ state: _ZeRO3State,
+ handle: Optional[FlatParamHandle],
+ unshard_fn: Callable,
+ module: nn.Module,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
+ """
+ Runs the pre-forward logic. This includes an opportunity to unshard
+ currently sharded parameters such as those for the current forward and
+ registering post-backward hooks for these current parameters. This function
+ also converts forward ``args`` and ``kwargs`` to the given precision.
+
+ Args:
+ handles (List[FlatParamHandle]): Handles giving the parameters used in
+ the current forward.
+ unshard_fn (Optional[Callable]): A callable to unshard any currently
+ sharded parameters or ``None`` to not do any unsharding.
+ module (nn.Module): Module whose forward this method runs right before;
+ expected by the hook signature.
+ args (Tuple[Any, ...]): Module forward ``args``.
+ kwargs (Dict[str, Any]): Module forward ``kwargs``.
+ """
+ with torch.profiler.record_function(f"LayerZeRO3._pre_forward"):
+ # For `fully_shard` + `checkpoint`, skip pre-forward logic in the
+ # recomputed forward
+ if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
+ return args, kwargs
+ state.training_state = TrainingState.FORWARD_BACKWARD
+ state._exec_order_data.record_pre_forward(handle, module.training)
+ if handle:
+ handle._training_state = HandleTrainingState.FORWARD
+
+ with torch.autograd.profiler.record_function("Unshard Function"):
+ if unshard_fn is not None:
+ unshard_fn(state, handle)
+ if handle:
+ handle._use_unsharded_views(as_params=False)
+ if constants.AUTO_CAST_INPUT and state.mixed_precision:
+ # Recursively convert args and kwargs to specified precision.
+ input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
+ args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
+ _register_post_backward_reshard_only_hook(state, handle, args, kwargs)
+ return args, kwargs
+
+
+@no_type_check
+def _post_forward(
+ state: _ZeRO3State,
+ handle: Optional[FlatParamHandle],
+ reshard_fn: Callable,
+ module: nn.Module,
+ inputs: Any,
+ output: Any,
+) -> Any:
+ """
+ Runs the post-forward logic. This includes an opportunity to reshard
+ currently unsharded parameters such as those used in the current forward
+ and registering pre-backward hooks on the forward outputs.
+
+ Args:
+ handles (List[FlatParamHandle]): Handles giving the parameters used in
+ the current forward.
+ reshard_fn (Optional[Callable]): A callable to reshard any currently
+ unsharded parameters (e.g. from the current forward) or ``None`` to
+ not do any resharding.
+ module (nn.Module): Module whose forward just ran, which should be a
+ fully sharded module (see [Note: Fully Sharded Module]); expected
+ by the hook signature.
+ input (Any): Unused; expected by the hook signature.
+ output (Any): Forward pass output; pre-backward hooks are registered on
+ the tensors that require gradients in this output.
+
+ Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
+ parameter.
+ """
+ with torch.profiler.record_function(f"LayerZeRO3._post_forward"):
+ # For `fully_shard` + `checkpoint`, skip post-forward logic in the
+ if handle and handle._training_state != HandleTrainingState.FORWARD:
+ return output
+ #! adapt megatron AC to avoid free after forward
+ if handle and not handle.enter_backward:
+ state._exec_order_data.record_post_forward(handle)
+ with torch.autograd.profiler.record_function("Reshard Function"):
+ if reshard_fn is not None:
+ reshard_fn(state, handle)
+ # Register pre-backward hooks to unshard the flat parameters for the
+ # gradient computation (if needed)
+ output = _register_pre_backward_hooks(state, module, output, handle)
+ state.training_state = TrainingState.IDLE
+ if handle:
+ handle._training_state = HandleTrainingState.IDLE
+ return output
+
+
+@no_type_check
+def _pre_backward_hook(
+ state: _ZeRO3State,
+ module: nn.Module,
+ handle: FlatParamHandle,
+ grad,
+ *unused: Any,
+) -> Any:
+ """
+ Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation.
+
+ Args:
+ module (nn.Module): Fully sharded module (see [Note: Fully Sharded
+ Module]).
+ Post Condition:
+ parameter in unshard and unpadded, used for grad compute
+ grad is unshard and unpadded.
+ """
+ # Only run the pre-backward hook once per group of handles involved in the
+ # same module forward computation
+ if handle and getattr(handle, "_ran_pre_backward_hook", False):
+ return grad
+ if handle:
+ handle.enter_backward = True
+ with torch.profiler.record_function(f"LayerZeRO3._pre_backward_hook"):
+ # Queue the post-backward callback once for the root FSDP instance to
+ # attach it to the outermost backward graph task so that it is called
+ # after all backward calls complete
+ if state._is_root and not state._post_backward_callback_queued:
+ _register_post_backward_final_callback(state, module)
+ _reset_flat_param_grad_info_if_needed(state._all_handles)
+ elif handle:
+ allowed_states = [TrainingState.IDLE]
+ if _is_composable(state):
+ allowed_states.append(TrainingState.FORWARD_BACKWARD)
+ _assert_in_training_states(state, allowed_states)
+
+ state.training_state = TrainingState.FORWARD_BACKWARD
+ # Queueing the post-backward callback is the only logic that is not
+ # per-handle in the pre-backward hook, so we can return early here if
+ # there are no handles.
+ if not handle:
+ return grad
+ #! ensure that last handle has finished accumulate grad (backward) on cpu
+ if len(BACKWARD_POST_QUEUE) > 0:
+ (_last_state, _last_handle) = BACKWARD_POST_QUEUE.popleft()
+ _post_backward_hook(_last_state, _last_handle)
+ handle._training_state = HandleTrainingState.BACKWARD_PRE
+ _register_post_backward_hook(state, handle)
+ _pre_forward_backward_unshard(state, handle)
+ _pre_bwd_reload_full_prec_grad(state, handle)
+ #! alloc memory on default stream if not allocated
+ handle.prepare_gradient_for_backward()
+ handle._ran_pre_backward_hook = True
+ return grad
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_ready_hook(
+ state: _ZeRO3State,
+ handle: FlatParamHandle,
+ *unused: Any,
+):
+ if not handle:
+ return
+ BACKWARD_POST_QUEUE.append((state, handle))
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_hook(
+ state: _ZeRO3State,
+ handle: FlatParamHandle,
+ *unused: Any,
+):
+ """
+ Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
+
+ Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
+ unsharded gradient for the local batch.
+
+ Postcondition:
+ - If no sync, then the ``.grad`` attribute is the reduced
+ unsharded gradient.
+ - Otherwise, the ``_saved_grad`` attribute is the reduced sharded
+ gradient.
+ """
+ flat_param = handle.flat_param
+ handle.enter_backward = False
+
+ with torch.autograd.profiler.record_function(
+ f"LayerZeRO3._post_backward_hook"
+ ):
+ _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
+ # For multiple applications of reentrant AC across submodules sharing
+ # the same `FlatParameter`, the post-backward hook may run multiple
+ # times in one backward, in which case we permit the state to already
+ # be in `BACKWARD_POST`.
+ _p_assert(
+ handle._training_state
+ in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
+ f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
+ )
+ handle._training_state = HandleTrainingState.BACKWARD_POST
+
+ if flat_param.grad is None:
+ return
+ if flat_param.grad.requires_grad:
+ raise RuntimeError("ZeRO3 does not support gradients of gradients")
+
+ _post_backward_reshard(state, handle)
+ _accumulate_grad(state, handle)
+ reduce_scatter_sync_gradients(state, handle)
+ handle._ran_post_backward_hook = True
+
+
+def reduce_scatter_sync_gradients(
+ state: _ZeRO3State,
+ handle: FlatParamHandle):
+ '''
+ Performs a sync in zero1 process group
+ '''
+ with torch.autograd.profiler.record_function(f"Reduce Scatter Gradients"):
+ if not state._sync_gradients:
+ return
+ flat_param = handle.flat_param
+ if flat_param is not None and flat_param._post_backward_called:
+ return
+ flat_param._post_backward_called = True
+ if state.backward_reduce_scatter == BackwardReduceScatter.BACKWARD_PRE:
+ state.wait_critical_path_events()
+ _reduce_grad(state, handle)
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_final_callback_no_sync(
+ state: _ZeRO3State,
+ module: nn.Module,
+):
+ if not state._is_root or state._sync_gradients:
+ raise RuntimeError("The post-backward no sync callback should only be called \
+ on the root FSDP instance without sync gradients")
+
+ while len(BACKWARD_POST_QUEUE) > 0:
+ (_last_state, _last_handle) = BACKWARD_POST_QUEUE.popleft()
+ _post_backward_hook(_last_state, _last_handle)
+
+ root_state: _ZeRO3State = state
+ root_state._exec_order_data.next_iter_during_accumulation()
+ for zero3_state in state._all_zero3_states:
+ zero3_state.training_state = TrainingState.IDLE
+ handle: FlatParamHandle = zero3_state._handle
+ if handle:
+ handle._ran_pre_backward_hook = False
+ handle._ran_post_backward_hook = False
+ handle._training_state = HandleTrainingState.IDLE
+ handle.prev_iter_synced = False
+ if handle._offload_grads:
+ while True:
+ offload_event = root_state._offload_event_queue._dequeue()
+ if offload_event:
+ (event, last_handle) = offload_event
+ event.wait()
+ last_handle.free_full_prec_grad()
+ else:
+ break
+ root_state._post_backward_callback_queued = False
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_final_callback_sync_gradients(
+ state: _ZeRO3State,
+ module: nn.Module
+):
+ if not (state._is_root and state._sync_gradients):
+ raise RuntimeError("The post-backward sync callback should \
+ only be called on the root FSDP instance with sync gradients")
+
+ while len(BACKWARD_POST_QUEUE) > 0:
+ (_last_state, _last_handle) = BACKWARD_POST_QUEUE.popleft()
+ _post_backward_hook(_last_state, _last_handle)
+
+ root_state: _ZeRO3State = state
+ root_state._exec_order_data.next_iter()
+ for zero3_state in state._all_zero3_states:
+ _catch_all_reshard(zero3_state)
+ zero3_state.training_state = TrainingState.IDLE
+ handle: FlatParamHandle = zero3_state._handle
+ #! if post_backward is done, but flat_param has not reduce scatter
+ if state.backward_reduce_scatter == BackwardReduceScatter.BACKWARD_PRE:
+ if handle and handle._ran_post_backward_hook and not handle.flat_param._post_backward_called:
+ reduce_scatter_sync_gradients(zero3_state, handle)
+ if handle:
+ handle._ran_pre_backward_hook = False
+ handle._ran_post_backward_hook = False
+ handle._needs_pre_backward_unshard = False
+ handle._post_forward_index = None
+ handle._training_state = HandleTrainingState.IDLE
+ handle._prefetched = False
+ handle._needs_param_sync = root_state._sync_gradients
+ handle._param_synced = False
+ handle._grad_synced = False
+ #! free handle zero3 shard if _sync_gradients in reshard after backward cause next run we use zero1 shard
+ handle.flat_param._zero3_shard = None
+ handle.prev_iter_synced = True
+
+ _finalize_params(zero3_state)
+ while True:
+ rs_event = root_state._rs_event_queue._dequeue()
+ if rs_event:
+ (rs, last_handle) = rs_event
+ rs.wait()
+ last_handle.free_full_prec_grad()
+ else:
+ break
+
+ compute_stream = state._default_stream
+ compute_stream.wait_stream(root_state._post_backward_stream)
+ for handle in state._all_handles:
+ flat_param = handle.flat_param
+ if flat_param.requires_grad:
+ handle.prepare_gradient_for_zero1()
+ root_state._post_backward_callback_queued = False
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_final_callback(
+ state: _ZeRO3State,
+ module: nn.Module
+):
+ """
+ This waits for the post-backward to finish and performs some final cleanup.
+ This runs at the end of the entire backward pass and should only be called
+ on the root FSDP instance.
+ """
+ if dist.get_rank() == 0:
+ logger.info(
+ f"_post_backward_final_callback Being Called and reset states")
+ if state._sync_gradients:
+ _post_backward_final_callback_sync_gradients(state, module)
+ else:
+ _post_backward_final_callback_no_sync(state, module)
+
+
+@no_type_check
+def _catch_all_reshard(
+ state: _ZeRO3State,
+) -> None:
+ """
+ Reshards the parameters that may not have been resharded in the
+ post-backward hook. This can happen when a module's output is used in the
+ forward pass, meaning that its pre-backward hook runs (unsharding the
+ parameter), but the post-backward hook does not run because the output was
+ not jused in the loss computation corresponding to this backward pass.
+ """
+ # Wrap with a try-except to provide a more informative traceback if an
+ # error is raised
+ try:
+ if state._handle:
+ already_resharded = (
+ state._handle.flat_param.data_ptr()
+ == state._handle.flat_param._zero1_shard.data_ptr()
+ # If FSDP skipped using sharded views, then the flat parameter
+ # still points to the sharded data, so we need to reshard to
+ # use sharded views
+ and not state._handle._skipped_use_sharded_views
+ )
+ if already_resharded:
+ return
+ _reshard(state, state._handle, True)
+ except Exception as e:
+ _p_assert(
+ False,
+ f"Got exception in the catch-all reshard for {state}: {str(e)}",
+ raise_assertion_error=False,
+ )
+ raise e
+
+
+@no_type_check
+def _finalize_params(
+ state: _ZeRO3State,
+) -> None:
+ """Finalizes the parameters before the next iteration.
+ """
+ handle = state._handle
+ if not handle:
+ return
+ flat_param = handle.flat_param
+ if hasattr(flat_param, "_post_backward_hook_state"):
+ post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
+ expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1
+ _p_assert(
+ post_backward_hook_state_len == expected_post_backward_hook_state_len,
+ f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
+ )
+ flat_param._post_backward_hook_state[-1].remove()
+ delattr(flat_param, "_post_backward_hook_state")
+ if flat_param.requires_grad:
+ _p_assert(
+ hasattr(flat_param, "_post_backward_called"),
+ "Expects `_post_backward_called` to be set on the `FlatParameter`",
+ )
+ flat_param._post_backward_called = False
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_grad.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_grad.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f8a365a71ff73d99817a8a59ba1e12925c29fc
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_grad.py
@@ -0,0 +1,94 @@
+from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
+
+import torch.distributed as dist
+
+from mindspeed.core.distributed.layerzero.zero3._common_utils import _ZeRO3State
+from mindspeed.core.distributed.layerzero.zero3.flat_param import FlatParamHandle, HandleTrainingState
+from ._utils import _div_if_needed
+from ._shard import _is_last_order_forward
+
+
+@no_type_check
+def _reduce_grad(state: _ZeRO3State, handle: FlatParamHandle) -> None:
+ """
+ Runs gradient reduction, sharded gradient and the post-reduction callback.
+ if accumulate grad, this func will not be called cause whole param unshard
+ grad will be stored, rather than shard grad.
+ """
+ flat_param = handle.flat_param
+ rs_event = state._rs_event_queue._dequeue()
+ if rs_event:
+ rs, last_hanlde = rs_event
+ rs.wait()
+ last_hanlde.free_full_prec_grad()
+ padded_unsharded_grad, new_sharded_grad = handle._get_reduce_scatter_tensors()
+ _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor)
+ state._post_backward_stream.wait_stream(state._default_stream)
+ with state._device_handle.stream(state._post_backward_stream):
+ dist.reduce_scatter_tensor(
+ new_sharded_grad,
+ padded_unsharded_grad,
+ group=handle._get_reduce_scatter_group(),
+ )
+ reduce_scatter_event = state._device_handle.Event()
+ reduce_scatter_event.record()
+ state._rs_event_queue.enqueue((reduce_scatter_event, handle))
+ #! remove all-reduce logic and shard grad accumulation, and grad view logic
+ handle.set_shard_grad(new_sharded_grad)
+
+
+def offload_grad(
+ state: _ZeRO3State, handle: FlatParamHandle
+):
+ if not handle:
+ return
+ # do not offload the last backward cause it is needed at first
+ if _is_last_order_forward(state, handle):
+ return
+ off_event_handle = state._offload_event_queue._dequeue()
+ if off_event_handle is not None:
+ offload_event, last_handle = off_event_handle
+ offload_event.wait()
+ last_handle.free_full_prec_grad()
+ state._offload_stream.wait_stream(state._default_stream)
+ state._offload_stream.wait_stream(state._unshard_stream)
+ with state._device_handle.stream(state._offload_stream):
+ handle.offload_grad()
+ event = state._device_handle.Event()
+ event.record()
+ state._offload_event_queue.enqueue((event, handle))
+
+
+@no_type_check
+def _pre_bwd_reload_full_prec_grad(
+ state: "_ZeRO3State",
+ handle: Optional["FlatParamHandle"],
+) -> None:
+ if not handle or handle._training_state != HandleTrainingState.BACKWARD_PRE:
+ return
+
+ if state._offload_grads:
+ if not handle.already_load_full_prec_grad():
+ handle.alloc_full_prec_grad()
+ with state._device_handle.stream(state._offload_stream):
+ handle.reload_full_prec_grad()
+ handle._check_padded_unsharded(
+ handle.flat_param._full_prec_grad_padded)
+
+
+def _accumulate_grad(
+ state: "_ZeRO3State",
+ handle: Optional["FlatParamHandle"],
+):
+ if not handle or handle._training_state != HandleTrainingState.BACKWARD_POST:
+ return
+ if not handle.already_load_full_prec_grad():
+ handle.alloc_full_prec_grad()
+ if state._offload_grads:
+ state._default_stream.wait_stream(state._offload_stream)
+ #! accumulate grad on compute stream
+ handle.accumulate_grad()
+ handle.free_runtime_unshard_grad()
+
+ if state._offload_grads and not state._sync_gradients:
+ offload_grad(state, handle)
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_initialize.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_initialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa2b2970e49b70c61301d257c8e8412021c7b036
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_initialize.py
@@ -0,0 +1,134 @@
+from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
+import logging
+import torch.nn as nn
+from torch.distributed.utils import _p_assert
+import torch.distributed as dist
+
+import mindspeed.core.distributed.layerzero.zero3._traversal_utils as traversal_utils
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ _assert_in_training_states,
+ _ZeRO3State,
+ TrainingState,
+)
+from ._utils import (
+ _get_buffers_and_dtypes_for_computation,
+ _cast_buffers_to_dtype_and_device,
+)
+
+
+@no_type_check
+def _lazy_init(
+ state: _ZeRO3State,
+ root_module: nn.Module,
+) -> _ZeRO3State:
+ """
+ Performs initialization lazily, typically right before the first forward
+ pass. The laziness is needed to ensure that the parameter device/dtype and
+ the FSDP hierarchy have finalized. This method's actual logic only runs on
+ the root FSDP instance, which performs initialization for all non-root FSDP
+ instances to avoid partial initialization.
+
+ For the non-composable code path, ``state`` and ``root_module`` should be
+ the same, namely the zero3 instance itself.
+ """
+ if state._is_root is not None:
+ return None
+ if not state._device_handle.is_available():
+ # Allow the FSDP constructor to run even without CUDA but check this
+ # once we start real execution
+ raise RuntimeError("ZeRO3 does not support CPU only execution")
+ # The following logic is only run on the root FSDP instance since it will
+ # set `_is_root=False` for the non-root instances
+ state._is_root = True
+ _assert_in_training_states(state, [TrainingState.IDLE])
+ _check_flat_params_on_expected_device(state, root_module)
+ state._all_zero3_states = traversal_utils._get_zero3_states(root_module)
+ _init_streams(state)
+ buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
+ _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
+ state._exec_order_data.init(state, root_module, state.zero1_process_group)
+ _share_state_and_init_handle_attrs(state, root_module)
+ if dist.get_rank() == 0:
+ logging.info(f"Root Layezero Contains {len(state._all_handles)} non-None handles")
+ return state
+
+
+def _check_flat_params_on_expected_device(state: _ZeRO3State, module: nn.Module):
+ """
+ Checks that all ``FlatParameter``s in ``module`` 's tree managed by
+ ``state`` are on the expected device for *lazy initialization*.
+ """
+ for handle in traversal_utils._get_zero3_handles(module):
+ if handle.flat_param.device != state.compute_device:
+ raise RuntimeError(
+ "An ZeRO3-managed module unexpectedly has parameters on "
+ f"{handle.flat_param.device}. Make sure to move the module to "
+ f"{state.compute_device} before training."
+ )
+
+
+@no_type_check
+def _share_state_and_init_handle_attrs(
+ root_state: _ZeRO3State,
+ root_module: nn.Module,
+) -> None:
+ """
+ Shares data structure state from the ``root_state`` to all zero3 states in
+ ``root_module`` 's module tree, and initializes handle attributes. These
+ are done together to require a single loop over the states.
+ """
+ handle = root_state._handle
+ if handle:
+ handle.init_flat_param_attributes()
+ root_state._all_handles = root_state._exec_order_data.all_handles # share reference
+ for zero3_state in root_state._all_zero3_states:
+ if zero3_state is root_state:
+ continue
+ _p_assert(
+ zero3_state._is_root is None or not zero3_state._is_root,
+ "Non-root FSDP instance's `_is_root` should not have been "
+ "set yet or should have been set to `False`",
+ )
+ zero3_state._is_root = False
+ zero3_state._unshard_stream = root_state._unshard_stream
+ zero3_state._post_backward_stream = root_state._post_backward_stream
+ zero3_state._pre_unshard_stream = root_state._pre_unshard_stream
+ zero3_state._default_stream = root_state._default_stream
+ zero3_state._offload_stream = root_state._offload_stream
+
+ zero3_state._exec_order_data = root_state._exec_order_data
+ zero3_state._free_event_queue = root_state._free_event_queue
+ zero3_state._rs_event_queue = root_state._rs_event_queue
+ zero3_state._offload_event_queue = root_state._offload_event_queue
+ handle = zero3_state._handle
+ if handle:
+ handle.init_flat_param_attributes()
+
+
+@no_type_check
+def _init_streams(
+ state: _ZeRO3State,
+) -> None:
+ """
+ Initializes streams for overlapping communication, computation, and
+ data transfers. The streams should be shared across zero3 instances.
+ """
+ if not (state._is_root and state._device_handle.is_available()):
+ raise RuntimeError(f"state is not initialized or device not available")
+ # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and
+ # preserve the default priority of 0 otherwise
+ high_priority = 1
+ mid_priority = 2
+ low_priority = 3
+ # Default stream for computation
+ state._default_stream = state._device_handle.current_stream()
+ # Stream for unshard logic, including allocating the all-gather destination
+ # tensors and the all-gathers themselves
+ state._unshard_stream = state._device_handle.Stream(priority=mid_priority)
+ # Stream for overlapping gradient reduction with the backward pass gradient
+ # computation
+ state._post_backward_stream = state._device_handle.Stream(priority=low_priority)
+ # Stream for pre-unshard logic, namely allocations and writes for CPU
+ # offloading (H2D copy) and mixed precision (low precision cast)
+ state._offload_stream = state._device_handle.Stream(priority=low_priority)
+ state._pre_unshard_stream = state._device_handle.current_stream()
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_root_forward.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_root_forward.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d39d1b62a06e49de1cf301d9346b98dd5c8ef6e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_root_forward.py
@@ -0,0 +1,69 @@
+from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
+
+import torch
+import torch.nn as nn
+from torch.distributed.utils import (
+ _cast_forward_inputs,
+ _p_assert,
+ _to_kwargs,
+)
+from mindspeed.core.distributed.layerzero import constants
+from mindspeed.core.distributed.layerzero.zero3._common_utils import _ZeRO3State, _is_composable
+from mindspeed.core.distributed.layerzero.zero3.flat_param import FlatParamHandle
+
+from ._utils import (
+ _reset_flat_param_grad_info_if_needed,
+ _wait_for_computation_stream
+)
+from ._initialize import _lazy_init
+
+
+@no_type_check
+def _zero3_root_pre_forward(
+ state: _ZeRO3State,
+ module: nn.Module,
+ args,
+ kwargs,
+) -> None:
+ with torch.profiler.record_function("LayerZeRO3._root_pre_forward_check"):
+ _lazy_init(state, module)
+ _p_assert(state._is_root is not None,
+ "Expects a root ZeRO3 to have been set")
+ if not state._is_root:
+ if constants.AUTO_CAST_INPUT and _is_composable(state):
+ return _root_cast_forward_input(state, module, args, kwargs)
+ return args, kwargs
+
+ with torch.profiler.record_function("LayerZeRO3._root_pre_forward"):
+ if state.forward_prefetch:
+ handles: List[FlatParamHandle] = []
+ for zero3_state in state._all_zero3_states:
+ if zero3_state._handle:
+ handles.append(zero3_state._handle)
+ for handle in handles:
+ handle._needs_pre_forward_unshard = True
+
+ _wait_for_computation_stream(
+ state._default_stream, state._unshard_stream, state._pre_unshard_stream)
+ _reset_flat_param_grad_info_if_needed(state._all_handles)
+
+ # Prepares the forward inputs by moving them to ``compute_device``
+ # the perf with/without it.
+ with torch.profiler.record_function("LayerZeRO3._to_kwargs"):
+ args_tuple, kwargs_tuple = _to_kwargs(
+ args, kwargs, state.compute_device, False
+ )
+ args = args_tuple[0]
+ kwargs = kwargs_tuple[0]
+ return args, kwargs
+
+
+@no_type_check
+def _root_cast_forward_input(
+ state: _ZeRO3State, module: torch.nn.Module, args, kwargs
+) -> Tuple[Any, Any]:
+
+ if module.training and state.mixed_precision is not None:
+ input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
+ args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
+ return args, kwargs
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_shard.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_shard.py
new file mode 100644
index 0000000000000000000000000000000000000000..627e280a3c1e50e6f3e69e3edc1364af0420a32e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_shard.py
@@ -0,0 +1,277 @@
+import logging
+
+from enum import auto, Enum
+from typing import Any, no_type_check, Optional, Set, Tuple, TYPE_CHECKING
+
+import torch
+from torch.distributed.utils import _p_assert
+import torch.distributed as dist
+from mindspeed.core.distributed.layerzero.zero3.api import BackwardPrefetch
+from mindspeed.core.distributed.layerzero.zero3.flat_param import HandleTrainingState
+if TYPE_CHECKING:
+ from mindspeed.core.distributed.layerzero.zero3._common_utils import _ZeRO3State
+ from mindspeed.core.distributed.layerzero.zero3.flat_param import FlatParamHandle
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+
+class _PrefetchMode(Enum):
+ BACKWARD = auto()
+ FORWARD = auto()
+
+
+@no_type_check
+def _unshard(
+ state: "_ZeRO3State",
+ handle: "FlatParamHandle",
+ unshard_stream: torch.Stream,
+ pre_unshard_stream: torch.Stream,
+) -> None:
+ """
+ Unshards the handles in ``handles``. If the handles are in
+ :meth:`summon_full_params` and are using mixed precision, then they are
+ forced to full precision.
+
+ Postcondition: handle's ``FlatParameter`` 's data is the padded
+ unsharded flat parameter on the compute device.
+ """
+ if not handle or not handle.needs_unshard():
+ return
+
+ with state._device_handle.stream(pre_unshard_stream):
+ handle.pre_unshard()
+
+ unshard_stream.wait_stream(pre_unshard_stream)
+ if state.limit_all_gathers:
+ event = state._free_event_queue.dequeue_if_needed()
+ if event:
+ with torch.profiler.record_function(
+ "LayerZeRO3.rate_limiter"
+ ):
+ event.synchronize()
+ with state._device_handle.stream(unshard_stream):
+ handle.unshard()
+ handle.post_unshard()
+
+
+@no_type_check
+def _reshard(
+ state: "_ZeRO3State",
+ handle: "FlatParamHandle",
+ free_unsharded_flat_param: bool,
+):
+ """
+ Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
+ free the handle's padded unsharded flat parameter.
+ """
+ handle.reshard(free_unsharded_flat_param)
+ if state.limit_all_gathers and free_unsharded_flat_param:
+ free_event = state._device_handle.Event()
+ free_event.record()
+ state._free_event_queue.enqueue(free_event)
+ # Since we prefetch entire handles keys at a time, conservatively mark
+ # the entire key as no longer prefetched once we free at least one
+ if free_unsharded_flat_param:
+ handle._prefetched = False
+ else:
+ handle._prefetched = True
+
+
+@no_type_check
+def _pre_forward_backward_unshard(
+ state: "_ZeRO3State",
+ handle: Optional["FlatParamHandle"],
+) -> None:
+ """Unshards parameters in the pre-forward.
+ 1. check handle exists
+ 2. check zero1 synced params to zero3
+ 3. check zero3 prefetched
+ 4. prefetch next layer
+ modified _unshard func, which is called at each all-gather
+
+ """
+ if not handle:
+ return
+ # If the handles have been prefetched, then there is no need to call
+ # `_unshard()` again
+ if handle._training_state not in [HandleTrainingState.FORWARD, HandleTrainingState.BACKWARD_PRE]:
+ return
+
+ in_forward = handle._training_state == HandleTrainingState.FORWARD
+ stage = "forward" if in_forward else "backward"
+ guard_state = f"_needs_pre_{stage}_unshard"
+ if in_forward or getattr(handle, guard_state):
+ _unshard(
+ state,
+ handle,
+ state._unshard_stream,
+ state._pre_unshard_stream
+ )
+ setattr(handle, guard_state, False)
+ state._default_stream.wait_stream(state._unshard_stream)
+ handle._check_unsharded(handle.flat_param.data)
+
+ _prefetch_mode = _PrefetchMode.FORWARD if handle._training_state == HandleTrainingState.FORWARD else _PrefetchMode.BACKWARD
+ with torch.profiler.record_function(
+ f"LayerZeRO3._pre_{stage}_prefetch"
+ ):
+ _prefetch_handle(state, handle, _prefetch_mode)
+
+
+def _is_last_order_forward(
+ state: "_ZeRO3State",
+ handle: "FlatParamHandle"
+) -> bool:
+ return handle._post_forward_index == len(state._exec_order_data.all_handles) - 1
+
+
+@no_type_check
+def _post_forward_reshard(
+ state: "_ZeRO3State",
+ handle: "FlatParamHandle",
+) -> None:
+ """Reshards parameters in the post-forward.
+ """
+ if not handle:
+ return
+ free_unsharded_flat_param = not _is_last_order_forward(state, handle)
+ with torch.profiler.record_function(
+ "LayerZeRO3._post_forward_reshard"
+ ):
+ _reshard(state, handle, free_unsharded_flat_param)
+
+
+def _post_backward_reshard(
+ state: "_ZeRO3State",
+ handle: "FlatParamHandle",
+ *unused: Any,
+) -> None:
+ free_unsharded_flat_param = not (
+ handle._pre_forward_order_index == 0 and not state._sync_gradients)
+ with torch.profiler.record_function(
+ "LayerZeRO3._post_backward_reshard"
+ ):
+ _reshard(state, handle, free_unsharded_flat_param)
+
+ with torch.profiler.record_function(
+ "LayerZeRO3._post_backward_prefetch"
+ ):
+ _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
+
+
+@no_type_check
+def _prefetch_handle(
+ state: "_ZeRO3State",
+ current_handle: Optional["FlatParamHandle"],
+ prefetch_mode: _PrefetchMode,
+) -> None:
+ """
+ Prefetches the next handles if needed (without synchronization). An empty
+ handles key cannot prefetch.
+ """
+ if not current_handle:
+ return
+ handle = _get_handle_to_prefetch(state, current_handle)
+ if not handle:
+ return
+ # Temporarily emulate the training state while calling `_unshard` to
+ # ensure the correct `as_params` for `_use_unsharded_views()`
+ prev_training_state = handle._training_state
+ if prefetch_mode == _PrefetchMode.BACKWARD:
+ handle._training_state = HandleTrainingState.BACKWARD_PRE
+ elif prefetch_mode == _PrefetchMode.FORWARD:
+ if handle.enter_backward:
+ return
+ handle._training_state = HandleTrainingState.FORWARD
+ else:
+ raise ValueError(f"Invalid prefetch mode on rank {state.zero3_rank}: {prefetch_mode}")
+ # Prefetch the next set of handles without synchronizing to allow
+ # the sync to happen as late as possible to maximize overlap
+ _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
+ handle._training_state = prev_training_state
+ handle._prefetched = True
+
+
+@no_type_check
+def _get_handle_to_prefetch(
+ state: "_ZeRO3State",
+ current_handle: "FlatParamHandle",
+) -> "FlatParamHandle":
+ """
+ Returns a :class:`list` of the handles keys to prefetch for the next
+ module(s), where ``current_handle`` represents the current module.
+
+ "Prefetching" refers to running the unshard logic early (without
+ synchronization), and the "next" modules depend on the recorded execution
+ order and the current training state.
+ """
+ training_state = _get_training_state(current_handle)
+ valid_training_states = (
+ HandleTrainingState.BACKWARD_PRE,
+ HandleTrainingState.BACKWARD_POST,
+ HandleTrainingState.FORWARD,
+ )
+ _p_assert(
+ training_state in valid_training_states,
+ f"Prefetching is only supported in {valid_training_states} but "
+ f"currently in {training_state}",
+ )
+ eod = state._exec_order_data
+ target_handle: Optional["FlatParamHandle"] = None
+ if (
+ training_state == HandleTrainingState.BACKWARD_PRE
+ and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
+ ) or (
+ training_state == HandleTrainingState.BACKWARD_POST
+ and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
+ ):
+ target_handle_candidate = eod.get_handle_to_backward_prefetch(
+ current_handle)
+ if (
+ target_handle_candidate
+ # and target_handle_candidate._needs_pre_backward_unshard
+ and not target_handle_candidate._prefetched
+ ):
+ target_handle = target_handle_candidate
+ else:
+ target_handle = None
+ elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
+ target_handle_candidate = eod.get_handle_to_forward_prefetch(
+ current_handle)
+ if (
+ target_handle_candidate
+ # and target_handle_candidate._needs_pre_forward_unshard
+ and not target_handle_candidate._prefetched
+ ):
+ target_handle = target_handle_candidate
+ else:
+ target_handle = None
+
+ return target_handle
+
+
+def _get_training_state(
+ handle: "FlatParamHandle",
+) -> HandleTrainingState:
+ """Returns the training state of the handles in ``handle``."""
+ _p_assert(handle, "Expects a non-empty handle")
+ return handle._training_state
+
+
+@no_type_check
+def _get_handle_to_post_backward(
+ state: "_ZeRO3State",
+ current_handle: "FlatParamHandle",
+) -> "FlatParamHandle":
+ """
+ Returns the last handle to do post_backward reduce scatter, where ``current_handle`` represents the current module.
+ """
+ eod = state._exec_order_data
+ target_handle: Optional["FlatParamHandle"] = None
+ target_handle = eod.get_handle_to_post_backward(current_handle)
+ if target_handle:
+ return [handle for handle in target_handle
+ if (_get_training_state(handle) == HandleTrainingState.BACKWARD_POST) and not handle.flat_param._post_backward_called]
+ else:
+ return None
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_utils.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..679d91d6b7a973a284dcef23239d708d683e26aa
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/_utils.py
@@ -0,0 +1,194 @@
+from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.distributed.utils import (
+ _p_assert,
+ _apply_to_tensors
+)
+
+import mindspeed.core.distributed.layerzero.zero3._traversal_utils as traversal_utils
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ _assert_in_training_states,
+ _get_module_zero3_state,
+ _no_dispatch_record_stream,
+ clean_tensor_name,
+ _ZeRO3State,
+ TrainingState,
+)
+from mindspeed.core.distributed.layerzero.zero3.flat_param import (
+ FlatParameter,
+ FlatParamHandle,
+)
+
+
+def print0(msg):
+ if dist.get_rank() == 0:
+ print(msg)
+
+
+def _get_zero3_root_states_with_modules(
+ module: nn.Module,
+) -> Tuple[List[_ZeRO3State], List[nn.Module]]:
+ """
+ Returns a tuple containing:
+ 1. A list of the root ``_FSDPState`` instances in the module tree rooted at
+ ``module`` without any duplicates and following the ``module.modules()``
+ traversal order (which is assumed to be depth-first).
+ 2. A corresponding list of the root modules owning the states in the first
+ list.
+
+ This is similar to :func:`_get_zero3_states_with_modules` except that we
+ must call :func:`_is_fsdp_root` to force a lazy initialization to determine
+ the FSDP root in case lazy initialization has not yet happened.
+ """
+ zero3_root_states: List[_ZeRO3State] = []
+ zero3_root_modules: List[nn.Module] = []
+ visited_zero3_states: Set[_ZeRO3State] = set()
+ # NOTE: This function assumes that `module.modules()` proceeds top-down.
+ for submodule in module.modules():
+ optional_state = _get_module_zero3_state(submodule)
+ if (
+ optional_state is not None
+ and optional_state not in visited_zero3_states
+ and _is_zero3_root(optional_state, submodule)
+ ):
+ visited_zero3_states.add(optional_state)
+ zero3_root_states.append(optional_state)
+ zero3_root_modules.append(submodule)
+ return zero3_root_states, zero3_root_modules
+
+
+def _get_zero3_root_states(module: nn.Module) -> List[_ZeRO3State]:
+ """See :func:`_get_zero3_root_states_with_modules`."""
+ zero3_root_states, _ = _get_zero3_root_states_with_modules(module)
+ return zero3_root_states
+
+
+def _is_zero3_root(state: _ZeRO3State, module: nn.Module) -> bool:
+ """
+ Returns if ``state`` corresponds to that of an zero3 root.
+
+ For the wrapper code path, ``state`` and ``module`` should be the same. For
+ the non-wrapper code path, ``state`` should be ``module`` 's state.
+ """
+ if state._is_root is None:
+ raise ValueError(f"state is not initialized")
+ return state._is_root
+
+
+def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None:
+ if div_factor > 1:
+ tensor.div_(div_factor)
+
+
+def _wait_for_computation_stream(
+ computation_stream: torch.Stream,
+ unshard_stream: torch.Stream,
+ pre_unshard_stream: torch.Stream,
+):
+ """
+ Has the unshard and pre-unshard streams wait for the computation stream.
+ For example, this should be called in the zero3 root's pre-forward to
+ respect optimizer step computation.
+ """
+ unshard_stream.wait_stream(
+ computation_stream) # type: ignore[attr-defined]
+ # Having the pre-all-gather stream wait for the current stream even if we
+ # do not leverage the pre-all-gather stream is tolerable since this only
+ # runs once per iteration
+ # type: ignore[attr-defined]
+ pre_unshard_stream.wait_stream(computation_stream)
+
+
+@no_type_check
+def _get_buffers_and_dtypes_for_computation(
+ state: _ZeRO3State,
+ root_module: nn.Module,
+) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
+ """
+ Returns all buffers in the module tree rooted at ``root_module`` and a
+ corresponding list of the buffer dtypes for computation. Each buffer dtype
+ is either ``None`` if buffer mixed precision is not enabled or the buffer
+ low precision dtype otherwise.
+ """
+ _p_assert(state._is_root, "Expects the root to cast buffers")
+ buffers: List[torch.Tensor] = []
+ buffer_dtypes: List[Optional[torch.dtype]] = []
+ visited_buffers: Set[torch.Tensor] = set()
+ # Traverse the FSDP states bottom-up so that we prefer the owning FSDP
+ # instance's mixed precision setting for each buffer
+ zero3_states, zero3_modules = traversal_utils._get_zero3_states_with_modules(
+ root_module
+ )
+ for zero3_state, zero3_module in zip(reversed(zero3_states), reversed(zero3_modules)):
+ for buffer_name, buffer in zero3_module.named_buffers():
+ if buffer in visited_buffers:
+ continue
+ visited_buffers.add(buffer)
+ if clean_tensor_name(buffer_name) in zero3_state._ignored_buffer_names:
+ continue
+ buffers.append(buffer)
+ buffer_dtypes.append(zero3_state.mixed_precision.buffer_dtype)
+ _p_assert(len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}")
+ return buffers, buffer_dtypes
+
+
+def _cast_buffers_to_dtype_and_device(
+ buffers: List[torch.Tensor],
+ buffer_dtypes: List[Optional[torch.dtype]],
+ device: torch.device,
+) -> None:
+ """
+ Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
+ to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
+ corresponding buffer is only moved to ``device``.
+ """
+ _p_assert(
+ buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
+ f"Expects `buffers` and `buffer_dtypes` to have the same length if "
+ f"`buffer_dtypes` is specified but got {len(buffers)} and "
+ f"{len(buffer_dtypes)}",
+ )
+ for buffer, buffer_dtype in zip(buffers, buffer_dtypes):
+ if not torch.is_floating_point(buffer) or buffer_dtype is None:
+ buffer.data = buffer.to(device=device)
+ else:
+ buffer.data = buffer.to(device=device, dtype=buffer_dtype)
+
+
+#!===================== grad==================================================
+def _reset_flat_param_grad_info_if_needed(
+ handles: List[FlatParamHandle],
+):
+ """
+ Clears the original parameters' gradients if needed. This method's CPU
+ overhead is minimal, so we may call it throughout ZeRO3 methods, which serve
+ as callsites to free the gradient memory earlier.
+ """
+ if not isinstance(handles, list):
+ handles = [handles]
+ for handle in handles:
+ handle._reset_flat_param_grad_info_if_needed()
+
+
+def _cast_forward_outputs(
+ dtype: Optional[torch.dtype],
+ output
+) -> Tuple[Any, Any]:
+ """
+ Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
+
+ This respects the existing ``requires_grad`` on the tensors.
+ """
+ if dtype is None:
+ return output
+
+ def cast_fn(x: torch.Tensor) -> torch.Tensor:
+ if not torch.is_floating_point(x) or x.dtype == dtype:
+ return x
+ return x.to(dtype)
+
+ return _apply_to_tensors(cast_fn, output)
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/hook.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d710ce4f5ddcf5b319b12e1496777d1703f619d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/runtime/hook.py
@@ -0,0 +1,78 @@
+import abc
+import threading
+from typing import (
+ Callable,
+ Literal,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
+
+import torch
+from torch.utils.hooks import RemovableHandle
+from torch.autograd.graph import Node
+
+
+class _MultiHandle(RemovableHandle):
+ handles: Tuple[RemovableHandle, ...]
+
+ def __init__(self, handles: Tuple[RemovableHandle, ...]) -> None:
+ self.handles = handles
+
+ def remove(self) -> None:
+ for handle in self.handles:
+ handle.remove()
+
+ def __getstate__(self) -> Tuple[RemovableHandle, ...]:
+ return self.handles
+
+ def __setstate__(self, state: Tuple[RemovableHandle, ...]) -> None:
+ self.handles = state
+
+
+def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, None]) -> Node:
+
+ if not (isinstance(t, torch.Tensor) and t.requires_grad):
+ raise ValueError(
+ f"Expects torch.Tensor with requires_grad=True, got {type(t)}")
+ if t.requires_grad and t.grad_fn is not None:
+ node = t.grad_fn
+ else:
+ with torch.enable_grad():
+ node = t.grad_fn.next_functions[0][0]
+ if node is None:
+ raise AssertionError(
+ f"No graph.Node object returned from tensor.grad_fn")
+ return node
+
+
+def register_multi_post_grad_hook(
+ tensors: Sequence[torch.Tensor],
+ fn: Union[
+ Callable[[Sequence[Optional[torch.Tensor]]], None],
+ Callable[[torch.Tensor], None],
+ ],
+) -> RemovableHandle:
+ """Note:
+ 1. This hook is only called once, so it needs to be re-registered.
+ 2. This hook is called only when all grad_fn or acc node is triggered
+ """
+ lock = threading.Lock()
+ nb_calls = 0
+ grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
+ len_tensors = len(tensors)
+
+ def get_inner_hook() -> Callable[[torch.Tensor], None]:
+ def inner_hook(*grad: torch.Tensor) -> None:
+ nonlocal len_tensors, nb_calls, fn
+ with lock:
+ nb_calls += 1
+ if len_tensors == nb_calls:
+ fn()
+ return inner_hook
+
+ handles = tuple(
+ t.register_hook(get_inner_hook()) for t in grad_fns
+ )
+ return _MultiHandle(handles)
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/__init__.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/fqn.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/fqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..090d9b9c33290aac743baad789495016a56acb4c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/fqn.py
@@ -0,0 +1,28 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from dataclasses import dataclass
+import torch
+
+ITERATION_KEY = "iteration"
+ARGS_KEY = "args"
+LOCAL_NAME_TO_FQN_KEY = "shard_state_dict"
+D3PARALLEL_KEY = ""
+MODEL_SD_KEY = "model"
+
+
+@dataclass
+class ShardFlattenInfo:
+ '''
+ This class is unsed for saving flatten shard parameter global info
+ and helps to convert full param into shard param
+
+ [offset, offset+numel]
+ '''
+ in_shard: bool
+ numel: int
+ intra_param_start_idx: int
+ intra_param_end_idx: int # inclusive
+ shape: torch.Size
+ tensor_model_parallel: bool
+ partition_dim : int
+ partition_stride : int
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/mga_checkpoint.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/mga_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..200937c74d1b288baea950474d0b594d25de0ea7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/mga_checkpoint.py
@@ -0,0 +1,293 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import os
+import random
+import warnings
+import sys
+import numpy as np
+
+import torch
+import torch.distributed as dist
+from megatron.training.checkpointing import get_rng_state
+from megatron.training.global_vars import get_args
+from megatron.training.utils import print_rank_0
+from megatron.core import mpu, tensor_parallel
+
+from .state_dict import shard_state_dict, clean_ignored_modules, use_zero1_params
+from .optim_state import _shard_optim_state_dict
+
+PARALLE_STATE_KAY = "parallel_state"
+MODEL_KEY = "model"
+RNG_STATE_KEY = "rng_state"
+SHRAD_KEY = "shard_state_dict"
+EMA_MODEL_KEY = "ema_model"
+OPTIM_STATE_KEY = "optimizer"
+OPTIM_INFO_KEY = "optimizer_param_key_to_fqn"
+OPTIM_SCHEDULER_KEY = "opt_param_scheduler"
+LR_SCHEDULER_KEY = "lr_scheduler"
+
+
+def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far=None, checkpointing_context=None,
+ pipeline_rank=None, expert_rank=None, tensor_rank=None, pipeline_parallel=None, expert_parallel=None):
+ """Save a model checkpoint.
+
+ Checkpointing context is used to persist some checkpointing state
+ throughout a single job. Must be initialized externally (not used if None).
+ """
+ args = get_args()
+ if not hasattr(args, "save"):
+ setattr(args, "save", "ckpt")
+ print_rank_0('saving checkpoint at iteration {:7d} to {} '.format(
+ iteration, args.save))
+ rng_state = get_rng_state(False)
+ checkpoint_name = get_checkpoint_name(args.save, iteration, release=False)
+
+ # Collect args, model, RNG.
+ state_dict = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state,
+ False, iteration)
+ state_dict[PARALLE_STATE_KAY] = generate_3D_parallel_state()
+ state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
+
+ ensure_directory_exists(checkpoint_name)
+ print_rank_0(f"Start Saving to {checkpoint_name}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ torch.save(state_dict, checkpoint_name)
+
+ if dist.is_initialized():
+ dist.barrier()
+
+
+def generate_3D_parallel_state():
+ # Ensure the distributed environment is initialized
+ if not dist.is_initialized():
+ raise RuntimeError("Distributed environment is not initialized.")
+
+ # Ensure Megatron's parallel utilities are initialized
+ if not mpu.is_initialized():
+ raise RuntimeError(
+ "Megatron's parallel utilities are not initialized.")
+
+ # Get global rank
+ global_rank = dist.get_rank()
+ # Get tensor parallel rank
+ tp_rank = mpu.get_tensor_model_parallel_rank()
+ # Get pipeline parallel rank
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
+ # Get data parallel rank
+ dp_rank = mpu.get_data_parallel_rank()
+ # Get tensor parallel degree
+ tp_degree = mpu.get_tensor_model_parallel_world_size()
+ # Get pipeline parallel degree
+ pp_degree = mpu.get_pipeline_model_parallel_world_size()
+ # Get data parallel degree
+ dp_degree = mpu.get_data_parallel_world_size()
+
+ # Assemble the dictionary
+ parallel_state = {
+ 'tp_rank': tp_rank,
+ 'pp_rank': pp_rank,
+ 'dp_rank': dp_rank,
+ 'tp_degree': tp_degree,
+ 'pp_degree': pp_degree,
+ 'dp_degree': dp_degree,
+ 'global_rank': global_rank
+ }
+
+ return parallel_state
+
+
+def generate_state_dict(args, model, optimizer, opt_param_scheduler,
+ rng_state, use_dist_ckpt=False, iteration=None,
+ optim_sd_kwargs=None):
+ # Arguments, iteration, and model.
+ state_dict = {}
+ state_dict['args'] = args
+ state_dict['checkpoint_version'] = 3.0
+ if iteration is not None:
+ state_dict['iteration'] = iteration
+
+ if not len(model) == 1:
+ raise ValueError(f"Only single model is supported, VPP not supported")
+ use_zero1_params(model[0])
+ state_dict[MODEL_KEY] = clean_ignored_modules(
+ model[0], model[0].state_dict())
+ state_dict[SHRAD_KEY] = shard_state_dict(model[0], state_dict[MODEL_KEY])
+
+ # Optimizer stuff.
+ if not args.no_save_optim:
+ if optimizer is not None:
+ state_dict[OPTIM_STATE_KEY] = optimizer.state_dict()
+ state_dict[OPTIM_INFO_KEY] = _shard_optim_state_dict(
+ model[0], optimizer.optimizer, state_dict[OPTIM_STATE_KEY])
+ if getattr(args, "optimizer_selection", None) == 'fused_ema_adamw':
+ try:
+ ema_optimizer_applier(optimizer)
+ state_dict[EMA_MODEL_KEY] = clean_ignored_modules(
+ model[0], model[0].state_dict())
+ state_dict = ema_state_dict_to_cpu(
+ state_dict, EMA_MODEL_KEY)
+ ema_optimizer_restore(optimizer)
+ print_rank_0("Ema model successful saved in state_dict")
+ except KeyError:
+ warnings.warn(
+ f"ema_optimizer_applier failed with KeyError, ema_model not saved")
+ if opt_param_scheduler is not None:
+ state_dict[OPTIM_SCHEDULER_KEY] = \
+ opt_param_scheduler.state_dict()
+ # RNG states.
+ if not args.no_save_rng:
+ state_dict[RNG_STATE_KEY] = rng_state
+ return state_dict
+
+
+def get_checkpoint_name(checkpoints_path, iteration, release=False):
+ """Determine the directory name for this rank's checkpoint."""
+ if checkpoints_path is None:
+ raise ValueError("checkpoints_path cannot be None")
+ if release:
+ directory = 'release'
+ else:
+ directory = 'iter_{:07d}'.format(iteration)
+ common_path = os.path.join(checkpoints_path, directory)
+ global_rank = dist.get_rank()
+ return os.path.join(common_path, f"model_{global_rank}.pt")
+
+
+def ensure_directory_exists(filename, check_parent=True):
+ """Build filename's path if it does not already exists."""
+ if filename is None:
+ raise AssertionError(f"Got {filename} filename")
+ dirname = os.path.dirname(filename) if check_parent else filename
+ os.makedirs(dirname, exist_ok=True)
+
+
+def load_layerzero_checkpoint(models, ckpt_dir, optimizer=None, opt_param_scheduler=None):
+ if ckpt_dir is None:
+ raise AssertionError(f"Got {ckpt_dir} filename")
+ if len(models) != 1:
+ raise ValueError(f"VPP is not supported by layerzero currently")
+ rank = dist.get_rank()
+ sd_file = os.path.join(ckpt_dir, f"model_{rank}.pt")
+ if not os.path.exists(sd_file):
+ raise FileNotFoundError(
+ f"No checkpoint found in load directory or pretrained directory: no such file {sd_file}")
+ args = get_args()
+ state_dict = torch.load(sd_file)
+ for i in range(len(models)):
+ models[i].load_state_dict(state_dict[MODEL_KEY], strict=False)
+ if not args.finetune and not args.no_load_optim:
+ try:
+ # Load state dict.
+ if optimizer is not None:
+ optimizer.load_state_dict(state_dict[OPTIM_STATE_KEY])
+ if opt_param_scheduler is not None:
+ if LR_SCHEDULER_KEY in state_dict: # backward compatbility
+ opt_param_scheduler.load_state_dict(
+ state_dict[LR_SCHEDULER_KEY])
+ else:
+ opt_param_scheduler.load_state_dict(
+ state_dict[OPTIM_SCHEDULER_KEY])
+ except KeyError as e:
+ raise RuntimeError('Unable to load optimizer from checkpoint {}. '
+ 'Specify --no-load-optim or --finetune to prevent '
+ 'attempting to load the optimizer state, '
+ 'exiting ...'.format(ckpt_dir)) from e
+ args.num_floating_point_operations_so_far = state_dict.get(
+ 'num_floating_point_operations_so_far', 0)
+ if args.finetune:
+ iteration = 0
+ else:
+ try:
+ iteration = state_dict['iteration']
+ except KeyError:
+ iteration = 0
+ args.iteration = iteration
+
+ # Check arguments.
+ update_consumed_samples(args, state_dict)
+ # rng states.
+ resume_rng_states(args, state_dict)
+
+ # Some utilities want to load a checkpoint without distributed being initialized
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+ print_rank_0(f' successfully loaded checkpoint from {ckpt_dir} '
+ f'[ t {mpu.get_tensor_model_parallel_rank()}, '
+ f'p {mpu.get_pipeline_model_parallel_rank()} ] '
+ f'at iteration {iteration}')
+ return args.iteration, args.num_floating_point_operations_so_far
+
+
+def update_consumed_samples(args, state_dict):
+ if 'args' in state_dict and not args.finetune:
+ checkpoint_args = state_dict['args']
+ args.consumed_train_samples = getattr(checkpoint_args,
+ 'consumed_train_samples', 0)
+ try:
+ from megatron.core.num_microbatches_calculator import update_num_microbatches
+ update_num_microbatches(
+ consumed_samples=args.consumed_train_samples)
+ except ImportError:
+ pass
+ args.consumed_valid_samples = getattr(checkpoint_args,
+ 'consumed_valid_samples', 0)
+ else:
+ print_rank_0('could not find arguments in the checkpoint ...')
+
+
+def resume_rng_states(args, state_dict):
+ if not args.finetune and not args.no_load_rng:
+ try:
+ if RNG_STATE_KEY in state_dict:
+ # access rng_state for data parallel rank
+ if args.data_parallel_random_init:
+ rng_state = state_dict[RNG_STATE_KEY][mpu.get_data_parallel_rank(
+ )]
+ else:
+ rng_state = state_dict[RNG_STATE_KEY][0]
+ random.setstate(rng_state['random_rng_state'])
+ np.random.set_state(rng_state['np_rng_state'])
+ torch.set_rng_state(rng_state['torch_rng_state'])
+ torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
+ # Check for empty states array
+ if not rng_state['rng_tracker_states']:
+ raise KeyError
+ tensor_parallel.get_cuda_rng_tracker().set_states(
+ rng_state['rng_tracker_states'])
+ else: # backward compatability
+ random.setstate(state_dict['random_rng_state'])
+ np.random.set_state(state_dict['np_rng_state'])
+ torch.set_rng_state(state_dict['torch_rng_state'])
+ torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
+ # Check for empty states array
+ if not state_dict['rng_tracker_states']:
+ raise KeyError
+ tensor_parallel.get_cuda_rng_tracker().set_states(
+ state_dict['rng_tracker_states'])
+ except KeyError as e:
+ raise RuntimeError('Unable to load rng state from checkpoint '
+ 'Specify --no-load-rng or --finetune to prevent '
+ 'attempting to load the rng state, '
+ 'exiting ...') from e
+
+
+def ema_optimizer_applier(optimizer):
+ if hasattr(optimizer, "optimizer"):
+ optimizer.optimizer.store(optimizer.optimizer.param_groups)
+ optimizer.optimizer.copy_to()
+ return
+
+
+def ema_optimizer_restore(optimizer):
+ if hasattr(optimizer, "optimizer"):
+ optimizer.optimizer.restore(optimizer.optimizer.param_groups)
+ return
+
+
+def ema_state_dict_to_cpu(state_dict, ema_key):
+ for k, v in state_dict[ema_key].items():
+ if not torch.is_tensor(v):
+ continue
+ new_v = v.detach().cpu().clone()
+ state_dict[ema_key][k] = new_v
+ return state_dict
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/optim_state.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/optim_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..859551d444248065632b1bef3df954b3458b91fc
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/optim_state.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import warnings
+from typing import Dict, List, Optional, Iterable, Union, Any
+
+import torch
+import torch.nn as nn
+
+from ..zero3._common_utils import (
+ clean_tensor_name,
+ _named_parameters_with_duplicates
+)
+
+
+@torch.no_grad()
+def _shard_optim_state_dict(
+ model: nn.Module,
+ optim: torch.optim.Optimizer,
+ optim_state_dict: Dict[str, Any],
+) -> Dict[str, Any]:
+ """
+
+ Args:
+ model (nn.Module): Root module (which may or may not be a
+ :class:`FullyShardedDataParallel` instance) whose parameters
+ were passed into the optimizer ``optim``.
+ optim (torch.optim.Optimizer): Optimizer for ``model`` 's
+ parameters.
+ rank0_only (bool): If ``True``, saves the populated :class:`dict`
+ only on rank 0; if ``False``, saves it on all ranks. (Default:
+ ``True``)
+ shard_state (bool): If ``True``, shard and distribute all
+ non-zero-dimension states.
+
+ Returns:
+ Dict[str, Any]: A :class:`dict` containing the optimizer state that is sharded: FQN - > state_dict.
+ """
+ param_to_fqns = _get_param_to_fqns(model)
+ is_named_optimizer = _is_named_optimizer(optim_state_dict)
+
+ param_key_to_param = _get_param_key_to_param(
+ optim, model, is_named_optimizer, param_to_fqns
+ )
+ param_key_to_fqns, missing_keys = _get_param_key_to_fqns(
+ param_to_fqns, param_key_to_param)
+ if missing_keys:
+ warnings.warn(
+ f"Missing keys that do not have FQN mappings {missing_keys}")
+ return param_key_to_fqns
+
+
+def _get_param_key_to_fqns(param_to_fqns, param_key_to_param):
+ param_key_to_fqns = {}
+ missing_keys = set()
+ for param_key, param in param_key_to_param.items():
+ if param in param_to_fqns:
+ param_key_to_fqns[param_key] = param_to_fqns[param]
+ else:
+ missing_keys.add(param_key)
+ return param_key_to_fqns, missing_keys
+
+
+def _get_param_to_fqns(
+ model: torch.nn.Module,
+ dedup_shared_params: bool = True,
+) -> Dict[nn.Parameter, List[str]]:
+ """
+ Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here,
+ we use canonical to mean the fully-qualified name assigned to the parameter
+ based on its position in the original nn.Module hierarchy before any wrapper
+ or parallelism has been applied to it. This is in contrast to FQNs that may be
+ generated after parallelisms or wrappers have been applied to the model.
+
+ Each normal parameter maps to a singleton list containing its FQN, while each
+ ``FlatParameter`` maps to a list of its original parameter FQNs, which may
+ have length greater than one. All FQNs are prefixed starting from ``model``.
+ """
+ param_to_fqns = {}
+ for param_name, param in _named_parameters_with_duplicates(
+ model
+ ):
+ local_fqns = [param_name]
+ global_fqns = [
+ clean_tensor_name(name) for name in local_fqns
+ ] # prefixed from the top level `model` (i.e. including `prefix`)
+ is_shared_param = param in param_to_fqns
+ if not is_shared_param:
+ param_to_fqns[param] = global_fqns
+ elif not dedup_shared_params:
+ param_to_fqns[param].extend(global_fqns)
+
+ return param_to_fqns
+
+
+def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
+ """
+ Returns whether the state_dict is from a NamedOptimizer.
+ This function checks that the keys in the state_dict['state'] are strings
+ (which usually are FQNs) versus integers (which usually refer to param_ids
+ from a vanilla torch.optim.Optimizer).
+ """
+ state = optim_state_dict.get("state", None)
+ if not state:
+ # If we cannot find a state, assume it is not NamedOptimizer as
+ # NamedOptimizer has eager initialization.
+ return False
+ try:
+ key = next(iter(state.keys()))
+ except Exception as e:
+ raise Exception(optim_state_dict) from e # noqa: TRY002
+ return isinstance(key, str)
+
+
+def _get_param_key_to_param(
+ optim: torch.optim.Optimizer,
+ model: Optional[nn.Module] = None,
+ is_named_optimizer: bool = False,
+ param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
+) -> Dict[Union[int, str], nn.Parameter]:
+ """
+ Constructs a mapping from parameter keys to parameters. For the regular
+ optimizers, the keys are parameter IDs. For NamedOptimizer, the keys
+ are FQNs. This API may be used both for models with ``FlatParameter`` s and
+ without.
+ """
+ clean_fqn_to_fsdp_fqn: Dict[str, str] = {}
+ if is_named_optimizer:
+ if param_to_fqns is None or model is None:
+ raise AssertionError("The optimizer is a NamedOptimizer, `param_to_fqns` must not be None.")
+ for key, _ in _named_parameters_with_duplicates(model):
+ clean_fqn_to_fsdp_fqn[clean_tensor_name(key)] = key
+
+ param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
+ pid = 0
+ for param_group in optim.param_groups:
+ if is_named_optimizer:
+ for param in param_group["params"]:
+ # use_orig_params case
+ if len(param_to_fqns[param]) != 1:
+ raise AssertionError("More than one fqn matches this param")
+ key = param_to_fqns[param][0]
+ try:
+ key = clean_fqn_to_fsdp_fqn[key]
+ except KeyError as e:
+ raise KeyError(
+ f"Can't find {key} from {list(clean_fqn_to_fsdp_fqn.keys())}."
+ ) from e
+ param_key_to_param[key] = param
+ else:
+ for param in param_group["params"]:
+ param_key_to_param[pid] = param
+ pid += 1
+
+ return param_key_to_param
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/__init__.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/convert_to_megatron.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/convert_to_megatron.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f976e25f0a2c8d6165a162c70def7d7eaef1039
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/convert_to_megatron.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import argparse
+import os
+from collections import OrderedDict
+
+import torch
+import mindspeed.megatron_adaptor
+from mindspeed.core.distributed.layerzero.state.scripts import layerzero_checkpointer
+from mindspeed.core.distributed.layerzero.state.scripts.layerzero_checkpointer import LayerzeroCheckpoint
+ARGS_KEY = 'args'
+
+FINAL_LAYER_NORM_KEY = 'final_layernorm'
+CHECKPOINT_VERSION_KEY = 'checkpoint_version'
+CHECKPOINT_VERSION_VALUE = 3.0
+ITERATION_KEY = 'iteration'
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_folder', default=None,
+ type=str, help='Input DeepSpeed Checkpoint folder')
+ parser.add_argument('--output_folder', default=None,
+ type=str, help='Output Megatron checkpoint folder')
+ parser.add_argument('--prefix', default="predictor",
+ help='Model prefix used in Layerzero')
+ parser.add_argument('--target_tp', default=1,
+ type=int, help='Target TP degree')
+ parser.add_argument('--target_pp', default=1,
+ type=int, help='Target PP degree')
+ parser.add_argument('--for_release', action='store_true',
+ help='Convert for release purpose, reset some (progress) counters.')
+ parser.add_argument('--ema_model', action='store_true',
+ help='Convert Ema models')
+ args = parser.parse_args()
+ print(f'args = {args}')
+ return args
+
+
+def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
+ path_list = []
+ iter_folder = f'iter_{iteration:07d}'
+ for i in range(0, tp_degree):
+ path_list.append([])
+ for j in range(0, pp_degree):
+ rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}'
+ ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt')
+ path_list[i].append(os.path.join(
+ base_folder, iter_folder, ckpt_path))
+
+ return path_list
+
+
+def _save_checkpoint(file_path, chkpt_sd):
+ ckpt_dir, _ = os.path.split(file_path)
+ os.makedirs(ckpt_dir, exist_ok=True)
+ torch.save(chkpt_sd, file_path)
+
+
+def _create_rank_checkpoint(zero_checkpoint, tp_index, pp_index, tp_degree, pp_degree, for_release=False):
+ checkpoint_sd = OrderedDict()
+ checkpoint_sd[layerzero_checkpointer.MODEL_SD_KEY] = zero_checkpoint.create_rank_checkpoint(
+ tp_index, pp_index, tp_degree, pp_degree)
+ iteration = zero_checkpoint.get_iteration()
+ checkpoint_sd[ITERATION_KEY] = iteration
+ checkpoint_sd[ARGS_KEY] = zero_checkpoint.get_args()
+ # Adjust specific fields
+ checkpoint_sd[ARGS_KEY].tensor_model_parallel_size = tp_degree
+ checkpoint_sd[ARGS_KEY].pipeline_model_parallel_size = pp_degree
+ if for_release:
+ checkpoint_sd[ARGS_KEY].consumed_train_samples = 0
+ checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0
+ checkpoint_sd[CHECKPOINT_VERSION_KEY] = CHECKPOINT_VERSION_VALUE
+ return checkpoint_sd
+
+
+def _create_latest_file(base_folder, iteration):
+ file_path = os.path.join(base_folder, 'latest_checkpointed_iteration.txt')
+ os.makedirs(base_folder, exist_ok=True)
+ with open(file_path, 'w') as f:
+ f.write(str(iteration))
+
+
+def main():
+ print(f'Convert Layerzero dist Checkpoint to a SINGLE Megatron Checkpoint')
+
+ args = parse_arguments()
+ print(f'Converting Layerzero checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}')
+ if args.ema_model:
+ from mindspeed.core.distributed.layerzero.state.scripts.layerzero_checkpointer import set_ema_model
+ set_ema_model()
+ if args.prefix is not None:
+ from mindspeed.core.distributed.layerzero.state.scripts.layerzero_checkpointer import remove_model_prefix
+ remove_model_prefix(args.prefix)
+
+ lz_checkpoint = LayerzeroCheckpoint(args.input_folder)
+ iteration = lz_checkpoint.get_iteration()
+ _create_latest_file(args.output_folder, iteration)
+ checkpoint_paths = _create_checkpoint_paths(
+ args.output_folder, iteration, args.target_tp, args.target_pp)
+ for i in range(0, args.target_tp):
+ for j in range(0, args.target_pp):
+ sd = _create_rank_checkpoint(
+ lz_checkpoint, i, j, args.target_tp, args.target_pp, args.for_release)
+ _save_checkpoint(checkpoint_paths[i][j], sd)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/layerzero_checkpointer.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/layerzero_checkpointer.py
new file mode 100644
index 0000000000000000000000000000000000000000..28ecc79100c8058b282127a26b0dd66b5abbad67
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/scripts/layerzero_checkpointer.py
@@ -0,0 +1,404 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import os
+import re
+from typing import Dict, List, Tuple, Any
+from dataclasses import dataclass
+from collections import OrderedDict, defaultdict
+
+import torch
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ clean_tensor_name,
+)
+
+ITERATION_KEY = "iteration"
+ARGS_KEY = "args"
+LOCAL_NAME_TO_FQN_KEY = "shard_state_dict"
+PARALLE_STATE_KAY = "parallel_state"
+MODEL_SD_KEY = "model"
+PP_LAYER_PATTERN = re.compile(r"(layers\.)(\d+)(\..*)")
+
+MODEL_FILE_KEY = "model_"
+NUM_LAYERS_KEY = "num_layers"
+PP_LAYERS_KEY = "layers_per_pp"
+
+EMA_MODEL_SD_KEY = "ema_model"
+MODEL_PREFIX = None
+
+
+def remove_model_prefix(prefix):
+ print(f"[debug] set model prefix =", prefix)
+ global MODEL_PREFIX
+ if prefix:
+ MODEL_PREFIX = prefix + '.'
+
+
+def clean_prefix(fqn, prefix):
+ if prefix:
+ fqn = fqn.replace(prefix, "")
+ return fqn
+
+
+def set_ema_model():
+ global MODEL_SD_KEY
+ global EMA_MODEL_SD_KEY
+ MODEL_SD_KEY = EMA_MODEL_SD_KEY
+
+
+class ShardStateDict:
+
+ def __init__(self, filename) -> None:
+ self.filename = filename
+ self._init_metadata()
+
+ def _init_metadata(self):
+ state_dict = torch.load(self.filename, map_location='cpu')
+
+ self.parallel_info = state_dict[PARALLE_STATE_KAY]
+ self._param_key_to_shard_info = state_dict[LOCAL_NAME_TO_FQN_KEY]
+ self.model_state_dict = state_dict[MODEL_SD_KEY]
+
+ self.tp_rank = self.parallel_info["tp_rank"]
+ self.pp_rank = self.parallel_info["pp_rank"]
+ self.global_rank = self.parallel_info["global_rank"]
+ self.tp_degree = self.parallel_info["tp_degree"]
+ self.pp_degree = self.parallel_info["pp_degree"]
+ self.dp_degree = self.parallel_info["dp_degree"]
+
+ def _get_param_by_param_key(self, param_key) -> torch.Tensor:
+ param = self.model_state_dict.get(param_key, None)
+ return param
+
+ def _get_shape_by_param_key(self, key: str) -> torch.Tensor:
+ shard_info = self._get_shard_info_by_fqn(key)
+ return shard_info.shape
+
+ def _get_tp_pp_rank(self) -> Tuple[int, int]:
+ return (self.tp_rank, self.pp_rank)
+
+ def __lt__(self, rhs):
+ return self.global_rank < rhs.global_rank
+
+ def __len__(self):
+ return len(self.model_state_dict)
+
+ def _get_shard_info_by_fqn(self, key: str):
+ shard_info = self._param_key_to_shard_info.get(key, None)
+ return shard_info
+
+
+class LayerzeroCheckpoint(object):
+ def __init__(self, ckpt_dir):
+ self.ckpt_dir = ckpt_dir
+ self.file_list = self._get_files_by_key(ckpt_dir, MODEL_FILE_KEY)
+ self.global_state = {}
+ self._build_global_state()
+ self.state_dicts = [ShardStateDict(f) for f in self.file_list]
+ self.pp_degree = self.state_dicts[0].pp_degree
+ self.tp_degree = self.state_dicts[0].tp_degree
+ self.layer_state_dicts = [{} for _ in range(self.num_layers)]
+ self.pre_process_sd = {}
+ self.post_process_sd = {}
+ self.other_sd = {}
+ self._sanity_check()
+ self.convert_to_full_state_dict()
+
+ def _sanity_check(self):
+ pass
+
+ def _build_global_state(self):
+ sd = torch.load(self.file_list[0], map_location=torch.device('cpu'))
+ self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
+ self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
+ args = self.get_args()
+ self.global_state[NUM_LAYERS_KEY] = args.num_layers
+ self.global_state[PP_LAYERS_KEY] = args.num_layers // args.pipeline_model_parallel_size
+
+ @property
+ def pp_layers_per_rank(self):
+ return self.global_state[PP_LAYERS_KEY]
+
+ @property
+ def num_layers(self):
+ return self.global_state[NUM_LAYERS_KEY]
+
+ def get_iteration(self):
+ if ITERATION_KEY not in self.global_state:
+ sd = torch.load(
+ self.mp_rank_files[0], map_location=torch.device('cpu'))
+ self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
+
+ return self.global_state[ITERATION_KEY]
+
+ def get_args(self):
+ if ARGS_KEY not in self.global_state:
+ sd = torch.load(
+ self.mp_rank_files[0], map_location=torch.device('cpu'))
+ self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
+
+ return self.global_state[ARGS_KEY]
+
+ def _get_files_by_key(self, ckpt_dir, key):
+ file_list = []
+ for root, dirs, files in os.walk(ckpt_dir):
+ for file in files:
+ if file.startswith(key):
+ file_list.append(os.path.join(root, file))
+ return file_list
+
+ def convert_to_full_state_dict(self) -> Dict[str, Any]:
+ state_dicts: List[ShardStateDict] = self.state_dicts
+ same_pp_groups = _get_same_pp_ranks(state_dicts)
+ for pp_rank, pp_groups in same_pp_groups.items():
+ self.build_layer_state_dict(pp_rank, pp_groups)
+ return
+
+ def build_layer_state_dict(self, pp_rank: int, state_dicts: List[ShardStateDict]) -> Dict:
+ '''
+ This function converts dist layerzero state_dict file for each pp model
+
+ Input: sorted state_dict based on global rank and belongs to same pp stage
+
+ output: A single full_state_dict for this pp stage. (TP=1)
+ '''
+ tp_zero_index = get_TP_unshard_idx_same_pp(state_dicts)
+ non_zero_keys = set()
+ for key, param in state_dicts[0].model_state_dict.items():
+ fqn = clean_tensor_name(key)
+ shard_info = state_dicts[0]._get_shard_info_by_fqn(fqn)
+ if shard_info is None:
+ full_tensor = param
+ non_zero_keys.add(fqn)
+ else:
+ shape = shard_info.shape
+ tensor_model_parallel = shard_info.tensor_model_parallel
+ partition_dim = shard_info.partition_dim
+
+ shard_lists = _get_shard_list_by_param_key(state_dicts, key)
+ if self.tp_degree > 1 and tensor_model_parallel:
+ full_tensor = zero_tp_to_full_tensor(
+ shard_lists, tp_zero_index, shape, partition_dim, self.tp_degree)
+ else:
+ full_tensor = zero_to_full_tensor(shard_lists, shape)
+ layer_num = _get_layer_num(fqn)
+ if layer_num is not None:
+ global_layer_num = self.local_to_global_layer_num(
+ layer_num, pp_rank)
+ self.layer_state_dicts[global_layer_num][key] = full_tensor
+ else:
+ if pp_rank == 0:
+ self.pre_process_sd[fqn] = full_tensor
+ if pp_rank == self.pp_degree - 1:
+ self.post_process_sd[fqn] = full_tensor
+ if not (pp_rank == 0) or (pp_rank == self.pp_degree - 1):
+ self.other_sd[fqn] = full_tensor
+ print(f"{non_zero_keys=}")
+ return
+
+ def local_to_global_layer_num(self, layer_num: int, pp_rank: int):
+ return layer_num + pp_rank * self.pp_layers_per_rank
+
+ def create_rank_checkpoint(self, tp_index: int, pp_index: int, tp_degree: int, pp_degree: int) -> Dict[str, torch.Tensor]:
+ '''
+ 为指定的 tp_index 和 pp_index 生成对应的状态字典,并根据 tp_degree 对张量进行分片。
+
+ Args:
+ tp_index (int): 目标 TP 阶段的索引。
+ pp_index (int): 目标 PP 阶段的索引。
+ tp_degree (int): TP 的总阶段数。
+ pp_degree (int): PP 的总阶段数。
+
+ Returns:
+ Dict[str, torch.Tensor]: 目标 TP 和 PP 阶段的状态字典。
+ '''
+ # 获取目标 PP 阶段的状态字典
+ state_dict = self.get_layer_state_dict(pp_index, pp_degree)
+ # 对状态字典中的张量进行 TP 分片
+ rank_state_dict = {}
+ for fqn, tensor in state_dict.items():
+ shard_info = self.state_dicts[0]._get_shard_info_by_fqn(fqn)
+
+ if MODEL_PREFIX:
+ fqn = clean_prefix(fqn, MODEL_PREFIX)
+
+ if shard_info is not None and shard_info.tensor_model_parallel:
+ # 如果张量是 TP 分片的,则根据 tp_index 和 tp_degree 进行分片
+ partition_dim = shard_info.partition_dim
+ stride = shard_info.partition_stride
+ rank_state_dict[fqn] = shard_tensor(
+ tensor, tp_degree, tp_index, partition_dim, stride)
+ else:
+ # 如果张量不是 TP 分片的,则直接使用原张量
+ rank_state_dict[fqn] = tensor
+ return rank_state_dict
+
+ def get_layer_state_dict(self, pp_index: int, pp_degree: int) -> Dict[str, torch.Tensor]:
+ '''
+ 获取指定 pp_index 的状态字典,包括预处理、后处理以及该 pp_index 对应的层状态字典。
+
+ Args:
+ pp_index (int): 目标 PP 阶段的索引。
+ pp_degree (int): PP 的总阶段数。
+
+ Returns:
+ Dict[str, torch.Tensor]: 目标 PP 阶段的状态字典。
+ '''
+ state_dict = {}
+
+ # 添加预处理部分(仅在 pp_index == 0 时)
+ if pp_index == 0:
+ state_dict.update(self.pre_process_sd)
+
+ # 添加后处理部分(仅在 pp_index == pp_degree - 1 时)
+ if pp_index == pp_degree - 1:
+ state_dict.update(self.post_process_sd)
+ state_dict.update(self.other_sd)
+ pp_layers_per_rank = self.pp_layers_per_rank
+ # 添加该 PP 阶段对应的层状态字典
+ start_layer = pp_index * pp_layers_per_rank
+ end_layer = start_layer + pp_layers_per_rank
+
+ for layer_idx, layer_state_dict in enumerate(self.layer_state_dicts[start_layer:end_layer]):
+ layer_state_dict = _rename_layer_sd_key(
+ layer_state_dict, layer_idx)
+ state_dict.update(layer_state_dict)
+
+ return state_dict
+
+
+def _get_layer_num(key: str) -> int:
+ match = PP_LAYER_PATTERN.match(key)
+
+ if match:
+ # 提取前缀、层号和后缀
+ prefix, layer_num, suffix = match.groups()
+ # 构建新的键
+ return int(layer_num)
+ else:
+ return None
+
+
+def _rename_layer_sd_key(layer_state_dict: Dict, layer_idx: int):
+ state_dict = {}
+ for key, value in layer_state_dict.items():
+ state_dict[_rename_layer_key(key, layer_idx)] = value
+ return state_dict
+
+
+def _rename_layer_key(old_key: str, idx: int) -> str:
+ """Generate new key based for pp stage, old_key -> new_key
+
+ Args:
+ old_key (str): layers.{i}.name
+ idx (int): num_layers_idx new
+
+ Returns:
+ str: layers.{idx}.name
+ """
+ match = PP_LAYER_PATTERN.match(old_key)
+
+ if match:
+ # 提取前缀、层号和后缀
+ prefix, layer_num, suffix = match.groups()
+ # 构建新的键
+ new_key = f"{prefix}{idx}{suffix}"
+ return new_key
+ else:
+ return old_key
+
+
+def _get_shard_list_by_param_key(state_dicts, key):
+ '''
+ Return the sharded paramter that belongs to same param key!!!
+
+ Be aware of TP condition, the parameter is shard by TP then by ZeRO3
+ '''
+ if not state_dicts:
+ return []
+ resutls = [sd._get_param_by_param_key(key) for sd in state_dicts]
+ return resutls
+
+
+def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
+ setattr(tensor, 'tensor_model_parallel', is_parallel)
+ setattr(tensor, 'partition_dim', dim)
+ setattr(tensor, 'partition_stride', stride)
+
+
+def shard_tensor(full_tensor: torch.tensor,
+ tp_degree: int,
+ tp_rank: int,
+ partition_dim: int, stride: int = 1
+ ) -> List[torch.tensor]:
+ shards = torch.chunk(full_tensor, tp_degree, dim=partition_dim)
+ set_tensor_model_parallel_attributes(
+ shards[tp_rank], is_parallel=True, dim=partition_dim, stride=stride)
+ return shards[tp_rank]
+
+
+def zero_to_full_tensor(shards, global_shape):
+ if not isinstance(global_shape, torch.Size):
+ raise TypeError(f"Expect Type torch.Size, got {type(global_shape)}")
+ if not all(len(param.shape) <= 1 for param in shards):
+ raise AssertionError(f"Expect all zero param to be 1D, Got non Flat param")
+ return torch.cat(shards).reshape(global_shape)
+
+
+def tp_full_shape(shape: torch.Size, partition_dim: int, tp_degree: int):
+ if len(shape) <= partition_dim:
+ raise AssertionError(f"{partition_dim} greater or equal to shape len {len(shape)}")
+ shape_list = list(shape)
+ # 修改指定维度的大小
+ shape_list[partition_dim] *= tp_degree
+ return torch.Size(shape_list)
+
+
+def zero_tp_to_full_tensor(shards: List[torch.tensor],
+ tp_zero_index: List[int],
+ shape: torch.Size,
+ partition_dim: int,
+ tp_degree: int):
+ if tp_degree > 1:
+ if len(shards) != len(tp_zero_index):
+ raise AssertionError(f"Not enough zero params for {tp_degree=}")
+ full_shape = tp_full_shape(shape, partition_dim, tp_degree)
+ shards = [shards[i] for i in tp_zero_index]
+ else:
+ full_shape = shape
+ return zero_to_full_tensor(shards, full_shape)
+
+
+def _get_same_pp_ranks(shard_dict_list: List[ShardStateDict]) -> Dict[int, List[ShardStateDict]]:
+ results = defaultdict(list)
+ for shard_dict in shard_dict_list:
+ pp_rank = shard_dict.pp_rank
+ results[pp_rank].append(shard_dict)
+
+ # 对每组进行 sanity check 和排序
+ for pp_rank, group in results.items():
+ # 检查所有状态字典是否具有相同的模型键
+ model_keys = [set(sd.model_state_dict.keys()) for sd in group]
+ if not all(keys == model_keys[0] for keys in model_keys):
+ raise ValueError(
+ f"All state dicts in PP rank {pp_rank} must have the same model keys. "
+ f"Found mismatched keys: {model_keys}"
+ )
+ # 按全局rank排序排序
+ sort_shard_dict_by_global_rank(group)
+ return results
+
+
+def sort_shard_dict_by_global_rank(shard_list: List[ShardStateDict]) -> None:
+ shard_list.sort()
+
+
+def get_TP_unshard_idx_same_pp(state_dicts: List[ShardStateDict]) -> List[int]:
+ pp_ranks = set(sd.pp_rank for sd in state_dicts)
+ if len(pp_ranks) != 1:
+ raise AssertionError("Got more than 1 pp rank")
+
+ tp_global_index = [(idx, sd.tp_rank, sd.global_rank)
+ for idx, sd in enumerate(state_dicts)]
+ sorted_list = sorted(tp_global_index, key=lambda x: (x[1], x[2]))
+ sorted_index = [x[0] for x in sorted_list]
+ return sorted_index
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/state_dict.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/state_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..3663a32b6baa62c914597e3ab06544e1dedb1fa6
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/state/state_dict.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from collections import OrderedDict
+from typing import Dict, List
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from megatron.training.utils import print_rank_0
+
+from .fqn import ShardFlattenInfo
+from ..zero3.fsdp import LayerZeRO3
+from ..zero3._common_utils import (
+ clean_tensor_name,
+ _apply_to_modules,
+)
+from ..zero3._init_utils import _get_ignored_params
+from ..runtime._initialize import _lazy_init
+
+TP_SHARD_ARGS = "tensor_model_parallel"
+
+
+def clean_state_dict(state_dict: Dict):
+ sd = OrderedDict()
+ for key, param in state_dict.items():
+ fqn = clean_tensor_name(key)
+ sd[fqn] = param
+ return sd
+
+
+def use_zero1_params(zero3_model: LayerZeRO3):
+ if zero3_model._is_root is None:
+ _lazy_init(zero3_model, zero3_model)
+ for handle in zero3_model._all_handles:
+ if handle:
+ already_resharded = handle.flat_param.data_ptr(
+ ) == handle.flat_param._zero1_shard.data_ptr()
+ if already_resharded:
+ handle._use_sharded_views()
+ return
+ else:
+ with zero3_model._device_handle.stream(zero3_model._default_stream):
+ event = zero3_model._device_handle.Event()
+ event.record()
+ event.wait()
+ handle.reshard(True)
+ handle._prefetched = False
+ handle._use_sharded_views()
+ return
+
+
+def clean_ignored_modules(zero3_model: LayerZeRO3, state_dict):
+ if zero3_model._is_root is None:
+ _lazy_init(zero3_model, zero3_model)
+ ignored_params = _get_ignored_params(
+ zero3_model, zero3_model._ignored_modules, zero3_model._ignored_params)
+ ignored_keys = set()
+ for key, param in zero3_model.named_parameters():
+ if param in ignored_params:
+ ignored_keys.add(key)
+ new_state_dict = OrderedDict()
+ ignored_param_keys = set()
+ for key, param in state_dict.items():
+ if key in ignored_keys:
+ ignored_param_keys.add(key)
+ else:
+ new_state_dict[key] = param
+ print_rank_0(f"Ignored parameter keys: {ignored_param_keys}")
+ return new_state_dict
+
+
+def shard_state_dict(zero3_model: LayerZeRO3, state_dict):
+ '''This function returns a dict of FQN to shard info mappings for later converting to megatron ckpt.
+ missing keys maybe params that are not managed by Layerzero3,
+ These params later will directly convert to megatron with no-op
+ '''
+ if zero3_model._is_root is None:
+ _lazy_init(zero3_model, zero3_model)
+ if not zero3_model._is_root:
+ raise ValueError("Expected a root zero3 model")
+ shard_infos = _get_param_fqns_to_shards(zero3_model)
+ missing_keys = set()
+ for key in state_dict.keys():
+ fqn = clean_tensor_name(key)
+ if fqn not in shard_infos:
+ missing_keys.add(fqn)
+ print_rank_0(f"Layerzero3 Shard info {missing_keys=}")
+ return shard_infos
+
+
+def _get_param_fqns_to_shards(
+ model: torch.nn.Module,
+) -> Dict[str, ShardFlattenInfo]:
+
+ def module_fn(module, prefix, tree_level, shard_infos):
+ if isinstance(module, LayerZeRO3):
+ handle = module._handle
+ if handle:
+ flat_param = handle.flat_param
+ for param, shard_param_info, fqn, shape in zip(
+ flat_param._params,
+ flat_param._shard_param_infos,
+ flat_param._fqns,
+ flat_param._shapes
+ ):
+ if hasattr(param, TP_SHARD_ARGS):
+ tensor_model_parallel = param.tensor_model_parallel
+ partition_dim = param.partition_dim
+ partition_stride = param.partition_stride
+ else:
+ tensor_model_parallel = False
+ partition_dim = -1,
+ partition_stride = 1,
+ global_fqn = prefix + fqn
+ shard_infos[global_fqn] = ShardFlattenInfo(
+ shard_param_info.in_shard,
+ shard_param_info.numel_in_shard,
+ shard_param_info.intra_param_start_idx,
+ shard_param_info.intra_param_end_idx,
+ shape,
+ tensor_model_parallel,
+ partition_dim,
+ partition_stride)
+
+ def return_fn(shard_infos):
+ return shard_infos
+
+ param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
+ return _apply_to_modules(
+ model,
+ module_fn,
+ return_fn,
+ [],
+ param_to_unflat_param_names,
+ )
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/__init__.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac9b11176400e60ad34c3bfe197afe33e848b9b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/__init__.py
@@ -0,0 +1,8 @@
+from .flat_param import FlatParameter
+from .fsdp import (
+ BackwardPrefetch,
+ LayerZeRO3,
+ MixedPrecision,
+)
+from .flat_param import FlatParamHandle, FlatParameter
+from ._common_utils import _ZeRO3State
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_common_utils.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedf37f864d2fd9e63e41426f75fdc44e892244d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_common_utils.py
@@ -0,0 +1,367 @@
+"""
+This file includes private common utilities for FSDP.
+"""
+import traceback
+import warnings
+import weakref
+from enum import auto, Enum
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ Generator,
+ List,
+ no_type_check,
+ Optional,
+ Set,
+ Tuple,
+ TYPE_CHECKING
+)
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.utils.hooks import RemovableHandle
+from torch.distributed._composable_state import _get_module_state, _State
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+ _CHECKPOINT_PREFIX,
+)
+from torch.utils._mode_utils import no_dispatch
+
+if TYPE_CHECKING:
+ from mindspeed.core.distributed.layerzero.zero3._exec_order_utils import _ExecOrderData
+from mindspeed.core.distributed.layerzero.comm.hookwrap import CriticalPathEventQueue
+
+ZERO3_WRAPPED_MODULE = "_zero3_wrapped_module"
+ZERO3_PREFIX = ZERO3_WRAPPED_MODULE + "."
+ZERO3_FLATTENED = "_zero3_flattened"
+CRITICAL_EVENT_QUEUE = CriticalPathEventQueue()
+
+
+class _DeviceHandle:
+ """
+ This is a simple abstraction for FSDP computing devices,
+ which enables custom backends that implement CUDA-like
+ semantics to be integrated with FSDP.
+ """
+
+ def __init__(self, device: torch.device, backend: Any = None):
+ if backend is None:
+ try:
+ self.__backend = getattr(torch, device.type)
+ self.__device = device
+ except AttributeError as e:
+ raise AttributeError(
+ f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'."
+ ) from e
+ else:
+ self.__backend = backend
+
+ @classmethod
+ def from_device(cls, device: torch.device) -> "_DeviceHandle":
+ """
+ Return an device handle corresponding to the device, and through this handle,
+ operations with the same semantics as CUDA can be performed on the device.
+ Just return torch.cuda if the device is cuda to make attribute-access faster.
+ Custom backend must first register a module with the same name with {device.type} on torch.
+ """
+ if device.type == "cuda":
+ return cast(_DeviceHandle, torch.cuda)
+ return cls(device)
+
+ def __getattr__(self, __name: str) -> Any:
+ try:
+ return getattr(self.__backend, __name)
+ except AttributeError as e:
+ raise AttributeError(
+ f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{__name}'"
+ ) from e
+
+
+class _UninitializedDeviceHandle:
+ def __init__(self):
+ pass
+
+ def __getattribute__(self, __name: str) -> Any:
+ raise RuntimeError("Trying to use an uninitialized device handle.")
+
+
+class _ZeRO3State(_State):
+
+ def __init__(self) -> None:
+ self._debug_level = None
+ #! zero3 related attributes
+ self._ignored_modules: Set[nn.Module] = set()
+ self._ignored_params: Set[nn.Parameter] = set()
+ # Buffer names are cleaned (without wrapper prefixes)
+ self._ignored_buffer_names: Set[str] = set()
+ self.zero3_process_group: Optional[dist.ProcessGroup] = None
+ #!=========================zero1 pg state===================
+ self.zero1_process_group: Optional[dist.ProcessGroup] = None
+ self.global_rank: int = -1
+ self.world_size: int = -1
+ #!==========================================================
+ self.zero3_rank: int = -1
+ self.zero3_world_size: int = -1
+ self.limit_all_gathers: bool = False
+ self.training_state = TrainingState.IDLE
+ self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
+ self._is_root: Optional[bool] = None
+ self._handle = None
+ # : Dict[nn.Module, Optional[flat_param_file.FlatParamHandle]]
+ self._zero3_module_to_handle = {}
+ self.compute_device: Optional[torch.device] = None
+ self._gradient_predivide_factor: int = 0
+ # Abstract device handle for fsdp compute device. For now,
+ # the compute device must implement cuda semantics used by fsdp
+ self._device_handle: _DeviceHandle = _UninitializedDeviceHandle()
+ # All following attributes should only be used for root states:
+ # Save these static lists to avoid the repeated tree traversals
+ self._all_zero3_states: List[_ZeRO3State] = []
+ self._all_handles = [] # : List[flat_param_file.FlatParamHandle] = []
+ self.mixed_precision = None
+ self._offload_grads = False
+ #!===========================streams==================================
+ self._unshard_stream = None
+ self._post_backward_stream = None
+ self._pre_unshard_stream = None
+ self._default_stream = None
+ self._offload_stream = None
+ self._exec_order_data: "_ExecOrderData" = None
+ self._free_event_queue = None
+ self._rs_event_queue = None
+ self._offload_event_queue = None
+ #!==========================runtime state =========================
+ self.backward_prefetch = None
+ self.backward_reduce_scatter = None
+ self.forward_prefetch: bool = None
+ self._root_pre_forward_handles: List[RemovableHandle] = []
+ self._pre_forward_handles: List[RemovableHandle] = []
+ self._post_forward_handles: List[RemovableHandle] = []
+ self._sync_gradients: bool = False
+ self._root_needs_param_sync: bool = True
+ #!==========================hook state===========================
+ self._post_backward_callback_queued: bool = False
+ #!=================================================================
+
+ def wait_critical_path_events(self):
+ if CRITICAL_EVENT_QUEUE is None or CRITICAL_EVENT_QUEUE.empty():
+ return
+ with torch.profiler.record_function("LayerZeRO3: wait critical path events"):
+ with CRITICAL_EVENT_QUEUE.block():
+ while not CRITICAL_EVENT_QUEUE.empty():
+ event = CRITICAL_EVENT_QUEUE.pop_left()
+ if event is not None:
+ with torch.profiler.record_function(
+ "LayerZeRO3.critical_path_events"
+ ):
+ event.wait()
+
+ @classmethod
+ def record_critical_event(cls):
+ if dist.get_rank() == 0:
+ print("Record a critical event")
+ event = torch.cuda.Event()
+ event.record()
+ CRITICAL_EVENT_QUEUE.enqueue(event)
+
+
+def _get_module_zero3_state(module: nn.Module) -> Optional[_ZeRO3State]:
+ state = _get_module_state(module)
+ if state is None or not isinstance(state, _ZeRO3State):
+ return None
+ return state
+
+
+class TrainingState(Enum):
+ """
+ An enum that indicates the state of a ``FullyShardedDataParallel` instance.
+ """
+
+ IDLE = auto()
+ FORWARD_BACKWARD = auto()
+ SUMMON_FULL_PARAMS = auto()
+
+
+class HandleTrainingState(Enum):
+ """
+ An enum that indicates the state of a ``FlatParamHandle`.
+ """
+
+ IDLE = auto()
+ FORWARD = auto()
+ BACKWARD_PRE = auto()
+ BACKWARD_POST = auto()
+ SUMMON_FULL_PARAMS = auto()
+ SYNC_PARAMS = auto()
+
+
+def _is_composable(state: _ZeRO3State):
+ return not isinstance(state, nn.Module)
+
+
+@no_type_check
+def _module_handle(state: _ZeRO3State, module: nn.Module):
+ """
+ Returns the ``FlatParamHandle`` s corresponding to ``module``. This is
+ the handle that contains some parameter in ``module``.
+ """
+ if _is_composable(state):
+ # A valid FSDP state may have no managed parameters and hence no
+ # handles, meaning no entry in `_fully_sharded_module_to_handles`
+ if state._handle is None:
+ return None
+ if module not in state._zero3_module_to_handle:
+ raise AssertionError(f"Expects a fully sharded module but got {module} on rank {state.zero3_rank}")
+ return state._zero3_module_to_handle[module]
+ else:
+ # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
+ return module._handle
+
+
+@no_type_check
+def _has_zero3_params(state: _ZeRO3State, module: nn.Module) -> bool:
+ """Returns if ``module`` has parameters managed by LayerZeRO3."""
+ return _module_handle(state, module) is not None
+
+
+def clean_tensor_name(tensor_name: str) -> str:
+ """
+ Cleans the parameter or buffer name by removing any module wrapper
+ prefixes.
+ """
+ tensor_name = tensor_name.replace(ZERO3_PREFIX, "")
+ # it couples `CheckpointWrapper` and FSDP and also does not scale for more
+ # module wrappers.
+ tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
+ return tensor_name
+
+
+def _set_zero3_flattened(tensor: torch.Tensor) -> None:
+ """
+ Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
+ avoid re-flattening it during nested construction.
+ """
+ setattr(tensor, ZERO3_FLATTENED, True)
+
+
+def _is_zero3_flattened(tensor: torch.Tensor) -> bool:
+ """Returns if ``tensor`` has been marked as flattened by FSDP."""
+ return getattr(tensor, ZERO3_FLATTENED, False)
+
+
+def _named_parameters_with_duplicates(
+ module: nn.Module, **kwargs: Any
+) -> List[Tuple[str, nn.Parameter]]:
+ """
+ This API is required as some modules overwrite `named_parameters()` but do not support
+ `remove_duplicate`.
+ """
+ kwargs["remove_duplicate"] = False
+ try:
+ ret = list(module.named_parameters(**kwargs))
+ except AssertionError as e:
+ kwargs.pop("remove_duplicate")
+ ret = list(module.named_parameters(**kwargs))
+ return ret
+
+
+def _apply_to_modules(
+ root_module: torch.nn.Module,
+ module_fn: Callable,
+ return_fn: Callable,
+ filter_fqns: Optional[List[str]] = None,
+ *args,
+ **kwargs,
+):
+ """
+ Performs a pre-order traversal of the modules in the hierarchy rooted at
+ ``root_module``, applying ``module_fn`` at each module and finally
+ returning a value using ``return_fn``. The traversal constructs the full
+ module prefix name (e.g. "module.submodule." just like in model state dict)
+ and makes that available to ``module_fn``.
+
+ ``filter_fqns`` is used because some module may have its own prefix similar
+ to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
+ to remove the prefix.
+ """
+
+ def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
+ # Call the module function before recursing over children (pre-order)
+ module_fn(module, prefix, tree_level, *args, **kwargs)
+ for submodule_name, submodule in module.named_children():
+ if submodule is None:
+ continue
+ new_prefix = prefix + submodule_name + "."
+ new_tree_level = tree_level + 1
+ if filter_fqns is not None:
+ for fqn in filter_fqns:
+ if fqn.startswith(new_prefix):
+ break
+ else:
+ # DMP's named_parameter() will mess up the traversal with
+ # ``named_children`` + `named_parameter(recurse=False)``.
+ # This hack is a must to make the traversal work.
+ if (
+ submodule_name == "_zero3_wrapped_module"
+ or submodule_name == "_dmp_wrapped_module"
+ ):
+ if (
+ not torch.distributed._functional_collectives.is_torchdynamo_compiling()
+ ):
+ warnings.warn(
+ "An unexpected prefix is detected. This case "
+ " should only happen when using DMP with FSDP. "
+ f"prefix = {prefix}, "
+ f"submodule_name = {submodule_name}"
+ )
+ new_prefix = prefix
+ elif submodule_name == "module":
+ warnings.warn(
+ "An unexpected prefix is detected. This case "
+ " should only happen when DDP wraps the outer "
+ " modules while FSDP wraps the inner ones."
+ f"prefix = {prefix}, "
+ f"submodule_name = {submodule_name}"
+ )
+ new_prefix = prefix
+ f(submodule, new_prefix, new_tree_level, *args, **kwargs)
+
+ f(root_module, "", 0, *args, **kwargs)
+ return return_fn(*args, **kwargs)
+
+
+@no_type_check
+def _assert_in_training_states(
+ state: _ZeRO3State,
+ training_states: List[TrainingState],
+) -> None:
+ """Asserts that zero3 is in the states ``_training_states``."""
+ # Raise a `ValueError` instead of using `assert` to ensure that these
+ # logical assertions run even if `assert`s are disabled
+ if state.training_state not in training_states:
+ msg = (
+ f"expected to be in states {training_states} but current state is "
+ f"{state.training_state}"
+ )
+ # Print the error on rank 0 in case this is called in the backward pass
+ if state.zero3_rank == 0:
+ if isinstance(state, nn.Module):
+ print(f"Asserting FSDP instance is: {state}")
+ print(f"ERROR: {msg}")
+ traceback.print_stack()
+ raise ValueError(msg)
+
+
+def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
+ if tensor.device.type not in ["cuda", torch._C._get_privateuse1_backend_name(), "npu"]:
+ return
+
+ # Don't no dispatch under torch compile like this
+ with no_dispatch():
+ tensor.record_stream(stream)
+
+
+def _same_storage_as_data_ptr(x: torch.Tensor, data_ptr: int) -> bool:
+ return x._typed_storage()._data_ptr() == data_ptr
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_exec_order_utils.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_exec_order_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1230e2482a96f20699dd4b50b566d32e391f55b2
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_exec_order_utils.py
@@ -0,0 +1,153 @@
+import logging
+from enum import auto, Enum
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch.distributed as dist
+import torch.nn as nn
+from mindspeed.core.distributed.layerzero.zero3.flat_param import FlatParamHandle
+import mindspeed.core.distributed.layerzero.zero3._traversal_utils as traversal_utils
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+
+class _ExecOrderWarnStatus(Enum):
+ """Used internally for execution order validation."""
+
+ NONE = auto() # no deviation yet
+ WARNING = auto() # deviated this iteration; currently issuing warnings
+ WARNED = auto() # deviated in a previous iteration
+
+
+class _ExecOrderData:
+
+ def __init__(
+ self,
+ backward_prefetch_limit: int,
+ forward_prefetch_limit: int,
+ ) -> None:
+ # Tracks the (static) pre-forward order for execution order validation
+ # and forward prefetching
+ self.handles_pre_forward_order: List[FlatParamHandle] = []
+ # Tracks the post-forward order for pre-backward prefetching
+ self.handles_post_forward_order: List[Optional[FlatParamHandle]] = []
+ self._iter = 0
+
+ # Gives the max number of backward/forward prefetched all-gathers by a
+ # single module
+ self._backward_prefetch_limit = backward_prefetch_limit
+ self._forward_prefetch_limit = forward_prefetch_limit
+
+ self.process_group: Optional[dist.ProcessGroup] = None
+ self.world_size: Optional[int] = None
+ self.all_handles: List[FlatParamHandle] = []
+
+ def init(
+ self,
+ state,
+ root_module: nn.Module,
+ process_group: dist.ProcessGroup,
+ ) -> None:
+ """
+ Initializes the data structures needed for checking the forward order.
+ This should be called after a root FSDP instance has been set during
+ lazy initialization.
+ """
+ self.process_group = process_group
+ self.rank = process_group.rank()
+ self.world_size = process_group.size()
+ # Fix an order over the handles, which should be the same across ranks
+ for handle in traversal_utils._get_zero3_handles(root_module):
+ index = len(self.all_handles)
+ self.all_handles.append(handle)
+ handle._handle_index = index
+
+ @property
+ def is_first_iter(self) -> bool:
+ return self._iter == 0
+
+ def get_handle_to_backward_prefetch(
+ self,
+ current_handle: FlatParamHandle,
+ ) -> Optional[FlatParamHandle]:
+ """
+ Returns a :class:`list` of the handles keys of the handles to backward
+ prefetch given the current handles key. If there are no valid handles
+ keys to prefetch, then this returns an empty :class:`list`.
+ """
+ current_index = current_handle._post_forward_index
+ if current_index is None:
+ return None
+ target_index = current_index - 1
+ target_handle: Optional[FlatParamHandle] = None
+ for _ in range(self._backward_prefetch_limit):
+ if target_index < 0:
+ break
+ target_handle = self.handles_post_forward_order[target_index]
+ target_index -= 1
+ return target_handle
+
+ def get_handle_to_forward_prefetch(
+ self,
+ current_handle: FlatParamHandle,
+ ) -> Optional[FlatParamHandle]:
+ """
+ Returns a :class:`list` of the handles keys of the handles to forward
+ prefetch given the current handles key. If there are no valid handles
+ keys to prefetch, then this returns an empty :class:`list`.
+ """
+ current_index = current_handle._pre_forward_order_index
+ if current_index is None:
+ return None
+ target_index = current_index + 1
+ target_handle: Optional[FlatParamHandle] = None
+ for _ in range(self._forward_prefetch_limit):
+ if target_index >= len(self.handles_pre_forward_order):
+ break
+ target_handle = self.handles_pre_forward_order[target_index]
+ target_index += 1
+ return target_handle
+
+ def get_handle_to_post_backward(
+ self,
+ current_handle: FlatParamHandle,
+ ) -> List[FlatParamHandle]:
+ current_index = current_handle._pre_forward_order_index
+ if current_index is None:
+ return []
+ target_index = current_index + 1
+ target_handle: List[FlatParamHandle] = []
+ for _ in range(len(self.handles_pre_forward_order)):
+ if target_index >= len(self.handles_pre_forward_order):
+ break
+ target_handle.append(self.handles_pre_forward_order[target_index])
+ target_index += 1
+ return target_handle
+
+ def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None:
+ if not handle or handle._post_forward_index is not None:
+ return
+ index = len(self.handles_post_forward_order)
+ handle._post_forward_index = index
+
+ self.handles_post_forward_order.append(handle)
+
+ def record_pre_forward(
+ self, handle: Optional[FlatParamHandle], is_training: bool
+ ) -> None:
+ if not handle:
+ return
+ # Fix the order after the first iteration and only record the first
+ # usage of a handles key
+ if not self.is_first_iter or handle._pre_forward_order_index is not None:
+ return
+ index = len(self.handles_pre_forward_order)
+ handle._pre_forward_order_index = index
+ self.handles_pre_forward_order.append(handle)
+
+ def next_iter(self):
+ self._iter += 1
+ self.handles_post_forward_order.clear()
+
+ def next_iter_during_accumulation(self):
+ self._iter += 1
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_init_utils.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_init_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..033a3a5911171e56cf3d3f9996f44b91db57f3b0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_init_utils.py
@@ -0,0 +1,665 @@
+import collections
+import warnings
+from typing import (
+ Any,
+ Callable,
+ Deque,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ no_type_check,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
+
+import torch
+import torch.distributed as dist
+
+import torch.nn as nn
+from torch.utils.hooks import RemovableHandle
+from megatron.core import mpu
+
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ _DeviceHandle,
+ _ZeRO3State,
+ _get_module_zero3_state,
+ _is_zero3_flattened,
+ _named_parameters_with_duplicates,
+ clean_tensor_name,
+ TrainingState,
+)
+from mindspeed.core.distributed.layerzero.zero3.api import (
+ BackwardPrefetch,
+ BackwardReduceScatter,
+ MixedPrecision,
+)
+from mindspeed.core.distributed.layerzero.zero3.flat_param import (
+ FlatParameter,
+ FlatParamHandle,
+)
+from mindspeed.core.distributed.layerzero.zero3._limiter import _FreeEventQueue
+import mindspeed.core.distributed.layerzero.zero3._exec_order_utils as exec_order_utils
+import mindspeed.core.distributed.layerzero.zero3._traversal_utils as traversal_utils
+import mindspeed.core.distributed.layerzero.zero3.fsdp as zero3_file
+
+
+PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
+ZERO3_SYNCED = "_zero3_synced"
+# Overall specification of process group.
+ProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
+
+
+def _get_gradient_predivide_factor(world_size: int) -> float:
+ factor: int = 1
+ while world_size % factor == 0 and world_size / factor > factor:
+ factor *= 2
+ return float(factor)
+
+
+@no_type_check
+def _init_process_group_state(
+ state: _ZeRO3State,
+ process_group: ProcessGroupType,
+) -> _ZeRO3State:
+
+ state.zero3_process_group, state.zero1_process_group = process_group
+ state.zero3_rank = state.zero3_process_group.rank()
+ data_parallel_world_size = state.zero1_process_group.size()
+ state.world_size = data_parallel_world_size
+ state.global_rank = dist.get_rank()
+ if mpu.is_initialized():
+ state._gradient_predivide_factor = float(dist.get_world_size(
+ mpu.get_data_parallel_group(with_context_parallel=True)))
+ else:
+ state._gradient_predivide_factor = data_parallel_world_size
+ return state
+
+
+@no_type_check
+def _init_ignored_module_states(
+ state: _ZeRO3State,
+ module: nn.Module,
+ ignored_modules: Optional[Iterable[torch.nn.Module]],
+ ignored_states: Union[
+ Optional[Iterable[torch.nn.Parameter]
+ ], Optional[Iterable[torch.nn.Module]]
+ ] = None,
+) -> _ZeRO3State:
+ if ignored_modules is not None and ignored_states is not None:
+ raise ValueError(
+ "Cannot pass both ignored_modules and ignored_states at the "
+ "same time. Please just pass ignored_states."
+ )
+ ignored_parameters = None
+ passed_as_ignored_states = ignored_states is not None
+ if passed_as_ignored_states:
+ ignored_states_list = list(ignored_states)
+ _check_ignored_states(ignored_states_list, True)
+ else:
+ ignored_states_list = []
+ _check_ignored_states(
+ list(ignored_modules) if ignored_modules is not None else [], False
+ )
+ if len(ignored_states_list) > 0:
+ if isinstance(ignored_states_list[0], nn.Parameter):
+ ignored_parameters = ignored_states_list
+ else:
+ ignored_modules = ignored_states_list
+ state._ignored_modules = _get_ignored_modules(module, ignored_modules)
+ state._ignored_params = _get_ignored_params(
+ module,
+ state._ignored_modules,
+ ignored_parameters,
+ )
+ state._ignored_buffer_names = _get_ignored_buffer_names(
+ module,
+ state._ignored_modules,
+ )
+ return state
+
+
+def _check_ignored_states(
+ ignored_states: List[Any], passed_as_ignored_states: bool
+) -> None:
+ """
+ Checks that the ignored states are uniformly parameters or uniformly
+ modules. We may remove this check in the future if we permit mixing.
+ """
+ if len(ignored_states) == 0:
+ return
+ if passed_as_ignored_states:
+ all_params = all(isinstance(state, nn.Parameter)
+ for state in ignored_states)
+ all_modules = all(isinstance(state, nn.Module)
+ for state in ignored_states)
+ if not all_params and not all_modules:
+ # Sort for consistent ordering for unit test regex matching
+ sorted_types = sorted(
+ {type(state) for state in ignored_states}, key=lambda x: repr(x)
+ )
+ raise ValueError(
+ "ignored_states expects all nn.Parameter or all nn.Module list "
+ f"elements but got types {sorted_types}"
+ )
+ else:
+ if not all(isinstance(state, nn.Module) for state in ignored_states):
+ sorted_types = sorted(
+ {type(state) for state in ignored_states}, key=lambda x: repr(x)
+ )
+ raise ValueError(
+ "ignored_modules expects nn.Module list elements but got "
+ f"types {sorted_types}"
+ )
+
+
+@no_type_check
+def _init_device_handle(
+ state: _ZeRO3State,
+ module: nn.Module,
+ ignored_params: Set[nn.Parameter],
+ device_id: Optional[Union[int, torch.device]],
+) -> _ZeRO3State:
+ determined_device = None
+ if device_id is not None:
+ determined_device = (
+ device_id
+ if isinstance(device_id, torch.device)
+ else torch.device(device_id)
+ )
+ if determined_device is None:
+ for param in _get_orig_params(module, ignored_params):
+ if param.device.type in {"cpu", "meta"}:
+ continue
+ if determined_device is None:
+ determined_device = param.device
+ else:
+ if param.device.type != determined_device.type:
+ raise RuntimeError(
+ f"FSDP does not support modules with different device types "
+ f"but got params on {determined_device.type} and {param.device.type}"
+ )
+ determined_device = determined_device or torch.device(
+ "cuda", torch.cuda.current_device()
+ )
+
+ state._device_handle = _DeviceHandle.from_device(determined_device)
+ return state
+
+
+@no_type_check
+def _init_buffer_state(
+ state: _ZeRO3State,
+ module: nn.Module,
+) -> _ZeRO3State:
+ state._buffer_names = _get_buffer_names(module)
+ # Save a mapping from clean fully-qualified buffer name (starting from
+ # `module`) to its original dtype for restoring that dtype during model
+ # checkpointing when buffer mixed precision is enabled. The names should
+ # be clean since the casting happens in a `summon_full_params()` context.
+ _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
+ for buffer_name, buffer in module.named_buffers():
+ buffer_name = clean_tensor_name(buffer_name)
+ _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
+ state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
+ return state
+
+
+@no_type_check
+def _init_core_state(
+ state: _ZeRO3State,
+ mixed_precision: Optional[MixedPrecision],
+ limit_all_gathers: bool,
+ backward_prefetch_limit: int,
+ forward_prefetch_limit: int,
+ offload_grads: bool = False
+) -> _ZeRO3State:
+ # We clamp the strategy to `NO_SHARD` for world size of 1 since they are
+ # currently functionally equivalent. This may change if/when we integrate
+ # FSDP with MoE.
+ state.mixed_precision = mixed_precision or MixedPrecision()
+ if mixed_precision is not None:
+ torch._C._log_api_usage_once(
+ f"mixed_precision.{str(state.mixed_precision)}"
+ )
+
+ state.limit_all_gathers = limit_all_gathers
+ state.training_state = TrainingState.IDLE
+ state._is_root = None
+ state._free_event_queue = _FreeEventQueue()
+ state._rs_event_queue = _FreeEventQueue()
+ state._offload_event_queue = _FreeEventQueue()
+ state._offload_grads = offload_grads
+ # ==========================================
+ state._debug_level = dist.get_debug_level()
+ state._exec_order_data = exec_order_utils._ExecOrderData(
+ backward_prefetch_limit,
+ forward_prefetch_limit,
+ )
+ #! add support for zero1 events
+ # Mapping from fully sharded module to the handles it is responsible to
+ # unshard and reshard (see [Note: Fully Sharded Module])
+ _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = dict()
+ state._zero3_module_to_handle = _fully_sharded_module_to_handle
+ # Invariant: `state.params` contains exactly the `FlatParameter`s of the
+ # handles in `state._handle`
+ _handle: FlatParamHandle = None
+ state._handle = _handle
+ params: List[FlatParameter] = []
+ state.params = params
+ return state
+
+
+@no_type_check
+def _init_runtime_state(
+ state: _ZeRO3State,
+) -> _ZeRO3State:
+ _root_pre_forward_handles: List[RemovableHandle] = []
+ state._root_pre_forward_handles = _root_pre_forward_handles
+ _pre_forward_handles: List[RemovableHandle] = []
+ state._pre_forward_handles = _pre_forward_handles
+ _post_forward_handles: List[RemovableHandle] = []
+ state._post_forward_handles = _post_forward_handles
+ state._sync_gradients = True
+ # Used to prevent running the pre-backward hook multiple times
+ return state
+
+
+@no_type_check
+def _init_prefetching_state(
+ state: _ZeRO3State,
+ backward_prefetch: BackwardPrefetch,
+ forward_prefetch: bool,
+ backward_reduce_scatter: BackwardReduceScatter
+) -> _ZeRO3State:
+ state.backward_prefetch = backward_prefetch
+ state.forward_prefetch = forward_prefetch
+ state.backward_reduce_scatter = backward_reduce_scatter
+ # The data structures use tuples of handles to generalize over the case
+ # where a module's forward involves multiple handles.
+ return state
+
+
+@no_type_check
+def _init_param_handle_from_module(
+ state: _ZeRO3State,
+ zero3_module: nn.Module,
+ device_id: Optional[Union[int, torch.device]],
+ param_init_fn: Optional[Callable[[nn.Module], None]],
+) -> _ZeRO3State:
+ """
+ Initializes a ``FlatParamHandle`` from a module ``fully_sharded_module``.
+ """
+ _check_single_device_module(zero3_module, state._ignored_params, device_id)
+ device_from_device_id = _get_device_from_device_id(
+ device_id, state.global_rank)
+ _move_module_to_device(
+ zero3_module, state._ignored_params, device_from_device_id
+ )
+ state.compute_device = _get_compute_device(
+ zero3_module,
+ state._ignored_params,
+ device_from_device_id,
+ state.global_rank,
+ )
+
+ managed_params = list(_get_orig_params(
+ zero3_module, state._ignored_params))
+ for param in managed_params:
+ if len(param.shape) == 1:
+ param._is_1D_param = True
+ _init_param_handle_from_params(
+ state, managed_params, zero3_module)
+ return state
+
+
+@no_type_check
+def _init_param_handle_from_params(
+ state: _ZeRO3State,
+ params: List[nn.Parameter],
+ zero3_module: nn.Module,
+):
+ if len(params) == 0:
+ return
+ handle = FlatParamHandle(
+ params,
+ zero3_module,
+ state.compute_device,
+ state.mixed_precision.param_dtype,
+ state.mixed_precision.reduce_dtype,
+ state.zero3_process_group,
+ state.zero1_process_group,
+ state._offload_grads
+ )
+ handle.shard()
+ if state._handle is not None:
+ raise ValueError(f"state handle has been initialized")
+ state.params.append(handle.flat_param)
+ state._handle = handle
+ state._zero3_module_to_handle[handle._zero3_module] = handle
+
+
+def _get_ignored_modules(
+ root_module: nn.Module,
+ _ignored_modules: Optional[Iterable[torch.nn.Module]],
+) -> Set[nn.Module]:
+ """
+ Checks that ``_ignored_modules`` is an iterable of ``nn.Module`` s without
+ any FSDP instances, and returns the modules contained in their module
+ subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
+ already-computed ignored modules are included.
+
+ ``_ignored_modules`` represents the argument passed by the user to FSDP.
+ """
+ msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
+ try:
+ ignored_root_modules = (
+ set(_ignored_modules) if _ignored_modules is not None else set()
+ )
+ except TypeError as e:
+ raise TypeError(
+ msg_prefix + f"but got {type(_ignored_modules)}") from e
+ for module in ignored_root_modules:
+ if not isinstance(module, torch.nn.Module):
+ raise TypeError(
+ msg_prefix + f"but got an iterable with {type(module)}")
+ if _get_module_zero3_state(module):
+ raise ValueError(
+ "`ignored_modules` should not include FSDP modules")
+ # Treat modules that cannot compose with `fully_shard` as ignored modules,
+ # meaning that their subtrees are ignored
+ for module in root_module.modules():
+ if not traversal_utils._composable(module):
+ ignored_root_modules.add(module)
+ # NOTE: Even if `ignored_root_modules` is empty, do not return early so
+ # that this FSDP instance can get any ignored modules from its children.
+
+ # Include child modules and exclude nested FSDP modules themselves
+ ignored_modules = {
+ child
+ for module in ignored_root_modules
+ for child in module.modules()
+ if not isinstance(child, zero3_file.LayerZeRO3)
+ }
+ if root_module in ignored_modules:
+ warnings.warn(
+ "Trying to ignore the top-level module passed into the FSDP "
+ "constructor itself will result in all parameters being "
+ f"ignored and is not well-supported: {module}"
+ )
+ # Include nested FSDP modules' ignored modules
+ for submodule in root_module.modules():
+ optional_fsdp_state = _get_module_zero3_state(submodule)
+ if optional_fsdp_state is not None:
+ if not hasattr(optional_fsdp_state, "_ignored_modules"):
+ raise AttributeError(
+ "State has not attribute _ignored_modules")
+ ignored_modules.update(optional_fsdp_state._ignored_modules)
+ return ignored_modules
+
+
+def _get_ignored_params(
+ root_module: torch.nn.Module,
+ ignored_modules: Set[torch.nn.Module],
+ ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
+) -> Set[torch.nn.Parameter]:
+ """
+ Returns the parameters of the modules in ``ignored_modules`` and
+ the parameters in ``ignored_parameters``, excluding any :class:`FlatParameter` s.
+ """
+ all_ignored_params: Set[torch.nn.Parameter] = set()
+
+ params_in_ignored_modules = {
+ p for m in ignored_modules for p in m.parameters() if not _is_zero3_flattened(p)
+ }
+
+ all_ignored_params.update(params_in_ignored_modules)
+
+ if ignored_parameters is not None:
+ params_in_ignored_parameters = {
+ p for p in ignored_parameters if not _is_zero3_flattened(p)
+ }
+ all_ignored_params.update(params_in_ignored_parameters)
+
+ # Always include nested FSDP modules' ignored parameters
+ for submodule in root_module.modules():
+ optional_fsdp_state = _get_module_zero3_state(submodule)
+ if optional_fsdp_state is not None:
+ if not hasattr(optional_fsdp_state, "_ignored_params"):
+ raise AttributeError("State has not attribute _ignored_params")
+ all_ignored_params.update(optional_fsdp_state._ignored_params)
+
+ return all_ignored_params
+
+
+def _get_ignored_buffer_names(
+ root_module: torch.nn.Module,
+ ignored_modules: Set[torch.nn.Module],
+) -> Set[str]:
+ """
+ Returns the cleaned buffer FQNs in ``ignored_modules``
+ """
+ all_ignored_buffer_names: Set[str] = set()
+
+ buffers_in_ignored_modules = {
+ buffer for m in ignored_modules for buffer in m.buffers()
+ }
+
+ all_ignored_buffer_names.update(
+ {
+ clean_tensor_name(buffer_name)
+ for buffer_name, buffer in root_module.named_buffers()
+ if buffer in buffers_in_ignored_modules
+ }
+ )
+
+ # Always include nested FSDP modules' ignored buffer names
+ for submodule in root_module.modules():
+ optional_fsdp_state = _get_module_zero3_state(submodule)
+ if optional_fsdp_state is not None:
+ if not hasattr(optional_fsdp_state, "_ignored_buffer_names"):
+ raise AttributeError(
+ "State has not attribute _ignored_buffer_names")
+ all_ignored_buffer_names.update(
+ optional_fsdp_state._ignored_buffer_names)
+
+ return all_ignored_buffer_names
+
+
+def _get_buffer_names(root_module: nn.Module) -> Set[str]:
+ """
+ Returns the fully prefixed names of all buffers in the module hierarchy
+ rooted at ``root_module`` as a class:`set`.
+ """
+ return {
+ clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
+ }
+
+
+def _check_single_device_module(
+ module: nn.Module,
+ ignored_params: Set[nn.Parameter],
+ device_id: Optional[Union[int, torch.device]],
+) -> None:
+ """
+ Raises an error if ``module`` has original parameters on multiple devices,
+ ignoring the parameters in ``ignored_params``. Thus, after this method, the
+ module must be either fully on the CPU or fully on a non-CPU device.
+ """
+ devices = {param.device for param in _get_orig_params(
+ module, ignored_params)}
+
+ if len(devices) == 2 and torch.device("cpu") in devices:
+ if device_id is None:
+ raise RuntimeError(
+ "To support a module with both CPU and GPU params, "
+ "please pass in device_id argument."
+ )
+ elif len(devices) > 1:
+ raise RuntimeError(
+ f"ZeRO3 only supports single device modules but got params on {devices}"
+ )
+
+
+def _get_device_from_device_id(
+ device_id: Optional[Union[int, torch.device]],
+ rank: int,
+) -> Optional[torch.device]:
+ """
+ Processes ``device_id`` and returns either the corresponding device or
+ ``None`` if ``device_id`` is ``None``.
+ """
+ if device_id is None:
+ return None
+ device = (
+ device_id if isinstance(
+ device_id, torch.device) else torch.device(device_id)
+ )
+ return device
+
+
+def _move_module_to_device(
+ module: nn.Module,
+ ignored_params: Set[nn.Parameter],
+ device_from_device_id: Optional[torch.device],
+) -> None:
+ cpu_device = torch.device("cpu")
+ if device_from_device_id is not None:
+ # BFS from `module` without traversing any nested FSDP instances to
+ # collect the parameters/buffers that have not yet been managed
+ queue: Deque[nn.Module] = collections.deque()
+ queue.append(module)
+ params: List[nn.Parameter] = []
+ buffers: List[torch.Tensor] = []
+ while queue:
+ curr_module = queue.popleft()
+ params.extend(
+ param
+ for param in curr_module.parameters(recurse=False)
+ if param.device == cpu_device
+ )
+ buffers.extend(
+ buffer
+ for buffer in curr_module.buffers(recurse=False)
+ if buffer.device == cpu_device
+ )
+ for submodule in curr_module.children():
+ if not isinstance(submodule, zero3_file.LayerZeRO3):
+ queue.append(submodule)
+
+ _move_states_to_device(params, buffers, device_from_device_id)
+ return
+ param = next(_get_orig_params(module, ignored_params), None)
+ if param is not None and param.device == cpu_device:
+ _warn_cpu_init()
+
+
+def _move_states_to_device(
+ params: List[nn.Parameter],
+ buffers: List[torch.Tensor],
+ device_from_device_id: Optional[torch.device],
+) -> None:
+ """
+ Precondition: ``_check_single_device_module()`` and module's parameters and
+ buffers have been materialized if needed.
+ """
+ if len(params) == 0 and len(buffers) == 0:
+ return
+ if len(params) > 0:
+ current_device = params[0].device
+ elif len(buffers) > 0:
+ current_device = buffers[0].device
+ cpu_device = torch.device("cpu")
+ if device_from_device_id is not None:
+ # Move the parameters and buffers like the `.data` code path in
+ # `nn.Module._apply()`, which underlies `nn.Module.to()`
+ for param in params:
+ with torch.no_grad():
+ param.data = param.to(device_from_device_id)
+ if param.grad is not None:
+ param.grad.data = param.grad.to(device_from_device_id)
+ for buffer in buffers:
+ buffer.data = buffer.to(device_from_device_id)
+ elif current_device == cpu_device:
+ _warn_cpu_init()
+
+
+def _warn_cpu_init():
+ warnings.warn(
+ "The passed-in `module` is on CPU and will thus have FSDP's sharding "
+ "initialization run on CPU, which may be slower than on GPU. We "
+ "recommend passing in the `device_id` argument for FSDP to move "
+ "`module` to GPU for the sharding initialization. `module` must also "
+ "be on GPU device to work with the `sync_module_states=True` flag "
+ "since that requires GPU communication."
+ )
+
+
+def _get_compute_device(
+ module: nn.Module,
+ ignored_params: Set[nn.Parameter],
+ device_from_device_id: Optional[torch.device],
+ rank: int,
+) -> torch.device:
+ """
+ Determines and returns this FSDP instance's compute device. If a device is
+ specified by ``device_id``, then returns that device. Otherwise, If the
+ module is already on a non-CPU device, then the compute device is that non-CPU
+ device. If the module is on CPU, then the compute device is the current
+ device.
+
+ Since this method should be called after materializing the module, any
+ non-CPU device should not be meta device. For now, the compute device is
+ always a CUDA GPU device with its explicit index.
+
+ Precondition: ``_check_single_device_module()`` and
+ ``_move_module_to_device()``.
+ """
+ param = next(_get_orig_params(module, ignored_params), None)
+ if param is not None and param.device.type != "cpu":
+ compute_device = param.device
+ else:
+ if device_from_device_id is not None and device_from_device_id.type != "cuda":
+ compute_device = device_from_device_id
+ else:
+ compute_device = torch.device("cuda", torch.cuda.current_device())
+ if device_from_device_id is not None and compute_device != device_from_device_id:
+ raise ValueError(
+ f"Inconsistent compute device and `device_id` on rank {rank}: "
+ f"{compute_device} vs {device_from_device_id}"
+ )
+ return compute_device
+
+
+def _get_orig_params(
+ module: nn.Module,
+ ignored_params: Set[nn.Parameter],
+) -> Iterator[nn.Parameter]:
+ param_gen = module.parameters()
+ try:
+ while True:
+ param = next(param_gen)
+ if param not in ignored_params and not _is_zero3_flattened(param):
+ yield param
+ except StopIteration:
+ pass
+
+
+def _check_orig_params_flattened(
+ zero3_module,
+ ignored_params: Set[nn.Parameter],
+) -> None:
+ """
+ Checks that all original parameters have been flattened and hence made
+ invisible to ``named_parameters()`` for the module hierarchy rooted at
+ ``zero3_module``. This should be called as a sanity check after flattening
+ the wrapped module's parameters.
+ """
+ for param_name, param in _named_parameters_with_duplicates(zero3_module):
+ if param not in ignored_params and not _is_zero3_flattened(param):
+ raise RuntimeError(
+ f"Found an unflattened parameter: {param_name}; "
+ f"{param.size()} {param.__class__}"
+ )
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_limiter.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_limiter.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a12a4f50411049b528ca37fce0cfbf00304ab3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_limiter.py
@@ -0,0 +1,25 @@
+import collections
+
+
+class _FreeEventQueue:
+
+ def __init__(self, num_inflights: int = 3) -> None:
+ self._queue = collections.deque()
+ self._max_num_inflight_all_gathers = num_inflights
+
+ def enqueue(self, free_event) -> None:
+ """Enqueues a free event."""
+ self._queue.append(free_event)
+
+ def dequeue_if_needed(self):
+ """Dequeues a single event if the limit is reached."""
+ if len(self._queue) >= self._max_num_inflight_all_gathers:
+ return self._dequeue()
+ return None
+
+ def _dequeue(self):
+ """Dequeues a free event if possible."""
+ if self._queue:
+ event = self._queue.popleft()
+ return event
+ return None
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_traversal_utils.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_traversal_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c2b333eeef2c04913fd6432a85635d44a77dd63
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_traversal_utils.py
@@ -0,0 +1,107 @@
+"""
+NOTE: This file must be imported like
+``import torch.distributed.fsdp._traversal_utils`` and not like
+``from torch.distirbuted.fsdp._traversal_utils import ...`` to avoid circular
+imports. For brevity, we may import the file as ``traversal_utils``.
+"""
+
+import collections
+from typing import Deque, List, Set, Tuple, TYPE_CHECKING
+
+import torch.nn as nn
+from torch.distributed._composable.contract import _get_registry
+from mindspeed.core.distributed.layerzero.zero3._common_utils import _get_module_zero3_state
+
+if TYPE_CHECKING:
+ from mindspeed.core.distributed.layerzero.zero3._common_utils import _ZeRO3State
+
+"""
+[Note: ZeRO3 State Traversal]
+For the wrapper code path, ``_ZeRO3PState`` is the ``ZeRO3``
+module wrapping a fully sharded module, and for the non-wrapper code path,
+``_ZeRO3PState`` is an object that gets embedded on a fully sharded module.
+
+There are three common traversal idioms: Given a root module,
+- ``_get_zero3_states()`` returns all ``_ZeRO3PState`` s in the tree.
+- ``get_zero3_root_states()`` returns all local root ``_ZeRO3PState`` s in the
+tree (i.e. those with ``_is_root == True``).
+- ``_get_zero3_handles()``returns all ``FlatParamHandle`` s in the tree.
+
+All of these methods must take in the root module (i.e. an ``nn.Module``) and
+not a general ``_ZeRO3PState`` because ``_ZeRO3PState`` does not support a graph
+traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal.
+"""
+
+
+def _composable(module: nn.Module) -> bool:
+ """
+ Returns if ``module`` can compose with ``fully_shard``.
+ """
+ return "replicate" not in _get_registry(module)
+
+
+def _get_zero3_states_with_modules(
+ module: nn.Module,
+) -> Tuple[List["_ZeRO3State"], List[nn.Module]]:
+ """
+ Returns a tuple containing:
+ 1. A list of the ``"_ZeRO3State"`` instances in the module tree rooted at
+ ``module`` without any duplicates and following the ``module.modules()``
+ traversal order (which is assumed to be depth-first).
+ 2. A corresponding list of the modules owning the states in the first list.
+
+ For the wrapper code path, both returned lists are the same, each
+ containing all ``FullyShardedDataParallel`` instances. For the composable
+ code path, this returns a list of all composable state instances and a list
+ of the corresponding fully sharded modules. See [Note: Fully Sharded
+ Module].
+
+ NOTE: The traversal does not proceed into any module annotated by an
+ incompatible API (e.g. ``replicate``).
+ """
+ zero3_states: List["_ZeRO3State"] = []
+ zero3_modules: List[nn.Module] = []
+ # Track the visited FSDP states since multiple modules may share the same
+ # one and we want to return a de-duplicated list
+ visited_states: Set["_ZeRO3State"] = set()
+ # Track the visited modules in case of shared modules, which implies the
+ # module graph is no longer a tree
+ visited_modules: Set[nn.Module] = set()
+
+ # Perform depth-first search from `module` to ensure that we do not
+ # traverse into an incompatible API's subtree (use DFS instead of BFS to
+ # match `.modules()` order)
+ deque: Deque[nn.Module] = collections.deque([module])
+ while deque:
+ submodule = deque.popleft()
+ visited_modules.add(submodule)
+ if not _composable(submodule):
+ continue
+ for child_module in reversed(list(submodule.children())):
+ if child_module not in visited_modules:
+ deque.appendleft(child_module)
+ optional_state = _get_module_zero3_state(submodule)
+ if optional_state is not None and optional_state not in visited_states:
+ visited_states.add(optional_state)
+ zero3_states.append(optional_state)
+ zero3_modules.append(submodule)
+ return zero3_states, zero3_modules
+
+
+def _get_zero3_states(module: nn.Module) -> List["_ZeRO3State"]:
+ """See :func:`_get_zero3_states_with_modules`."""
+ zero3_states, _ = _get_zero3_states_with_modules(module)
+ return zero3_states
+
+
+def _get_zero3_handles(module: nn.Module) -> List:
+ """
+ Returns all ``FlatParamHandle`` s in the module tree rooted at ``module``
+ following the rules in :func:`_get_zero3_state`.
+ """
+ handles = [
+ zero3_state._handle
+ for zero3_state in _get_zero3_states(module)
+ if zero3_state._handle is not None
+ ]
+ return handles
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_wrap_utils.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_wrap_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc11cdf9d4bda12cf6fa510bb9ce7e21010a0679
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/_wrap_utils.py
@@ -0,0 +1,129 @@
+import collections
+import functools
+import inspect
+import warnings
+from functools import partial
+from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
+
+import torch.nn as nn
+from torch.distributed.fsdp._wrap_utils import _override_module_mixed_precision, _validate_frozen_params, _warn_on_overridden_mixed_precision
+from megatron.training.global_vars import get_args
+from megatron.core.tensor_parallel.layers import (
+ ColumnParallelLinear,
+ RowParallelLinear,
+ VocabParallelEmbedding
+)
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ _get_module_zero3_state,
+)
+from mindspeed.core.distributed.layerzero.zero3.wrap import (
+ _construct_wrap_fn,
+ _or_policy,
+ _Policy,
+ _post_order_apply,
+ _recursive_wrap,
+ _run_mixed_precision_override_policy,
+ _run_tensor_parallel_pg_override_policy,
+ _wrap_module_cls_individually,
+)
+
+
+def _auto_wrap(
+ root_module: nn.Module,
+ policy: Union[Callable, _Policy],
+ ignored_modules: Set[nn.Module],
+ ignored_params: Set[nn.Parameter],
+ root_kwargs: Dict[str, Any],
+ fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
+):
+ """
+ Auto wraps modules in ``root_module`` 's tree according to ``policy``
+ following a post-order traversal.
+
+ Precondition: ``root_kwargs`` should contain all arguments except
+ ``module``. This function accepts the kwargs dict directly since it gets
+ forwarded into the post-order traversal function.
+ """
+ mixed_precision = root_kwargs["mixed_precision"]
+ is_wrapper = inspect.isclass(fsdp_fn)
+ _check_nested_wrapping(root_module)
+
+ if isinstance(policy, _Policy):
+ root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
+ target_module_to_kwargs = policy._run_policy(
+ root_module, ignored_modules, root_kwargs
+ )
+ if mixed_precision is not None:
+ target_module_to_kwargs = _run_mixed_precision_override_policy(
+ root_module,
+ mixed_precision._module_classes_to_ignore,
+ ignored_modules,
+ root_kwargs,
+ target_module_to_kwargs,
+ )
+ overridden_module_classes = _override_module_mixed_precision(
+ root_module, mixed_precision._module_classes_to_ignore
+ )
+ _warn_on_overridden_mixed_precision(overridden_module_classes)
+ try:
+ args = get_args()
+ if args.tensor_model_parallel_size > 1:
+ _run_tensor_parallel_pg_override_policy(
+ root_module,
+ {ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding},
+ ignored_modules,
+ root_kwargs,
+ target_module_to_kwargs,
+ )
+ except AssertionError:
+ warnings.warn(
+ "Global args is not correctly initialized, skip TP wrapping...")
+
+ _validate_frozen_params(
+ root_module,
+ set(target_module_to_kwargs.keys()),
+ ignored_params,
+ True,
+ )
+ wrap_fn = _construct_wrap_fn(
+ root_module, target_module_to_kwargs, fsdp_fn)
+ _post_order_apply(root_module, wrap_fn)
+ return
+
+ recursive_wrap_kwargs = {
+ "module": root_module,
+ "auto_wrap_policy": policy,
+ "wrapper_cls": fsdp_fn,
+ "ignored_modules": ignored_modules,
+ "ignored_params": ignored_params,
+ "only_wrap_children": True,
+ }
+ if mixed_precision is not None:
+ # Wrap modules of the ignored types separately and register forward
+ # hooks to cast to fp32 and back to the original dtype, respectively
+ overridden_module_classes = _override_module_mixed_precision(
+ root_module, mixed_precision._module_classes_to_ignore
+ )
+ policy = functools.partial(
+ _or_policy,
+ policies=[
+ policy,
+ partial(
+ _wrap_module_cls_individually,
+ module_classes=mixed_precision._module_classes_to_ignore,
+ ),
+ ],
+ )
+ recursive_wrap_kwargs["auto_wrap_policy"] = policy
+ _warn_on_overridden_mixed_precision(overridden_module_classes)
+ # type: ignore[arg-type]
+ _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)
+
+
+def _check_nested_wrapping(root_module: nn.Module):
+ for module_name, module in root_module.named_modules():
+ if _get_module_zero3_state(module) is not None:
+ raise ValueError(
+ "FSDP auto wrapping requires modules to not already have "
+ f"FSDP applied but found {module_name} in\n{root_module}"
+ )
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/api.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..18a7630e7614c2ec75dc1eba643a936f6c119e3c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/api.py
@@ -0,0 +1,34 @@
+"""
+This file includes public APIs for FSDP such as the classes used for the
+constructor arguments.
+"""
+
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import Optional, Sequence, Type
+
+import torch
+from torch.nn.modules.batchnorm import _BatchNorm
+
+__all__ = [
+ "BackwardPrefetch",
+ "MixedPrecision",
+]
+
+
+class BackwardPrefetch(Enum):
+ BACKWARD_PRE = auto()
+ BACKWARD_POST = auto()
+
+
+class BackwardReduceScatter(Enum):
+ BACKWARD_PRE = auto()
+ BACKWARD_POST = auto()
+
+
+@dataclass
+class MixedPrecision:
+ param_dtype: Optional[torch.dtype] = None
+ reduce_dtype: Optional[torch.dtype] = None
+ buffer_dtype: Optional[torch.dtype] = None
+ _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/flat_param.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/flat_param.py
new file mode 100644
index 0000000000000000000000000000000000000000..743497c912cafa571f6ce4c2e214ed822b2de366
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/flat_param.py
@@ -0,0 +1,1938 @@
+import contextlib
+import functools
+import logging
+import os
+import warnings
+from itertools import accumulate, chain
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ Generator,
+ Iterator,
+ List,
+ NamedTuple,
+ no_type_check,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
+from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined]
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ _DeviceHandle,
+ _named_parameters_with_duplicates,
+ _no_dispatch_record_stream,
+ _set_zero3_flattened,
+ HandleTrainingState,
+)
+
+__all__ = [
+ "FlatParameter",
+ "FlatParamHandle",
+ "FlatParamShardMetadata",
+ "ParamInfo",
+ "SharedParamInfo",
+ "HandleShardingStrategy",
+]
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+
+"""
+[Note: Fully Sharded Module]
+We define the "fully sharded module" to be the original ``nn.Module`` that owns
+a ``FlatParamHandle``. It is the *single* module logically responsible for the
+*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
+forward or backward pass. The fully sharded module should be passed to the
+``FlatParamHandle`` constructor.
+
+For the wrapper code path:
+- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
+runs the unshard/reshard on behalf of the ful+ly sharded module by overriding
+``nn.Module.forward``.
+- The fully sharded module is exactly the module passed to the
+``FullyShardedDataParallel`` constructor's ``module`` argument.
+
+For the non-wrapper code path:
+- Hooks registered on the fully sharded module run the unshard/reshard.
+- The fully sharded module may either be the direct argument to ``fully_shard``
+or a submodule chosen by the provided wrapping policy.
+"""
+
+# We should use 'safe' by default since it respects method overrides, but for
+# special cases such as for high CPU overhead or for intentionally bypassing
+# checks in the overrides, we may use 'unsafe'.
+_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
+
+# Some value to set padding in tensors to for debuggability
+_FLAT_PARAM_PADDING_VALUE = 42
+
+
+class ParamInfo(NamedTuple):
+ """Information for an original parameter."""
+
+ param_name: str # unprefixed
+ module: nn.Module
+ module_name: str
+
+
+class SharedParamInfo(NamedTuple):
+ """
+ Additional information for a shared parameter.
+
+ For each shared parameter, we designate one module and its parameter
+ variable to be the primary owner, determined as the first one encountered
+ in the parameter walk. These are prefixed with "prim". The primary module
+ and parameter do not have their own :class:`SharedParamInfo` instance.
+ """
+
+ param_name: str # unprefixed
+ module: nn.Module
+ module_name: str
+ prim_param_name: str # unprefixed
+ prim_module: nn.Module
+ prim_module_name: str
+
+
+class _ShardParamInfo(NamedTuple):
+ """Shard-related information for an original parameter."""
+
+ in_shard: bool
+ # Use to index into the sharded flat parameter, e.g.
+ # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
+ offset_in_shard: Optional[int]
+ numel_in_shard: Optional[int]
+ # Use to get part of the parameter in the local shard from a flattened
+ # version of the unsharded parameter, e.g.
+ # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
+ intra_param_start_idx: Optional[int]
+ intra_param_end_idx: Optional[int] # inclusive
+ # `unshard_data [flat_param_start_idx : flat_param_end_idx]`
+ flat_param_start_idx: Optional[int] = None
+ flat_param_end_idx: Optional[int] = None # inclusive
+
+
+class FlatParamShardMetadata(NamedTuple):
+ """
+ This holds metadata specific to this rank's shard of the flat parameter.
+
+ Attributes:
+ param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
+ shard of the parameters; see :class:`FlatParameter`.
+ param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
+ shard of the parameters; see :class:`FlatParameter`.
+ param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
+ of the parameters; see :class:`FlatParameter`.
+ param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
+ units of numels) giving this rank's part of each flattened
+ original parameter.
+ """
+
+ param_names: Tuple[str, ...]
+ param_shapes: Tuple[torch.Size, ...]
+ param_numels: Tuple[int, ...]
+ param_offsets: Tuple[Tuple[int, int], ...]
+
+
+class _FlatParameterMeta(_ParameterMeta):
+ # Make `isinstance(t, FlatParameter)` return True for custom tensor
+ # instances that have the _is_flat_param flag for BC
+ def __instancecheck__(self, instance):
+ # NB: do NOT test the super implementation
+ return isinstance(instance, torch.Tensor) and getattr(
+ instance, "_is_flat_param", False
+ )
+
+
+class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
+ _unpadded_unsharded_size: torch.Size
+ _padded_unsharded_size: torch.Size
+ _sharded_size: torch.Size
+ _num_params: int
+ _param_infos: Tuple[ParamInfo, ...]
+ _shapes: Tuple[torch.Size, ...]
+ _fqns: Tuple[str, ...]
+ _numels_with_padding: Tuple[int, ...]
+ _numels: Tuple[int, ...]
+ _shard_param_infos: Tuple[_ShardParamInfo, ...]
+ _shared_param_infos: Tuple[SharedParamInfo, ...]
+ _modules: Set[nn.Module]
+ _shard_numel_padded: int
+ _zero1_shard: Tensor
+ _zero3_shard: Tensor
+ _full_param_padded: Tensor
+ _full_grad_padded: Tensor
+ _full_prec_grad_padded: Tensor
+ _post_backward_hook_state: Tuple[Any, Any]
+ _saved_grad: Tensor
+ _params: Optional[List[nn.Parameter]]
+ _shared_params: Optional[List[nn.Parameter]]
+ _tensors: Optional[List[Optional[Tensor]]]
+ _is_grad_none_mask: Optional[List[bool]]
+ _is_padding_mask: List[bool]
+ _cpu_grad: Tensor = None
+
+ def __new__(cls, data=None, requires_grad=True):
+ if cls is not FlatParameter:
+ raise ValueError("subclasses FlatParameter not supported")
+ r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg]
+ r._is_flat_param = True # type: ignore[attr-defined]
+ return r
+
+ # NB: This is not a regular method, because FlatParameters are not actually
+ # instances of this class (see __new__ above). So you must indirectly
+ # call this directly through the classmethod.
+ @classmethod
+ def _init_metadata(
+ cls,
+ self,
+ param_infos: List[ParamInfo],
+ numels: List[int],
+ shapes: List[torch.Size],
+ fqns: List[str],
+ shared_param_infos: List[SharedParamInfo],
+ params: Optional[List[nn.Parameter]],
+ shared_params: Optional[List[nn.Parameter]],
+ is_padding_mask: List[bool],
+ ) -> None:
+ """
+ Initializes attributes holding metadata about the original parameters
+ comprising the flat parameter.
+
+ We expose this method separate from the constructor to keep the
+ constructor only responsible for the flat parameter's tensor data. This
+ method should only be called once per model, while the constructor may
+ be called multiple times, e.g. when reloading from a checkpoint, in
+ which case only the tensor data needs to be passed to the constructor.
+
+ Args:
+ See the Attributes in the class docstring.
+ """
+ if len(param_infos) != len(shapes) or len(param_infos) != len(fqns):
+ raise ValueError("Incorrect number of param_infos")
+
+ self._num_params = len(param_infos)
+ self._param_infos = param_infos
+ self._shapes = shapes
+ self._fqns = fqns
+ self._is_padding_mask = is_padding_mask
+
+ numels_without_padding: List[int] = []
+ for numel, is_padding in zip(numels, is_padding_mask):
+ if not is_padding:
+ numels_without_padding.append(numel)
+ self._numels = tuple(numels_without_padding)
+ self._numels_with_padding = tuple(numels)
+ if len(self._numels) != self._num_params:
+ raise AssertionError("self._numels do not match num_param")
+
+ self._shared_param_infos = tuple(shared_param_infos)
+ self._modules = {pi.module for pi in self._param_infos}.union(
+ {spi.module for spi in self._shared_param_infos}
+ )
+ if (params is None) != (shared_params is None):
+ raise AssertionError("Param and Shared_param should be both None or non-None")
+ if params is not None:
+ if len(shared_params) != len(shared_param_infos):
+ raise AssertionError("shared_params do not match shared_param_infos")
+ self._params = []
+ for param, is_padding in zip(params, is_padding_mask):
+ if not is_padding:
+ self._params.append(param)
+ self._shared_params = shared_params
+ # Mark the original parameters to avoid flattening them into
+ # another `FlatParameter` during recursive construction
+ for param in chain(self._params, self._shared_params):
+ _set_zero3_flattened(param)
+ self._is_grad_none_mask = [False for _ in range(self._num_params)]
+ self._tensors = [None for _ in range(self._num_params)]
+ else:
+ self._params = None
+ self._shared_params = None
+ self._is_grad_none_mask = None
+ self._tensors = None
+ self._unpadded_unsharded_size = self.size()
+ _set_zero3_flattened(self)
+ # Tracks whether the `FlatParameter`'s post-backward hook has been
+ # called to modify the behavior of the post-backward callback
+ self._post_backward_called = False
+
+
+class FlatParamHandle:
+ ##################
+ # INITIALIZATION #
+ ##################
+ def __init__(
+ self,
+ params: Sequence[Union[nn.Parameter, Tensor]],
+ zero3_module: nn.Module,
+ device: torch.device,
+ mp_param_dtype: Optional[torch.dtype],
+ mp_reduce_dtype: Optional[torch.dtype],
+ zero3_process_group: dist.ProcessGroup,
+ zero1_process_group: dist.ProcessGroup,
+ offload_grads: bool = False
+ ):
+ self.initialize(params,
+ zero3_module,
+ device=device,
+ mp_param_dtype=mp_param_dtype,
+ mp_reduce_dtype=mp_reduce_dtype,
+ zero3_process_group=zero3_process_group,
+ zero1_process_group=zero1_process_group,
+ offload_grads=offload_grads
+ )
+ self._init_flat_param_and_metadata(
+ params, zero3_module, self._aligned_numel, self.zero1_world_size # type: ignore[arg-type]
+ )
+ self._use_unsharded_views(as_params=False)
+
+
+ def initialize(
+ self,
+ params: Sequence[Union[nn.Parameter, Tensor]],
+ zero3_module: nn.Module,
+ device: torch.device,
+ mp_param_dtype: Optional[torch.dtype],
+ mp_reduce_dtype: Optional[torch.dtype],
+ zero3_process_group: dist.ProcessGroup,
+ zero1_process_group: dist.ProcessGroup,
+ offload_grads: bool = False
+ ):
+ params = list(params)
+ if len(params) == 0:
+ raise ValueError(
+ f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
+ )
+ self._init_setattr_fns()
+ align_addresses = True
+ self._init_get_unflat_views_fn(align_addresses)
+ self.device = device
+ self._device_handle = _DeviceHandle.from_device(self.device)
+ self.zero3_process_group = zero3_process_group
+ self.zero1_process_group = zero1_process_group
+ self.zero1_world_size = zero1_process_group.size()
+ self.zero1_group_rank = zero1_process_group.rank()
+ self.zero3_group_rank = zero3_process_group.rank()
+ self.zero3_group_size = zero3_process_group.size()
+ self._training_state = HandleTrainingState.IDLE
+ self._debug_level = dist.get_debug_level()
+ self._zero3_module = zero3_module
+ # For strategies that do not free after forward, we skip using sharded
+ # views after forward since the unsharded data exists. We still switch
+ # `self.flat_param` to point to the sharded flat parameter since what
+ # it points to parameterizes behavior. We use the following attribute
+ # to track which tensor data the parameters are unsharded views into.
+ self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
+ # The index in the state's `all_handles`, which must be the
+ # same across ranks for the execution order validation to work
+ self._handle_index: Optional[int] = None
+ # Index in handles_to_pre_forward_order
+ self._pre_forward_order_index: Optional[int] = None
+ # Index in `handles_post_forward_order`
+ self._post_forward_index: Optional[int] = None
+ # Used for guarding against mistargeted forward prefetches
+ self._needs_pre_forward_unshard = False
+ # Used for guarding against mistargeted backward prefetches
+ self._needs_pre_backward_unshard = False
+ # Was the handle prefetched? Set on successful _prefetch_handle and unshard
+ self._prefetched = False
+ self._ran_pre_backward_hook = False
+ self._ran_post_backward_hook = False
+ #!==================== add support for zero1 param & grad sync state=========================
+ self._needs_param_sync = True
+ self._param_synced = False
+ self._grad_synced = False
+ self.enter_backward = False
+ #!===================================================================================
+ self._offload_grads = offload_grads
+ self.prev_iter_synced = True
+ # Optimistically assume a valid input `params` and set dtype attributes
+ # before `_init_flat_param()`, which performs the actual validation
+ self._orig_param_dtype = params[0].dtype
+ self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
+ self._aligned_numel = (
+ _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
+ if align_addresses
+ else 0
+ )
+ if self.zero1_world_size % self.zero3_group_size != 0:
+ raise ValueError(f"The dp {self.zero1_world_size=} is not multiply of {self.zero3_group_size=}")
+
+ @property
+ def full_prec_dtype(self):
+ return torch.float32
+
+ @property
+ def param_dtype(self):
+ return self._fwd_bwd_param_dtype
+
+ @property
+ def grad_dtype(self):
+ return self._reduce_dtype
+
+ def _init_setattr_fns(self):
+ use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
+ self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
+ self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
+ if use_unsafe_setattr:
+ self._setattr_tensor = _unsafe_setattr_tensor
+ self._setattr_param = _unsafe_setattr_param
+ else:
+ self._setattr_tensor = _safe_setattr_tensor_or_param
+ self._setattr_param = _safe_setattr_tensor_or_param
+
+ def _init_get_unflat_views_fn(self, align_addresses: bool):
+ self._get_unflat_views = (
+ self._get_unflat_views_aligned
+ if align_addresses
+ else self._get_unflat_views_unaligned
+ )
+
+ def _init_flat_param_and_metadata(
+ self,
+ params: List[Union[Tensor, nn.Parameter]],
+ module: nn.Module,
+ aligned_numel: int,
+ div: int
+ ) -> None:
+ """
+ NOTE: This should only be called once at construction time, after which
+ the ``FlatParameter`` metadata is assumed to be static.
+
+ NOTE: The elements of ``params`` should only be ``Tensor`` s when
+ composing with ``DTensor`` -based tensor parallelism, in which case the
+ elements may be ``DTensor`` local shards.
+ """
+ if len(params) == 0:
+ raise ValueError("Expects non-empty `params`")
+ if aligned_numel < 0:
+ raise ValueError(
+ f"Expects non-negative `aligned_numel` but got {aligned_numel}"
+ )
+ (
+ dtype,
+ flat_param_requires_grad,
+ device,
+ ) = self._validate_tensors_to_flatten(params)
+ params_set = set(params)
+ # For alignment padding, only `numels` gets strictly non-`None`
+ # elements, and all other lists get `None` elements for padding.
+ param_infos: List[ParamInfo] = []
+ numels: List[int] = []
+ shapes: List[torch.Size] = []
+ fqns: List[str] = []
+ shared_param_infos: List[SharedParamInfo] = []
+ shared_param_memo: Dict[
+ Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
+ ] = {}
+ params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
+ shared_params: List[Union[Tensor, nn.Parameter]] = []
+ is_padding_mask: List[bool] = []
+ total_numel = total_numel_without_padding = 0
+ for submodule_name, submodule in module.named_modules(remove_duplicate=False):
+ for param_name, param in _named_parameters_with_duplicates(
+ submodule, recurse=False
+ ):
+ if param not in params_set:
+ continue
+ if param in shared_param_memo: # shared reference
+ prim_module, prim_module_name, prim_param_name = shared_param_memo[
+ param
+ ]
+ shared_params.append(param)
+ shared_param_infos.append(
+ SharedParamInfo(
+ param_name,
+ submodule,
+ submodule_name,
+ prim_param_name,
+ prim_module,
+ prim_module_name,
+ )
+ )
+ else:
+ if aligned_numel > 0:
+ numel_to_pad = aligned_numel - (total_numel % aligned_numel)
+ if numel_to_pad > 0 and numel_to_pad < aligned_numel:
+ padding_tensor = _construct_padding_tensor(
+ numel_to_pad, dtype, False, device
+ )
+ params_to_flatten.append(padding_tensor)
+ is_padding_mask.append(True)
+ numels.append(numel_to_pad)
+ total_numel += numel_to_pad
+ param = cast(nn.Parameter, param)
+ shared_param_memo[param] = (submodule, submodule_name, param_name)
+ params_to_flatten.append(param)
+ is_padding_mask.append(False)
+ param_infos.append(ParamInfo(param_name, submodule, submodule_name))
+ numels.append(param.numel())
+ shapes.append(param.shape)
+ fqn = (
+ submodule_name + "." + param_name
+ if submodule_name
+ else param_name
+ )
+ fqns.append(fqn)
+ total_numel += param.numel()
+ total_numel_without_padding += param.numel()
+ if len(params_to_flatten) == 0:
+ raise ValueError(
+ f"`params` were not found in `module`'s tree"
+ f"params: {params}\nmodule: {module}"
+ )
+ if (
+ self.zero1_group_rank == 0
+ and aligned_numel > 0
+ and total_numel != total_numel_without_padding
+ ):
+ logger.info(
+ "ZeRo3 FlatParameter address alignment created "
+ "%s numel of padding (%s vs. %s)",
+ total_numel - total_numel_without_padding,
+ total_numel,
+ total_numel_without_padding,
+ )
+ # if aligned_numel > 0:
+ # Pad to be divisible by world size to avoid a copy for the
+ # post-backward reduce-scatter
+ numel_to_pad = div - (total_numel % div)
+ if numel_to_pad > 0 and numel_to_pad < div:
+ if self.zero1_group_rank == 0:
+ logger.info(
+ "ZeRO3 FlatParameter world size divisibility created "
+ "%s numel of padding",
+ numel_to_pad,
+ )
+ padding_tensor = _construct_padding_tensor(
+ numel_to_pad, dtype, False, device
+ )
+ params_to_flatten.append(padding_tensor)
+ is_padding_mask.append(True)
+ numels.append(numel_to_pad)
+ total_numel += numel_to_pad
+ # Pass `aligned_numel=0` since we already included padding tensors
+ self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
+ params_to_flatten,
+ aligned_numel=0,
+ requires_grad=flat_param_requires_grad,
+ div=div
+ )
+ FlatParameter._init_metadata(
+ self.flat_param,
+ param_infos,
+ numels,
+ shapes,
+ fqns,
+ shared_param_infos,
+ _convert_to_params(params_to_flatten),
+ _convert_to_params(shared_params),
+ is_padding_mask,
+ )
+
+ def _validate_tensors_to_flatten(
+ self, tensors: List[Union[Tensor, nn.Parameter]]
+ ) -> Tuple:
+ """
+ Validates the tensors to flatten and returns any necessary metadata.
+ """
+ dtype: Optional[torch.dtype] = None
+ # Return as the logical OR over each tensor's value
+ flat_param_requires_grad: Optional[bool] = None
+ device: Optional[torch.device] = None
+ for tensor in tensors:
+ if isinstance(tensor, FlatParameter):
+ raise ValueError("Cannot flatten a `FlatParameter`")
+ if dtype is None and not tensor.is_floating_point():
+ raise ValueError("Cannot flatten integer dtype tensors")
+ if dtype is not None and tensor.dtype != dtype:
+ raise ValueError(
+ f"Must flatten tensors with uniform dtype but got {dtype} "
+ f"and {tensor.dtype}"
+ )
+ if device is not None and tensor.device != device:
+ raise ValueError(
+ "Must flatten tensors on the same device but got both "
+ f"{device} and {tensor.device}"
+ )
+ dtype = tensor.dtype
+ flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
+ device = tensor.device
+ return dtype, flat_param_requires_grad, device
+
+ def flatten_tensors(
+ self,
+ tensors: List[Tensor],
+ aligned_numel: int,
+ div: int
+ ) -> Tensor:
+ """
+ Flattens ``tensors`` into a single flat tensor optionally including
+ padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
+ gives the numel required to have address alignment.
+
+ div: The total tensor numel is a multipy of div to avoid different size among rank
+ NOTE: The padding alignment algorithm must be kept in sync with
+ :meth:`_init_flat_param_metadata`. We separate the two methods because
+ the initialization happens once, whereas this method may be called
+ multiple times throughout training (e.g. for checkpointing).
+ """
+ if len(tensors) == 0:
+ raise ValueError("Expects non-empty `tensors`")
+ if aligned_numel < 0:
+ raise ValueError(
+ f"Expects non-negative `aligned_numel` but got {aligned_numel}"
+ )
+ dtype, _, device = self._validate_tensors_to_flatten(tensors)
+ flat_tensors: List[Tensor] = []
+ if aligned_numel > 0:
+ total_numel = 0
+ for tensor in tensors:
+ numel_to_pad = aligned_numel - (total_numel % aligned_numel)
+ if numel_to_pad > 0 and numel_to_pad < aligned_numel:
+ padding_tensor = _construct_padding_tensor(
+ numel_to_pad, dtype, False, device
+ )
+ flat_tensors.append(padding_tensor)
+ total_numel += numel_to_pad
+ flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
+ total_numel += tensor.numel()
+ numel_to_pad = div - (total_numel % div)
+ if numel_to_pad > 0 and numel_to_pad < div:
+ padding_tensor = _construct_padding_tensor(
+ numel_to_pad, dtype, False, device
+ )
+ flat_tensors.append(padding_tensor)
+ total_numel += numel_to_pad
+ else:
+ flat_tensors = [
+ torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
+ ]
+ return torch.cat(flat_tensors, dim=0)
+
+ def flatten_tensors_into_flat_param(
+ self,
+ tensors: List[Tensor],
+ aligned_numel: int,
+ requires_grad: bool,
+ div: int
+ ) -> FlatParameter:
+ flat_param_data = self.flatten_tensors(tensors, aligned_numel, div)
+ return FlatParameter(flat_param_data, requires_grad=requires_grad)
+
+ def _init_param_reduce_dtypes(
+ self,
+ mp_param_dtype: Optional[torch.dtype],
+ mp_reduce_dtype: Optional[torch.dtype],
+ ) -> None:
+ """
+ Precondition: ``self.flat_param`` is set. This ensures that this
+ handle's parameters have a single dtype.
+
+ Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
+ ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
+ is ``None``, then we assume the original parameter dtype. One special
+ case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
+ is ``None``, in which case we assume the gradient reduction dtype
+ matches the forward/backward parameter dtype.
+ """
+ # Save whether these dtypes were specified so that we permit the
+ # parameter dtype to change up until the lazy initialization
+ self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
+ self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
+ if self._fwd_bwd_param_dtype is None or self._reduce_dtype is None:
+ raise ValueError(f"Runtime dtype not set")
+
+ ###################################
+ # SHARD INITIALIZATION & METADATA #
+ ###################################
+ @torch.no_grad()
+ def shard(self):
+ """
+ Shards the handle's ``FlatParameter``. This allocates new memory for
+ the sharded flat parameter and frees the unsharded flat parameter's
+ storage.
+
+ Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
+ metadata attributes are set for all sharding strategies.
+ """
+ flat_param = self.flat_param
+ _p_assert(
+ flat_param.storage_offset() == 0,
+ "The `FlatParameter` is not the sole occupant of its storage",
+ )
+ orig_storage = flat_param._typed_storage()
+ #! _get_shard returns a clone of original parameter
+ zero1_flat_param, zero1_padded = FlatParamHandle._get_shard(
+ flat_param, self.zero1_group_rank, self.zero1_world_size
+ )
+ zero1_flat_param = zero1_flat_param.to(self.full_prec_dtype)
+ flat_param._zero1_shard = zero1_flat_param
+ flat_param.data = zero1_flat_param # type: ignore[call-overload]
+
+ start_idx = zero1_flat_param.numel() * self.zero1_group_rank
+ end_idx = zero1_flat_param.numel() * (self.zero1_group_rank + 1) - 1 # inclusive
+
+ self._init_shard_metadata(zero1_padded, start_idx, end_idx)
+ if orig_storage._size() > 0:
+ orig_storage._resize_(0)
+ self._use_sharded_views()
+
+ def _init_shard_metadata(
+ self,
+ numel_padded: int,
+ unsharded_start_idx: int,
+ unsharded_end_idx: int,
+ ) -> None:
+ """
+ Initializes shard-related metadata for this rank's shard of the flat
+ parameter: ``_sharded_size``, ``_shard_param_infos``, and
+ ``_shard_numel_padded``.
+
+ Args:
+ numel_padded (int): Numel padded for this rank's sharded flat
+ parameter.
+ unsharded_start_idx (int): Start index in the unsharded flat
+ parameter assigned to this rank.
+ unsharded_end_idx (int): End index (inclusive) in the unsharded
+ flat parameter assigned to this rank.
+
+ Precondition: ``self.flat_param`` 's data is the sharded flat
+ parameter.
+ """
+ flat_param = self.flat_param
+ flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined]
+ sharded_flat_param_numel = flat_param.numel() # includes `numel_padded`
+ _p_assert(
+ unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
+ f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
+ )
+ _p_assert(
+ numel_padded <= sharded_flat_param_numel,
+ f"numel_padded: {numel_padded} "
+ f"sharded_flat_param_numel: {sharded_flat_param_numel}",
+ )
+ shard_param_infos = self._get_shard_metadata(
+ unsharded_start_idx, unsharded_end_idx
+ )
+ _p_assert(
+ len(shard_param_infos) == flat_param._num_params,
+ f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
+ )
+ flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
+ flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
+
+ def _get_shard_metadata(
+ self,
+ unsharded_start_idx: int,
+ unsharded_end_idx: int,
+ ) -> Tuple[_ShardParamInfo, ...]:
+ """
+ Computes the shard metadata based on ``unsharded_start_idx`` and
+ ``unsharded_end_idx`` (inclusive), which give the interval of the
+ unsharded flat parameter specifying the shard.
+ """
+ flat_param_offsets = self._get_flat_param_offsets()
+ _p_assert(len(flat_param_offsets) == len(
+ self.flat_param._numels_with_padding
+ ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
+ )
+ shard_param_infos: List[_ShardParamInfo] = []
+ sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
+ # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
+ # into the unsharded flat parameter (inclusive) of the given parameter
+ for i, (
+ (unsharded_param_start_idx, unsharded_param_end_idx),
+ is_padding,
+ ) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
+ if is_padding:
+ continue
+ in_sharded_flat_param = (
+ unsharded_start_idx <= unsharded_param_end_idx
+ and unsharded_end_idx >= unsharded_param_start_idx
+ )
+ if not in_sharded_flat_param:
+ shard_param_info = _ShardParamInfo(False, None, None, None, None, unsharded_param_start_idx, unsharded_param_end_idx)
+ else:
+ if unsharded_start_idx <= unsharded_param_start_idx:
+ # This branch can only happen once since the rank's
+ # unsharded start index can only intersect one parameter
+ intra_param_start_idx = 0
+ offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
+ else:
+ intra_param_start_idx = (
+ unsharded_start_idx - unsharded_param_start_idx
+ )
+ offset_in_shard = 0
+ if not (
+ offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
+ ):
+ raise ValueError(
+ f"Invalid `offset_in_shard` of {offset_in_shard} for "
+ f"sharded flat parameter with {sharded_flat_param_numel} numel"
+ )
+ intra_param_end_idx = (
+ min(unsharded_param_end_idx, unsharded_end_idx)
+ - unsharded_param_start_idx
+ )
+ numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
+ shard_param_info = _ShardParamInfo(
+ True,
+ offset_in_shard,
+ numel_in_shard,
+ intra_param_start_idx,
+ intra_param_end_idx,
+ unsharded_param_start_idx,
+ unsharded_param_end_idx,
+ )
+ shard_param_infos.append(shard_param_info)
+ return tuple(shard_param_infos)
+
+ @staticmethod
+ def _get_unpadded_shard(
+ tensor: Tensor,
+ rank: int,
+ world_size: int,
+ ) -> Tuple[Tensor, int]:
+ """
+ Returns the shard of ``tensor`` without any padding for the given
+ ``rank`` and ``world_size`` and the numel to pad for that shard.
+
+ If ``tensor`` is already flattened or may be viewed in the flattened
+ shape (which is true in the expected usage), then this method does not
+ allocate any new tensor memory.
+ """
+ if rank >= world_size:
+ raise ValueError(f"Shard rank should be small than shard world size, got {rank} and {world_size}")
+ chunks = torch.flatten(tensor).chunk(world_size)
+ if len(chunks) < (rank + 1):
+ # This rank gets an empty chunk fully padded with zeros since there
+ # are not enough chunks across ranks
+ chunk = chunks[0].new_empty(0)
+ else:
+ chunk = chunks[rank]
+ numel_to_pad = chunks[0].numel() - chunk.numel()
+ return chunk, numel_to_pad
+
+ @staticmethod
+ def _get_shard(
+ tensor: Tensor,
+ rank: int,
+ world_size: int,
+ ) -> Tuple[Tensor, int]:
+ """
+ Returns the shard of ``tensor`` with padding for the given ``rank`` and
+ ``world_size`` and the numel padded for that shard.
+
+ This method allocates new memory (via :meth:`clone`) since the
+ unsharded ``tensor`` may be deallocated after this method returns.
+ """
+ chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
+ tensor, rank, world_size
+ )
+ shard = chunk.clone()
+ if numel_to_pad > 0:
+ shard = F.pad(shard, [0, numel_to_pad])
+ return shard, numel_to_pad
+
+ @staticmethod
+ def _get_shard_from_padded_unshard_tensor(
+ tensor: Tensor,
+ rank: int,
+ world_size: int,
+ ) -> Tuple[Tensor, int]:
+ """
+ Returns the shard of ``tensor`` with padding for the given ``rank`` and
+ ``world_size`` and the numel padded for that shard.
+
+ This method allocates new memory (via :meth:`clone`) since the
+ unsharded ``tensor`` may be deallocated after this method returns.
+ """
+ chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
+ tensor, rank, world_size
+ )
+ shard = chunk.clone()
+ _p_assert(numel_to_pad == 0, f"The padded unshard flat param should be dividable with {world_size=}")
+ return shard
+
+ def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
+ """
+ Returns [start, end] offsets of each original parameter's flattened
+ data in the unsharded flat parameter (without padding).
+ NOTE: The returned list includes elements for alignment padding.
+ """
+ cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
+ starts = [0] + cumulative_sum[:-1]
+ ends = [end - 1 for end in cumulative_sum] # inclusive
+ param_offsets = list(zip(starts, ends))
+ return param_offsets
+
+ @no_type_check
+ @torch.no_grad()
+ def init_flat_param_attributes(self) -> None:
+ """
+ This initializes some attributes on the handle's ``FlatParameter``.
+ This should be called during lazy initialization since it requires the
+ parameter to be on the compute device if not offloading to CPU and we
+ want to give users the chance to move the parameter appropriately after
+ the FSDP constructor.
+
+ For each tensor attribute on the ``FlatParameter``, see the unshard and
+ reshard methods in this class for the allocation and free pattern.
+ """
+ flat_param = self.flat_param
+ self._check_on_compute_device(self.flat_param)
+ # We maintain a padded unsharded tensor that serves as the
+ # all-gather destination and owns the original parameter storages.
+ padded_unsharded_numel = flat_param.numel() * self.zero1_world_size
+ flat_param._full_param_padded = torch.empty(
+ padded_unsharded_numel,
+ device=self.device,
+ dtype=self._fwd_bwd_param_dtype,
+ )
+ flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
+ _free_storage(flat_param._full_param_padded)
+ #! add support for grad saving
+ flat_param._full_grad_padded = torch.empty(
+ padded_unsharded_numel,
+ device=self.device,
+ dtype=self._fwd_bwd_param_dtype,
+ )
+ _free_storage(flat_param._full_grad_padded)
+ #! grad accumulation support
+ flat_param._full_prec_grad_padded = torch.empty(
+ padded_unsharded_numel,
+ device=self.device,
+ dtype=self.full_prec_dtype,
+ )
+ _free_storage(flat_param._full_prec_grad_padded)
+ if self._offload_grads:
+ cpu_device = torch.device("cpu")
+ flat_param._cpu_grad = torch.zeros(
+ padded_unsharded_numel,
+ device=cpu_device,
+ dtype=self.full_prec_dtype,
+ ).pin_memory(device=self.device)
+ ###################
+ # UNSHARD/RESHARD #
+ ###################
+
+ def pre_unshard(self) -> bool:
+ """
+ Returns: ``False`` if this is a no-op and ``True`` otherwise.
+
+ Postcondition: ``self.flat_param`` 's data is on the device for
+ communication and is what should be all-gathered.
+ """
+ if (
+ self._training_state in [HandleTrainingState.SUMMON_FULL_PARAMS, HandleTrainingState.SYNC_PARAMS]
+ and self._skipped_use_sharded_views
+ ):
+ self._use_sharded_views()
+ self._check_on_compute_device(self.flat_param)
+ if self.needs_unshard():
+ self._alloc_padded_unsharded_flat_tensor()
+
+ def unshard(self):
+ padded_unsharded_flat_param = self._get_padded_unsharded_flat_tensor(param=True, free=False)
+ padded_unsharded_flat_param = self._all_gather_flat_param(padded_unsharded_flat_param)
+ self._use_unpadded_unsharded_flat_param(padded_unsharded_flat_param)
+
+ def needs_unshard(self) -> bool:
+ """Returns if the handle's flat parameter needs to be unsharded."""
+ padded_unsharded_flat_param = self._get_padded_unsharded_flat_tensor(free=False)
+ already_unsharded = (
+ padded_unsharded_flat_param._typed_storage()._size()
+ == padded_unsharded_flat_param.numel()
+ )
+ return not already_unsharded
+
+ def _alloc_padded_unsharded_flat_tensor(self, param: bool = True):
+ flat_param = self.flat_param
+ unsharded_flat_tensor = self._get_padded_unsharded_flat_tensor(param)
+ self._check_storage_freed(unsharded_flat_tensor)
+ _alloc_storage(unsharded_flat_tensor,
+ flat_param._padded_unsharded_size)
+ return unsharded_flat_tensor
+
+ def _get_padded_unsharded_flat_tensor(self, param: bool = True, free: bool = True) -> torch.Tensor:
+ """
+ Returns a reference to the padded unsharded flat parameter depending on
+ the calling context. This should only be called if using a sharded
+ strategy.
+ """
+ flat_param = self.flat_param
+ if param:
+ padded_unsharded_flat_tensor = flat_param._full_param_padded
+ dtype = self._fwd_bwd_param_dtype
+ else:
+ padded_unsharded_flat_tensor = flat_param._full_grad_padded
+ dtype = self._fwd_bwd_param_dtype
+ _p_assert(
+ padded_unsharded_flat_tensor.dtype == dtype,
+ f"Expects same precision but got {padded_unsharded_flat_tensor.dtype} vs {dtype}",
+ )
+
+ if free and padded_unsharded_flat_tensor.untyped_storage().size() > 0:
+ _free_storage(padded_unsharded_flat_tensor)
+ return padded_unsharded_flat_tensor
+
+ def _all_gather_flat_param(
+ self,
+ padded_unsharded_flat_param: Tensor,
+ ) -> Tensor:
+ """
+ All-gathers the handle's flat parameter to the destination
+ ``padded_unsharded_flat_param``, and switches to using the all-gathered
+ tensor.
+ """
+ _p_assert(
+ hasattr(self, "zero3_process_group") and hasattr(self, "zero3_group_size"),
+ "Expects a process group and world size to have been set via `shard()`",
+ )
+ #! cast zero1 param to zero3 param
+ #! be careful of recompute
+ if self._needs_param_sync and not self._param_synced:
+ sharded_flat_param = self.flat_param._zero1_shard.to(self._fwd_bwd_param_dtype)
+ expected_numel = sharded_flat_param.numel() * self.zero1_world_size
+ process_group = self.zero1_process_group
+ source = "zero1 shard"
+ else:
+ sharded_flat_param = self.flat_param._zero3_shard.to(self._fwd_bwd_param_dtype)
+ expected_numel = sharded_flat_param.numel() * self.zero3_group_size
+ process_group = self.zero3_process_group
+ source = "zero3 shard"
+
+ _p_assert(
+ padded_unsharded_flat_param.numel() == expected_numel,
+ f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}")
+ log0(f"All gather into full parameter from {source} with {process_group.size()=}")
+ dist.all_gather_into_tensor(
+ padded_unsharded_flat_param,
+ sharded_flat_param,
+ process_group,
+ )
+ return padded_unsharded_flat_param
+
+ def _use_unpadded_unsharded_flat_param(
+ self,
+ padded_unsharded_flat_param: torch.Tensor,
+ ) -> None:
+ """
+ Switches to using the *unpadded* unsharded flat parameter, which is a
+ view into the *padded* unsharded flat parameter.
+ """
+ unsharded_size = self.flat_param._unpadded_unsharded_size
+ self.flat_param.data = padded_unsharded_flat_param[:unsharded_size.numel()].view(unsharded_size)
+ # this `.view()` is not autograd visible
+ in_forward = self._training_state == HandleTrainingState.FORWARD
+ in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
+ if in_forward or in_pre_backward:
+ self._use_unsharded_views(as_params=False)
+ else:
+ self._use_unsharded_views(as_params=True)
+
+ def _use_unpadded_unsharded_flat_grad(
+ self,
+ padded_unsharded_flat_grad: torch.Tensor,
+ ) -> None:
+ """
+ Switches to using the *unpadded* unsharded flat parameter, which is a
+ view into the *padded* unsharded flat parameter.
+ """
+ unsharded_size = self.flat_param._unpadded_unsharded_size
+ self.flat_param.grad.data = padded_unsharded_flat_grad[:unsharded_size.numel()].view(unsharded_size)
+ self._use_unsharded_grad_views()
+
+ def post_unshard(self):
+ """
+ Runs the post-unshard logic. This includes freeing the low precision
+ shard if needed.
+ """
+ self._check_on_compute_device(self.flat_param)
+
+ @torch.no_grad()
+ def unshard_grad(self):
+ """
+ Unshard the handle's ``FlatParameter``'s gradient.
+
+ If all ranks have
+ ``None`` gradient, then all original parameters will as well. This
+ method performs an all-reduce and an all-gather. The additional
+ all-reduce is tolerable since this method is not meant to be used on
+ the computation critical path.
+
+ Postcondition: ``_saved_grad_shard`` is defined and contains the value
+ to set ``flat_param.grad`` after gradients are resharded.
+ """
+ flat_param = self.flat_param
+ self._check_unsharded(flat_param)
+
+ # Check if all ranks have a `None` gradient
+ num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
+ num_grad_none[0] = flat_param.grad is None
+ dist.all_reduce(num_grad_none, group=self.zero1_process_group)
+ if num_grad_none[0] == self.zero1_world_size:
+ flat_param._saved_grad_shard = None # type: ignore[assignment]
+ self._use_unsharded_grad_views()
+ return
+ if flat_param.grad is None:
+ # In the case that only some ranks have `None` gradient, we use
+ # zeros to approximate as a best effort attempt
+ if self._debug_level == dist.DebugLevel.INFO:
+ warnings.warn(
+ f"[Rank {self.rank}] Only some but not all ranks have a "
+ "`None` `FlatParameter` gradient, so FSDP is using zeros to "
+ "approximate those ranks' sharded gradients being `None`"
+ )
+ flat_param._saved_grad = None # type: ignore[assignment]
+ sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device, dtype=self._fwd_bwd_param_dtype) # type: ignore[attr-defined]
+ #如果该rank上有梯度,保存在flat_param._saved_grad中
+ else:
+ self._check_sharded(flat_param.grad)
+ # flat_param._saved_grad = flat_param.grad # type: ignore[attr-defined]
+ sharded_grad = flat_param.grad.to(self._fwd_bwd_param_dtype) # type: ignore[attr-defined]
+ # 分配内存,全聚合
+ padded_unsharded_grad = torch.zeros(
+ flat_param._padded_unsharded_size, # type: ignore[attr-defined]
+ device=self.device,
+ dtype=self._fwd_bwd_param_dtype,
+ )
+ dist.all_gather_into_tensor(
+ padded_unsharded_grad, sharded_grad, self.zero1_process_group
+ )
+ # 使用非分片的梯度视图
+ unsharded_size = self.flat_param._unpadded_unsharded_size
+ flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
+ unsharded_size
+ )
+ self._use_unsharded_grad_views()
+
+ def reshard_grad(self):
+ self.flat_param.grad = self.flat_param._saved_grad # type: ignore[attr-defined]
+ self._use_sharded_grad_views()
+ delattr(self.flat_param, "_saved_grad")
+
+ def offload_grad(self):
+ if not self._offload_grads:
+ warnings.warn(f"Call offload grad when offload grads is False")
+ return
+ cpu_tensor = self.flat_param._cpu_grad
+ gpu_tensor = self.flat_param._full_prec_grad_padded
+ self._check_on_cpu(cpu_tensor)
+ self._check_on_compute_device(gpu_tensor)
+ self._check_padded_unsharded(gpu_tensor)
+ cpu_tensor.untyped_storage().copy_(gpu_tensor.untyped_storage(), non_blocking=True)
+
+ def alloc_full_prec_grad(self):
+ if not self.already_load_full_prec_grad():
+ flat_param = self.flat_param
+ full_prec_grad = flat_param._full_prec_grad_padded
+ self._check_storage_freed(full_prec_grad)
+ _alloc_storage(full_prec_grad, flat_param._padded_unsharded_size)
+ full_prec_grad.zero_()
+ return
+
+ def reload_full_prec_grad(self):
+ if not self._offload_grads:
+ return
+ with torch.no_grad():
+ gpu_tensor = self.flat_param._full_prec_grad_padded
+ self._check_padded_unsharded(gpu_tensor)
+ self._check_on_compute_device(gpu_tensor)
+ cpu_tensor = self.flat_param._cpu_grad
+ self._check_on_cpu(cpu_tensor)
+ gpu_tensor.untyped_storage().copy_(cpu_tensor.untyped_storage(), non_blocking=True)
+
+ def already_load_full_prec_grad(self):
+ gpu_tensor = self.flat_param._full_prec_grad_padded
+ return gpu_tensor.device == self.device and gpu_tensor.untyped_storage().size() > 0
+
+ def free_full_prec_grad(self):
+ full_prec_grad = self.flat_param._full_prec_grad_padded
+ self._check_on_compute_device(full_prec_grad)
+ _free_storage(full_prec_grad)
+
+ def accumulate_grad(self):
+ '''
+ Precondition:
+ runtime_grad: _full_grad_padded finished grad compute
+
+ Postcondition:
+ grad is accumulated to full_prec_grad
+ '''
+ full_prec_grad = self.flat_param._full_prec_grad_padded
+ runtime_grad = self.flat_param._full_grad_padded
+ self._check_padded_unsharded(full_prec_grad)
+ self._check_padded_unsharded(runtime_grad)
+ self._check_on_compute_device(full_prec_grad)
+ self._check_on_compute_device(runtime_grad)
+ full_prec_grad.add_(runtime_grad)
+ return
+
+ def prepare_gradient_for_backward(self):
+ """
+ Prepares the gradient for the backward computation by saving and
+ clearing any existing sharded gradient in ``.grad`` to enable computing
+ a new unsharded gradient.
+
+ #! optimize this logic:
+ 1. if grad is not freed, Then last iter must not synced grad, then we use use_unshard_grad_view to accumulate grad
+
+ 2. if grad is freed, Then last iter must synced grad. alloc memeory for grad.
+ 2.1 alloc memory for grad computation
+ 2.2 set grad views
+
+ PostCondition:
+ flat_param.grad is the padded_unshard_grad
+ return the views of grad in correct position
+ """
+
+ _p_assert(
+ self._training_state
+ in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
+ "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
+ )
+
+ flat_param = self.flat_param
+ if not flat_param.requires_grad:
+ return
+ _p_assert(flat_param._full_grad_padded is not None, f"{self} got a None _full_grad_padded tensor for unshard flat parameters...")
+ self._check_on_compute_device(flat_param)
+ self._check_unsharded(flat_param.data)
+ #! 1. alloc memory if needed
+ padded_unsharded_flat_grad = flat_param._full_grad_padded
+ if self._is_storage_freed(padded_unsharded_flat_grad):
+ #! alloc memory
+ self._alloc_padded_unsharded_flat_tensor(param=False)
+ padded_unsharded_flat_grad.zero_()
+ else:
+ self._check_padded_unsharded(padded_unsharded_flat_grad)
+ #! 2. point grad to the reference tensor set proper view and grad view
+ flat_param.grad = flat_param._full_grad_padded
+ self._use_unpadded_unsharded_flat_grad(padded_unsharded_flat_grad)
+
+ def set_shard_grad(self, shard_grad):
+ flat_param = self.flat_param
+ _p_assert(not self._grad_synced, "A parameter should only sync its grad only once during one grad sync cycle")
+ flat_param._saved_grad = shard_grad
+ self._grad_synced = True
+
+ def free_runtime_unshard_grad(self):
+ self._free_unsharded_flat_tensor(param=False)
+
+ def prepare_gradient_for_zero1(self):
+ """
+ Prepares the gradient for optimizer computation by moving the sharded
+ gradient to the ``.grad`` attribute for the convienience of later reduce op
+ Precondition : saved_grad is the sharded grad
+
+ Postcondition: storage of saved_grad is freed
+
+ Post Condition:
+ ``.grad`` contains only the ``shard grad`` : Note : unshard grad storage free is done after zero1 grad sync
+ the full unsharded grad storage is freed
+ """
+ self._use_sharded_views()
+ self._use_sharded_grad_views()
+ del self.flat_param._saved_grad
+
+ def _get_reduce_scatter_tensors(self):
+ tensor = self.flat_param._full_prec_grad_padded
+ _p_assert(tensor.dtype == self.full_prec_dtype, "full_prec grad is not full prec.")
+ self._check_padded_unsharded(tensor)
+ self._check_on_compute_device(tensor)
+ chunks = tensor.chunk(self.zero1_world_size)
+ new_tensor = torch.empty_like(chunks[0])
+ return tensor, new_tensor
+
+ def _get_reduce_scatter_group(self):
+ return self.zero1_process_group
+
+ def reshard(self, free_unsharded_flat_param: bool):
+ """
+ Runs the reshard logic. This includes freeing the unsharded flat
+ parameter if ``free_unsharded_flat_param`` and switching to using the
+ sharded flat parameter.
+ """
+ if self._needs_param_sync and not self._param_synced:
+ zero3_shard = FlatParamHandle._get_shard_from_padded_unshard_tensor(self.flat_param.data, self.zero3_group_rank, self.zero3_group_size)
+ self.flat_param._zero3_shard = zero3_shard
+ self._param_synced = True
+
+ if free_unsharded_flat_param:
+ self._use_sharded_flat_param()
+ self._free_unsharded_flat_tensor()
+
+
+ def post_reshard(self):
+ """
+ Runs the post-reshard logic.
+ Precondition: ``self.flat_param`` 's data points to the full precision
+ sharded flat parameter.
+ """
+ pass
+
+ def _free_unsharded_flat_tensor(self, param: bool = True):
+ """
+ Frees the padded unsharded flat parameter. The tensor to free depends
+ on the calling context since the unshard may have forced full
+ precision, in which case a different tensor is used.
+ """
+ msg = "Parameter" if param else "Gradient"
+ log0(f"Freeing {msg} memory on handle {self}, {self._pre_forward_order_index=} {self._post_forward_index=}")
+
+ unsharded_flat_tensor = self._get_padded_unsharded_flat_tensor(param)
+ self._check_on_compute_device(unsharded_flat_tensor)
+ # Do not free the memory until all ops in the current stream finish
+ _no_dispatch_record_stream(
+ unsharded_flat_tensor, self._device_handle.current_stream()
+ )
+ _free_storage(unsharded_flat_tensor)
+
+ def _use_sharded_flat_param(self) -> None:
+ """Switches to using the sharded flat parameter."""
+ flat_param = self.flat_param
+ flat_param.data = flat_param._zero1_shard # type: ignore[attr-defined]
+ self._use_sharded_views()
+ #########
+ # VIEWS #
+ #########
+
+ @no_type_check
+ def _get_unflat_views_unaligned(
+ self,
+ tensor: Optional[torch.Tensor] = None,
+ ) -> Iterator[Tensor]:
+ """
+ Returns unflattened ``Tensor`` views into ``tensor`` if it is not
+ ``None`` or ``flat_param`` otherwise, where the unflattening is based
+ on ``flat_param`` 's metadata.
+
+ Examples for ``tensor`` include ``flat_param.grad`` or unsharded
+ tensor optimizer state.
+ """
+ flat_param = self.flat_param
+ if tensor is None:
+ tensor = flat_param
+
+ views = (
+ subtensor.view(shape)
+ for (subtensor, shape) in zip(
+ torch.split(tensor, flat_param._numels, dim=0),
+ flat_param._shapes,
+ )
+ )
+ return views
+
+ @no_type_check
+ def _get_unflat_views_aligned(
+ self,
+ tensor: Optional[Tensor] = None,
+ ) -> List[Tensor]:
+ """
+ This has the same contract as :meth:`_get_unflat_views_unaligned`
+ except it checks for ``None`` placeholders representing padding for
+ alignment, which may incur slightly more CPU overhead.
+ """
+ flat_param = self.flat_param
+ if tensor is None:
+ tensor = flat_param
+ splits: List[Tensor] = torch.split(
+ tensor, flat_param._numels_with_padding, dim=0
+ )
+ idx = 0
+ views: List[Tensor] = []
+ for split, is_padding in zip(splits, flat_param._is_padding_mask):
+ if is_padding:
+ continue
+ views.append(
+ split.view(flat_param._shapes[idx])
+ )
+ idx += 1
+ return views
+
+ @no_type_check
+ @torch.enable_grad()
+ def _use_unsharded_views(self, as_params: bool) -> None:
+ """
+ Unflattens the unsharded flat parameter by setting the original
+ parameter variables to be views into it.
+
+ unsharded unpadded and restore original parameter views
+
+ Args:
+ as_params (bool): If ``True``, then registers the original
+ parameters as ``nn.Parameter`` s; if ``False``, then registers
+ the original parameters only as ``Tensor`` s. ``False`` should
+ be used during forward/backward computation and when hiding the
+ original parameters from :meth:`nn.Module.named_parameters`.
+ """
+ log0(f"Change to unsharded Parameter View on {self._pre_forward_order_index=} {self._post_forward_index=}")
+
+ flat_param = self.flat_param
+ self._check_unsharded(flat_param)
+ views = self._get_unflat_views()
+
+ for i, (view, (param_name, module, _)) in enumerate(
+ zip(views, flat_param._param_infos)
+ ):
+ if as_params:
+ param = self.flat_param._params[i]
+ self._setattr_param(module, param_name, param)
+ param.data = view
+ else: # `as_params=False`
+ param_var: Tensor = view
+ if self.flat_param._tensors[i] is None:
+ # Save the `Tensor` for the pre-backward
+ self.flat_param._tensors[i] = view # save for pre-backward
+ else:
+ # Use the saved `Tensor` variable from the forward to
+ # preserve the autograd graph so that the post-backward
+ # hook fires (e.g. for reentrant AC)
+ tensor = self.flat_param._tensors[i]
+ tensor.data = view
+ param_var = tensor
+ self._setattr_tensor(module, param_name, param_var)
+ if self._training_state == HandleTrainingState.FORWARD:
+ module._parameters[param_name] = param_var
+ for i, (
+ param_name,
+ module,
+ _,
+ prim_param_name,
+ prim_module,
+ _,
+ ) in enumerate(self.flat_param._shared_param_infos):
+ prim_param: Union[Tensor, nn.Parameter] = getattr(
+ prim_module, prim_param_name
+ )
+ _p_assert(
+ not as_params or isinstance(prim_param, nn.Parameter),
+ f"as_params={as_params} type(prim_param)={type(prim_param)}",
+ )
+ if as_params:
+ shared_param = self.flat_param._shared_params[i]
+ self._setattr_param(module, param_name, shared_param)
+ shared_param.data = prim_param
+ else:
+ self._setattr_tensor(module, param_name, prim_param)
+ if self._training_state == HandleTrainingState.FORWARD:
+ module._parameters[param_name] = prim_param
+
+ @no_type_check
+ def _use_unsharded_grad_views(self) -> None:
+ """
+ Unflattens the unsharded flat parameter's gradient by setting the
+ original parameter variables' gradients to be views into it.
+
+ From the unpadded unshard grad to set parameter grad views at corresponing position relative to param
+ SO basically this is a similiar function to use_unsharded_param_views
+ """
+ log0(f"Change to unsharded Gradient View on {self._pre_forward_order_index=} {self._post_forward_index=}")
+
+ if self.flat_param.grad is None:
+ for param in chain(self.flat_param._params, self.flat_param._shared_params):
+ param.grad = None
+ return
+ # Expects the gradient to be in `flat_param.grad`
+ self._check_unsharded(self.flat_param.grad)
+
+ views = self._get_unflat_views(self.flat_param.grad)
+ for i, (view, (param_name, module, _)) in enumerate(
+ zip(views, self.flat_param._param_infos)
+ ):
+ _p_assert(
+ hasattr(module, param_name),
+ f"{self.flat_param._fqns[i]} is missing",
+ )
+ param = getattr(module, param_name)
+ if (
+ param.shape != view.shape
+ or param.dtype != view.dtype
+ or param.device != view.device
+ ):
+ # NOTE: This is a hack using `.data` to side step the check
+ # that parameter/gradient sizes/dtypes/devices match. From
+ # calling `reshard()`, `param` has the sharded size, has the
+ # full precision dtype, and if CPU offloading is enabled, is on
+ # CPU. Thus, one or more of the following cases can hold when
+ # in `no_sync()`, where `view` is the original parameter's
+ # gradient:
+ # 1. `view` can have the unsharded size.
+ # 2. `view` can have the parameter low precision dtype.
+ # 3. `view` can be on GPU.
+ if param.grad is None:
+ param.grad = torch.empty_like(param)
+ param.grad.data = view
+ else:
+ param.grad = view
+ for i, (
+ param_name,
+ module,
+ module_name,
+ prim_param_name,
+ prim_module,
+ _,
+ ) in enumerate(self.flat_param._shared_param_infos):
+ _p_assert(
+ hasattr(module, param_name),
+ f"{module_name + '.' + param_name if module_name else param_name} is missing",
+ ) # did not save FQN info in `_shared_param_infos`
+ param = getattr(module, param_name)
+ prim_param = getattr(prim_module, prim_param_name)
+ if (
+ param.shape != prim_param.grad.shape
+ or param.dtype != prim_param.grad.dtype
+ or param.device != prim_param.grad.device
+ ):
+ # NOTE: This is the same hack to use `.data` to side step the
+ # size check.
+ if param.grad is None:
+ param.grad = torch.empty_like(param)
+ param.grad.data = prim_param.grad
+ else:
+ param.grad = prim_param.grad
+
+ @contextlib.contextmanager
+ def unflatten_as_params(self) -> Generator:
+ """
+ Assumes the flat parameter is unsharded. When in the context,
+ unflattens the original parameters as ``nn.Parameter`` views into the
+ flat parameter, and after the context, restores the original parameters
+ as ``Tensor`` views into the flat parameter.
+ """
+ self._use_unsharded_views(as_params=True)
+ try:
+ yield
+ finally:
+ self._use_unsharded_views(as_params=False)
+
+ @no_type_check
+ @torch.no_grad()
+ def _use_sharded_views(self) -> None:
+ """
+ Sets the original parameter variables' data to be flattened views into
+ the sharded flat parameter.
+
+ The views are kept as flattened to simplify the case where a parameter
+ is sharded across ranks. Parameters whose data is not present in the
+ sharded flat parameter have their data set to a size-0 empty tensor. We
+ do not delete them to ensure to preserve expected behaviors like model
+ printability. Parameters whose data is present must preserve their
+ variables to be passable to an optimizer.
+ """
+ log0(f"Change to sharded Parameter View on {self._pre_forward_order_index=} {self._post_forward_index=}")
+ self._unsharded_flat_param_for_skipped_views = None
+ flat_param = self.flat_param
+ self._check_sharded(flat_param)
+ # Construct once and reuse for all parameters not in the local shard
+ size_0_empty_tensor = torch.empty(
+ 0,
+ dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
+ device=self.flat_param.device,
+ requires_grad=False,
+ )
+ for param, shard_param_info, (param_name, module, _) in zip(
+ flat_param._params,
+ flat_param._shard_param_infos,
+ flat_param._param_infos
+ ):
+ self._setattr_param(module, param_name, param)
+ if not shard_param_info.in_shard:
+ # Allow the original data to be freed via garbage collection
+ param.data = size_0_empty_tensor
+ else:
+ offset = shard_param_info.offset_in_shard
+ numel_in_shard = shard_param_info.numel_in_shard
+ param.data = flat_param[offset : offset + numel_in_shard]
+ for i, (
+ param,
+ (param_name, module, _, prim_param_name, prim_module, _),
+ ) in enumerate(
+ zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
+ ):
+ self._setattr_param(module, param_name, param)
+ prim_param = getattr(prim_module, prim_param_name)
+ param.data = prim_param # could be both empty and non-empty
+ if self._training_state == HandleTrainingState.BACKWARD_POST:
+ # Clear the saved `Tensor`s since they are unneeded now
+ for i in range(len(self.flat_param._tensors)):
+ self.flat_param._tensors[i] = None
+
+ @no_type_check
+ @torch.no_grad()
+ def _use_sharded_grad_views(self) -> None:
+ """
+ Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
+
+ This is a no-op if there is no gradient.
+
+ Parameters whose data is not present in the sharded flat parameter and
+ parameters with ``requires_grad=False`` have their gradients set to
+ ``None``. Since the gradient variables do not need to be preserved,
+ this method does not manipulate existing ``Tensor`` data directly and
+ creates new ``Tensor`` variables instead.
+ """
+ log0(f"Change to sharded Gradient View on {self._pre_forward_order_index=} {self._post_forward_index=}")
+
+ flat_param = self.flat_param
+ self._check_sharded(flat_param)
+ grad = self.sharded_grad
+ if grad is None:
+ for param in chain(flat_param._params, flat_param._shared_params):
+ param.grad = None
+ return
+ self._check_sharded(grad)
+ for param, shard_param_info, is_grad_none in zip(
+ flat_param._params,
+ flat_param._shard_param_infos,
+ flat_param._is_grad_none_mask,
+ ):
+ if not shard_param_info.in_shard:
+ param.grad = None
+ else:
+ numel_in_shard = shard_param_info.numel_in_shard
+
+ if param.requires_grad and not is_grad_none:
+ offset = shard_param_info.offset_in_shard
+ if param.dtype != grad.dtype:
+ if param.grad is None:
+ # `.grad` must have the same shape as `param`
+ param.grad = torch.empty_like(param)
+ param.grad.data = grad[
+ offset : offset + numel_in_shard
+ ]
+ else:
+ param.grad = grad[
+ offset : offset + numel_in_shard
+ ]
+
+ else:
+ param.grad = None
+
+ for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
+ zip(flat_param._shared_params, flat_param._shared_param_infos)
+ ):
+ in_sharded_flat_param = hasattr(prim_module, prim_param_name)
+ if in_sharded_flat_param and param.requires_grad:
+ prim_param = getattr(prim_module, prim_param_name)
+ param.grad = prim_param.grad
+ else:
+ param.grad = None
+
+ def _reset_flat_param_grad_info_if_needed(self):
+ """
+
+ (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
+ original parameters' ``.grad`` are ``None``, and
+ (2) sets ``flat_param.requires_grad=False`` if *none* of the original
+ parameters require gradient.
+ For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
+ which case we want to free the gradients as soon after the
+ ``zero_grad()`` call as possible.
+ """
+ flat_param = self.flat_param
+ all_grad_none = True
+ requires_grad = False
+ for param in flat_param._params:
+ all_grad_none &= param.grad is None
+ requires_grad |= param.requires_grad
+ if all_grad_none:
+ flat_param.grad = None
+ # As long as one parameter requires gradient, then the flat parameter
+ # must require gradient
+ flat_param.requires_grad = requires_grad
+
+ def _deregister_orig_params(self):
+ for param_info in self.flat_param._param_infos:
+ param_name, module, _ = param_info
+ if hasattr(module, param_name):
+ delattr(module, param_name)
+ for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
+ if hasattr(module, param_name):
+ delattr(module, param_name)
+
+ ###########
+ # HELPERS #
+ ###########
+ def _get_modules(self) -> Set[nn.Module]:
+ """
+ Returns a :class:`set` of the modules whose parameters are included
+ in this handle's flat parameter.
+ """
+ return {pi.module for pi in self.flat_param._param_infos}.union(
+ {spi.module for spi in self.flat_param._shared_param_infos}
+ )
+
+ def is_sharded(self, tensor: Tensor) -> bool:
+ """
+ Returns if ``tensor`` is *currently* sharded. For ``NO_SHARD``, we
+ choose to have this always return ``False`` for clarity.
+ """
+ if (
+ not hasattr(self.flat_param, "_sharded_size")
+ ):
+ # `_sharded_size` is defined iff `handle.shard()` has been called
+ return False
+ sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
+ return tensor.size() == sharded_size
+
+ def param_module_names(self) -> Iterator[Tuple[str, str]]:
+ shared_param_infos = [
+ ParamInfo(param_name, module, module_name)
+ for (
+ param_name,
+ module,
+ module_name,
+ _,
+ _,
+ _,
+ ) in self.flat_param._shared_param_infos
+ ]
+ for param_info in chain(self.flat_param._param_infos, shared_param_infos):
+ param_name, _, module_name = param_info
+ yield (param_name, module_name)
+
+ def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
+ for param_name, _, module_name in [
+ ParamInfo(param_name, module, module_name)
+ for (
+ param_name,
+ module,
+ module_name,
+ _,
+ _,
+ _,
+ ) in self.flat_param._shared_param_infos
+ ]:
+ yield (param_name, module_name)
+
+ @property
+ def _fqns_in_shard(self) -> List[str]:
+ """Returns the FQNs of the parameters present in this rank's shard."""
+ fqns_in_shard: List[str] = []
+ for fqn, shard_param_info in zip(
+ self.flat_param._fqns, self.flat_param._shard_param_infos
+ ):
+ if shard_param_info.in_shard:
+ fqns_in_shard.append(fqn)
+ return fqns_in_shard
+
+ @property
+ def sharded_grad(self) -> Optional[Tensor]:
+ """Returns the handle's sharded gradient."""
+ flat_param = self.flat_param
+ grad: Optional[Tensor]
+
+ if hasattr(flat_param, "_saved_grad"):
+ # In the post-backward hook, the sharded gradient is still in
+ # `_saved_grad_shard`.
+ grad = flat_param._saved_grad.to(self.full_prec_dtype)
+ else:
+ # If in IDLE or in FORWARD states, then there may be an
+ # (accumulated) gradient. If accessed in IDLE, then this should
+ # be due to re-registering the original parameters (e.g. in state
+ # dict load).
+ _p_assert(
+ flat_param.grad is None
+ or self._training_state
+ in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
+ "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
+ "unless in IDLE or FORWARD",
+ )
+ grad = None
+ return grad
+
+ #######################
+ # CHECKS & INVARIANTS #
+ #######################
+ def _check_on_compute_device(self, tensor: Tensor):
+ _p_assert(
+ tensor.device == self.device,
+ f"Expects tensor to be on the compute device {self.device}",
+ )
+
+ def _check_on_cpu(self, tensor: Tensor):
+ _p_assert(
+ tensor.device == torch.device("cpu"),
+ f"Expects tensor to be on CPU but got {tensor.device}",
+ )
+
+ @staticmethod
+ def _check_storage_freed(tensor: Tensor):
+ storage_size: int = tensor._typed_storage()._size()
+ _p_assert(
+ storage_size == 0,
+ f"Expects storage to be freed but got storage with size {storage_size}",
+ )
+
+ @staticmethod
+ def _is_storage_freed(tensor: Tensor) -> bool:
+ return tensor is not None and tensor._typed_storage()._size() == 0
+
+ @staticmethod
+ def _check_storage_allocated(tensor: Tensor):
+ storage_size: int = tensor._typed_storage()._size()
+ _p_assert(storage_size > 0, "Expects storage to be allocated")
+
+ def _check_unsharded(self, tensor: Tensor):
+ msg_prefix = "Expects tensor to be unsharded "
+ _p_assert(tensor is not None, msg_prefix + "but got `None`")
+ unsharded_size = self.flat_param._unpadded_unsharded_size
+ _p_assert(
+ tensor.size() == unsharded_size,
+ msg_prefix + f"with size {unsharded_size} but got {tensor.size()} with storage {tensor.untyped_storage().size()}",
+ )
+
+ def _check_padded_unsharded(self, tensor: Tensor):
+ msg_prefix = "Expects tensor to be unsharded and padded"
+ _p_assert(tensor is not None, msg_prefix + "but got `None`")
+ unsharded_size = self.flat_param._padded_unsharded_size
+ _p_assert(
+ tensor.size() == unsharded_size,
+ msg_prefix + f"with size {unsharded_size} but got {tensor.size()} with storage {tensor.untyped_storage().size()}",
+ )
+
+ def _check_sharded(self, tensor: Tensor):
+ msg_prefix = "Expects tensor to be sharded "
+ _p_assert(tensor is not None, msg_prefix + "but got `None`")
+ sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
+ _p_assert(
+ tensor.size() == sharded_size,
+ msg_prefix + f"with size {sharded_size} but got {tensor.size()} with storage {tensor.untyped_storage().size()}",
+ )
+
+ ##############
+ # PROPERTIES #
+ ##############
+
+ @property
+ def _skipped_use_sharded_views(self) -> bool:
+ return self._unsharded_flat_param_for_skipped_views is not None
+ #================== debug =========================
+
+ def _named_module_parameters(self):
+ #! 获取模型的parameter, 动态重建的参数
+ for i, (param_name, module, module_name) in enumerate(
+ self.flat_param._param_infos
+ ):
+ _p_assert(
+ hasattr(module, param_name),
+ f"{self.flat_param._fqns[i]} is missing",
+ )
+ param = getattr(module, param_name)
+ yield f"{module_name}.{param_name}", param
+
+ def _get_orig_param_by_name(self, total_name):
+ flat_param = self.flat_param
+ for param, (param_name, _, module_name) in zip(
+ flat_param._params, flat_param._param_infos
+ ):
+ if total_name == f"{module_name}.{param_name}":
+ return param
+ return None
+
+ def _get_module_param_by_name(self, total_name):
+ flat_param = self.flat_param
+ for param_name, module, module_name in flat_param._param_infos:
+ if total_name == f"{module_name}{param_name}":
+ return getattr(module, param_name)
+ return None
+
+ def __param_list(self):
+ self._use_unsharded_grad_views()
+ for param in self.flat_param._params:
+ yield param
+ yield param
+
+ yield param
+
+
+ def _shard_grad_list(self):
+ for param in self.flat_param._params:
+ yield param.grad
+
+
+def _unsafe_setattr_param(
+ module: nn.Module, param_name: str, param: nn.Parameter
+) -> None:
+ module._parameters[param_name] = param
+ # This bypasses any overrides in case `module` is an instance of an
+ # `nn.Module` subclass
+ super(nn.Module, module).__setattr__(param_name, param)
+
+
+def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
+ module._parameters.pop(param_name, None)
+ # This bypasses any overrides in case `module` is an instance of an
+ # `nn.Module` subclass
+ super(nn.Module, module).__setattr__(param_name, tensor)
+
+
+def _safe_setattr_tensor_or_param(
+ module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
+):
+ # Call `delattr()` and `setattr()` to go through `nn.Module` checks
+ if hasattr(module, param_name):
+ delattr(module, param_name)
+ setattr(module, param_name, tensor_or_param)
+
+
+def _convert_to_params(
+ tensors: List[Union[torch.Tensor, nn.Parameter]]
+) -> List[nn.Parameter]:
+ return [t if isinstance(t, nn.Parameter) else nn.Parameter(t, requires_grad=t.requires_grad) for t in tensors]
+
+
+def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
+ return (
+ param_or_tensor.detach()
+ if isinstance(param_or_tensor, nn.Parameter)
+ else param_or_tensor
+ )
+
+
+def _get_aligned_numel(unsharded_dtype: torch.dtype):
+ # NOTE: This alignment constraint comes from TorchInductor.
+ ALIGNMENT = 16 # bytes
+ unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
+ aligned_numel = ALIGNMENT // unsharded_dtype_size
+ return aligned_numel
+
+
+@functools.lru_cache(8)
+def _get_dtype_size(dtype):
+ return torch.empty((), dtype=dtype).element_size()
+
+
+def _construct_padding_tensor(
+ padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
+):
+ # NOTE: Set the padding value as a magic number for debuggability. The
+ # value itself should never be used in any user-facing computation.
+ return (
+ # torch.ones(
+ torch.zeros(
+ (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
+ )
+ )
+
+
+def log0(msg):
+ if dist.get_rank() == 0:
+ logger.info(msg)
+
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/fsdp.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..94e8dd9b7aa35654f82bf5ee6c47b39f656a28b1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/fsdp.py
@@ -0,0 +1,364 @@
+import traceback
+from contextlib import contextmanager
+from typing import (
+ Any,
+ Callable,
+ Generator,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+import torch
+import torch.nn as nn
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+ _CHECKPOINT_WRAPPED_MODULE,
+ ActivationWrapper,
+)
+from torch.distributed.utils import _p_assert
+from megatron.core import mpu
+import mindspeed.core.distributed.layerzero.zero3._traversal_utils as traversal_utils
+from mindspeed.core.distributed.layerzero.zero3._common_utils import (
+ _ZeRO3State,
+ ZERO3_PREFIX,
+ ZERO3_WRAPPED_MODULE,
+ TrainingState,
+)
+from mindspeed.core.distributed.layerzero.zero3._init_utils import (
+ _init_buffer_state,
+ _init_core_state,
+ _init_device_handle,
+ _init_ignored_module_states,
+ _init_param_handle_from_module,
+ _init_prefetching_state,
+ _init_process_group_state,
+ _init_runtime_state,
+ ProcessGroupType,
+)
+
+from mindspeed.core.distributed.layerzero.zero3._wrap_utils import _auto_wrap
+from mindspeed.core.distributed.layerzero.zero3.api import (
+ BackwardPrefetch,
+ BackwardReduceScatter,
+ MixedPrecision,
+)
+from mindspeed.core.distributed.layerzero.zero3.flat_param import FlatParameter, FlatParamHandle
+from mindspeed.core.distributed.layerzero.zero3.wrap import ModuleWrapPolicy
+from mindspeed.core.distributed.layerzero.runtime._forward import (
+ _post_forward,
+ _post_forward_reshard,
+ _pre_forward,
+ _pre_forward_backward_unshard,
+)
+from mindspeed.core.distributed.layerzero.runtime._root_forward import _zero3_root_pre_forward
+from mindspeed.core.distributed.layerzero.runtime._utils import (
+ _get_zero3_root_states,
+ _is_zero3_root,
+ _cast_forward_outputs,
+)
+from mindspeed.core.distributed.layerzero.runtime._initialize import _lazy_init
+from mindspeed.core.distributed.layerzero import constants
+
+
+__all__ = [
+ "LayerZeRO3",
+]
+FLAT_PARAM = "_flat_param"
+
+
+class LayerZeRO3(nn.Module, _ZeRO3State):
+
+ def __init__(
+ self,
+ module: nn.Module,
+ process_group: ProcessGroupType = None,
+ tp_zero_process_group: ProcessGroupType = None,
+ auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy]] = None,
+ backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
+ backward_reduce_scatter: Optional[BackwardReduceScatter] = BackwardReduceScatter.BACKWARD_PRE,
+ mixed_precision: Optional[MixedPrecision] = None,
+ offload_grads: bool = False,
+ ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
+ param_init_fn: Optional[Callable[[nn.Module], None]] = None,
+ device_id: Optional[Union[int, torch.device]] = None,
+ forward_prefetch: bool = True,
+ limit_all_gathers: bool = True,
+ ignored_states: Union[
+ Optional[Iterable[torch.nn.Parameter]
+ ], Optional[Iterable[torch.nn.Module]]
+ ] = None,
+ ):
+ torch._C._log_api_usage_once("layerzero")
+ super().__init__()
+
+ _init_ignored_module_states(
+ self, module, ignored_modules, ignored_states)
+ _init_device_handle(self, module, self._ignored_params, device_id)
+ _init_process_group_state(self, process_group)
+
+ if auto_wrap_policy is not None:
+ root_kwargs = {
+ "process_group": (self.zero3_process_group, self.zero1_process_group),
+ "tp_zero_process_group": tp_zero_process_group,
+ "backward_prefetch": backward_prefetch,
+ "backward_reduce_scatter": backward_reduce_scatter,
+ "mixed_precision": mixed_precision,
+ "offload_grads": offload_grads,
+ "param_init_fn": param_init_fn,
+ "device_id": device_id,
+ "forward_prefetch": forward_prefetch,
+ "limit_all_gathers": limit_all_gathers,
+ "ignored_states": self._ignored_params,
+ }
+ _auto_wrap(
+ module,
+ auto_wrap_policy,
+ self._ignored_modules,
+ self._ignored_params,
+ root_kwargs,
+ LayerZeRO3,
+ )
+
+ backward_prefetch_limit = 1
+ forward_prefetch_limit = 1
+ _init_core_state(
+ self,
+ mixed_precision,
+ limit_all_gathers,
+ backward_prefetch_limit,
+ forward_prefetch_limit,
+ offload_grads,
+ )
+ _init_runtime_state(self)
+
+ _init_prefetching_state(self, backward_prefetch,
+ forward_prefetch, backward_reduce_scatter)
+ _init_buffer_state(self, module)
+ _init_param_handle_from_module(
+ self,
+ module,
+ device_id,
+ param_init_fn,
+ )
+ self._zero3_wrapped_module = module
+
+ @property
+ def module(self) -> nn.Module:
+ """
+ Returns the wrapped module (like :class:`DistributedDataParallel`).
+ """
+ # FSDP's `.module` must refer to the innermost wrapped module when
+ # composing with other module wrappers in order for state dict to work
+ if isinstance(self._zero3_wrapped_module, ActivationWrapper):
+ return getattr(self._zero3_wrapped_module, _CHECKPOINT_WRAPPED_MODULE)
+ return self._zero3_wrapped_module
+
+ @property
+ def _has_params(self) -> bool:
+ """Returns whether this FSDP instance manages any parameters."""
+ return hasattr(self, "_handle") and self._handle is not None
+
+ @property
+ def _flat_param(self) -> Optional[FlatParameter]:
+ return self._handle.flat_param if self._handle else None
+
+ def __getattr__(self, name: str) -> Any:
+ """Forward missing attributes to the wrapped module."""
+ try:
+ return super().__getattr__(name) # defer to nn.Module's logic
+ except AttributeError:
+ return getattr(self._zero3_wrapped_module, name)
+
+ def __getitem__(self, key: int) -> Any:
+ """Forward indexing calls in case the module is an ``nn.Sequential``."""
+ if hasattr(self, ZERO3_WRAPPED_MODULE):
+ # type: ignore[operator]
+ return self._zero3_wrapped_module.__getitem__(key)
+ return super().__getitem__(key)
+
+ def check_is_root(self) -> bool:
+ return _is_zero3_root(self, self)
+
+ @staticmethod
+ def zero3_modules(
+ module: nn.Module,
+ root_only: bool = False,
+ ) -> List["LayerZeRO3"]:
+ """
+ Returns all nested ZeRO3 instances, possibly including ``module`` itself
+ and only including ZeRO3 root modules if ``root_only=True``.
+
+ Args:
+ module (torch.nn.Module): Root module, which may or may not be an
+ ``FSDP`` module.
+ root_only (bool): Whether to return only FSDP root modules.
+ (Default: ``False``)
+
+ Returns:
+ List[FullyShardedDataParallel]: FSDP modules that are nested in
+ the input ``module``.
+ """
+ if root_only:
+ return _get_zero3_root_states(module)
+ return traversal_utils._get_zero3_states(module)
+
+ def _mixed_precision_enabled_for_buffers(self) -> bool:
+ """
+ Returns if the user explicitly enabled buffer mixed precision.
+
+ NOTE: Unlike parameters and gradient reduction, buffer mixed precision
+ is applied at the FSDP instance level, not the ``FlatParameter`` level,
+ which may be different for the composable code path.
+ """
+ return self.mixed_precision.buffer_dtype is not None
+
+ def _reset_lazy_init(self) -> None:
+ """
+ Reset instance so :func:`_lazy_init` will run on the next forward.
+ """
+ self._is_root: Optional[bool] = None
+
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
+ """
+ Runs the forward pass for the wrapped module, inserting FSDP-specific
+ pre- and post-forward sharding logic.
+ """
+ handle = self._handle
+ with torch.autograd.profiler.record_function(
+ "LayerZeRO3.forward"
+ ):
+ args, kwargs = _zero3_root_pre_forward(self, self, args, kwargs)
+ unused = None
+ args, kwargs = _pre_forward(
+ self,
+ handle,
+ _pre_forward_backward_unshard,
+ self._zero3_wrapped_module,
+ args,
+ kwargs,
+ )
+ if handle:
+ _p_assert(
+ handle.flat_param.device == self.compute_device,
+ "Expected `FlatParameter` to be on the compute device "
+ f"{self.compute_device} but got {handle.flat_param.device}",
+ )
+ with torch.autograd.profiler.record_function("Wrapped Module Forward"):
+ output = self._zero3_wrapped_module(*args, **kwargs)
+ output = _post_forward(
+ self, handle, _post_forward_reshard, self, unused, output
+ )
+ if constants.AUTO_CAST_OUTPUT and self._is_root:
+ if mpu.is_initialized():
+ if mpu.is_pipeline_last_stage():
+ output = _cast_forward_outputs(torch.float32, output)
+ else:
+ output = _cast_forward_outputs(torch.float32, output)
+ return output
+
+ def named_buffers(
+ self,
+ *args,
+ **kwargs,
+ ) -> Iterator[Tuple[str, torch.Tensor]]:
+ """
+ Overrides :meth:`named_buffers()` to intercept buffer names and
+ remove all occurrences of the FSDP-specific flattened buffer prefix
+ when inside the :meth:`summon_full_params` context manager.
+ """
+ should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
+ for buffer_name, buffer in super().named_buffers(*args, **kwargs):
+ if should_clean_name:
+ # Remove any instances of the FSDP-specific prefix; there can
+ # be multiple in the case of nested FSDP modules
+ buffer_name = buffer_name.replace(ZERO3_PREFIX, "")
+ yield (buffer_name, buffer)
+
+ def named_modules(
+ self,
+ *args,
+ **kwargs,
+ ) -> Iterator[Tuple[str, torch.Tensor]]:
+ """
+ Overrides :meth:`named_buffers()` to intercept buffer names and
+ remove all occurrences of the FSDP-specific flattened buffer prefix
+ when inside the :meth:`summon_full_params` context manager.
+ """
+ should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
+ for module_name, module in super().named_modules(*args, **kwargs):
+ if should_clean_name:
+ # Remove any instances of the FSDP-specific prefix; there can
+ # be multiple in the case of nested FSDP modules
+ module_name = module_name.replace(ZERO3_PREFIX, "")
+ yield (module_name, module)
+
+ def named_parameters(
+ self,
+ *args,
+ **kwargs,
+ ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
+ """
+ Overrides :meth:`named_parameters()` to intercept parameter names and
+ remove all occurrences of the FSDP-specific flattened parameter prefix
+ when inside the :meth:`summon_full_params` context manager.
+ """
+ should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
+ for param_name, param in super().named_parameters(*args, **kwargs):
+ if should_clean_name:
+ # Remove any instances of the FSDP-specific prefix; there can
+ # be multiple in the case of nested FSDP modules
+ param_name = param_name.replace(ZERO3_PREFIX, "")
+ yield (param_name, param)
+
+ def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
+ """Assert we are in the given state."""
+ if isinstance(state, TrainingState):
+ state = [state]
+ if self.training_state not in state:
+ msg = (
+ f"expected to be in states {state} but current state "
+ f"is {self.training_state}"
+ )
+ # In case we are failing in the context of autograd hook, asserting
+ # may not generate useful msg. So, let's print it to be sure.
+ if self.zero3_rank == 0:
+ print(f"Asserting FSDP instance is: {self}")
+ print(f"ERROR: {msg}")
+ traceback.print_stack()
+ raise ValueError(msg)
+
+ @contextmanager
+ def no_sync(self) -> Generator:
+ _lazy_init(self, self)
+ if not self._is_root:
+ raise RuntimeError(
+ "`no_sync()` on inner LayerZeRO instances is not supported. Please call `no_sync()` on root LayerZeRO module."
+ )
+ self._assert_state(TrainingState.IDLE)
+ old_flags = []
+ for m in self.modules():
+ if isinstance(m, LayerZeRO3):
+ old_flags.append((m, m._sync_gradients))
+ m._sync_gradients = False
+ try:
+ yield
+ finally:
+ for m, old_flag in old_flags:
+ if m._sync_gradients:
+ raise ValueError(
+ "`_sync_gradients` was incorrectly set to `True` while in the `no_sync()` context manager"
+ )
+ m._sync_gradients = old_flag
+
+ def zero1_parameters(self, recurse: bool = True):
+ # for name, param in chain(handle. for handle in self._all_handles):
+ for param in self.parameters(recurse):
+ if param.requires_grad:
+ yield param
+
+ def zero_grad_buffer(self):
+ '''This method is to used for accomendate with Megatron'''
+ pass
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/wrap.py b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/wrap.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ceb6a30094df5b9831f579865b573258ccaae7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/layerzero/zero3/wrap.py
@@ -0,0 +1,62 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Set,
+ Type,
+)
+
+import torch.nn as nn
+from torch.distributed.fsdp.wrap import (
+ _post_order_apply,
+ _construct_wrap_fn,
+ always_wrap_policy,
+ _Policy,
+ _module_wrap_policy,
+ ModuleWrapPolicy,
+ CustomPolicy,
+ _run_mixed_precision_override_policy,
+ _or_policy,
+ _recursive_wrap,
+ _wrap_module_cls_individually,
+ lambda_auto_wrap_policy,
+ transformer_auto_wrap_policy,
+ size_based_auto_wrap_policy,
+ enable_wrap,
+ wrap
+)
+__all__ = [
+ "always_wrap_policy",
+ "lambda_auto_wrap_policy",
+ "transformer_auto_wrap_policy",
+ "size_based_auto_wrap_policy",
+ "enable_wrap",
+ "wrap",
+ "CustomPolicy",
+ "ModuleWrapPolicy",
+]
+
+
+def _run_tensor_parallel_pg_override_policy(
+ root_module: nn.Module,
+ module_classes: Iterable[Type[nn.Module]],
+ ignored_modules: Set[nn.Module],
+ root_kwargs: Dict[str, Any],
+ target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
+):
+ module_classes_tuple = tuple(set(module_classes))
+ for module in root_module.modules():
+ if module in ignored_modules:
+ continue
+ elif isinstance(module, module_classes_tuple):
+ # This policy overrides any existing policy
+ if module not in target_module_to_kwargs:
+ # Only inherit from the root kwargs if not already specified
+ target_module_to_kwargs[module] = root_kwargs
+ target_module_to_kwargs[module]["process_group"] = root_kwargs["tp_zero_process_group"]
+ return target_module_to_kwargs
diff --git a/model/train/yoco_moe/mindspeed/core/distributed/param_and_grad_buffer.py b/model/train/yoco_moe/mindspeed/core/distributed/param_and_grad_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f804f842209b3d4d6768df44ff1fbe7c8a08e5b8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/distributed/param_and_grad_buffer.py
@@ -0,0 +1,277 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import os
+import math
+import logging
+from enum import Enum
+from functools import wraps
+from typing import Dict, List, Optional
+from logging import getLogger
+from contextlib import nullcontext
+import torch
+from megatron.training import get_args
+from megatron.core.distributed.param_and_grad_buffer import BufferType
+from megatron.core import parallel_state
+from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
+from megatron.core.utils import log_on_each_pipeline_stage
+
+
+logger = getLogger(__name__)
+
+
+def pipe_register_grad_ready(self, param: torch.nn.Parameter):
+ assert (self.ddp_config.overlap_grad_reduce), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
+ from mindspeed.moe.pipe_experts import FLAG_GRAD_REDUCE
+ if self.is_last_microbatch and FLAG_GRAD_REDUCE:
+ bucket = self.param_to_bucket[param]
+ bucket.register_grad_ready(param)
+
+
+def reuse_fp32_param_param_and_grad_buffer_init_wrapper(init_func):
+ @wraps(init_func)
+ def reuse_fp32_param_param_and_grad_buffer_init(*args, **kwargs):
+ global_args = get_args()
+ math_ceil = math.ceil
+ if global_args.reuse_fp32_param and global_args.use_distributed_optimizer:
+ def ceil_even(x):
+ return math_ceil(math_ceil(x) / 2) * 2
+ math.ceil = ceil_even
+ init_func(*args, **kwargs)
+ if global_args.reuse_fp32_param and global_args.use_distributed_optimizer:
+ math.ceil = math_ceil
+ return reuse_fp32_param_param_and_grad_buffer_init
+
+
+def param_and_grad_buffer_init_pad(
+ self,
+ ddp_config: DistributedDataParallelConfig,
+ param_dtype: torch.dtype,
+ grad_dtype: torch.dtype,
+ params: List[torch.nn.Parameter],
+ data_parallel_group: torch.distributed.ProcessGroup,
+ bucket_size: int,
+ param_to_name: Dict[torch.nn.Parameter, str],
+ gradient_scaling_factor: float,
+):
+ self.ddp_config = ddp_config
+
+ # Check that params are unique.
+ unique_params = set()
+ for param in params:
+ assert param not in unique_params
+ unique_params.add(param)
+ del unique_params
+
+ # Store attributes that will be needed later.
+ self.param_dtype = param_dtype
+ self.grad_dtype = grad_dtype
+ self.data_parallel_group = data_parallel_group
+ self.data_parallel_world_size = torch.distributed.get_world_size(
+ group=self.data_parallel_group
+ )
+ self.gradient_scaling_factor = gradient_scaling_factor
+ self.is_last_microbatch = True
+
+ # Data structures to store underlying buckets and relevant indexing data.
+ self.buckets = []
+ self.param_to_bucket = {} # Param -> bucket mapping.
+ self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
+
+ def _pad(number_to_be_padded: int, divisor: int) -> int:
+ return int(math.ceil(number_to_be_padded / divisor) * divisor)
+
+ def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int:
+ """
+ Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
+ """
+ if self.ddp_config.use_distributed_optimizer:
+ # We now ensure that all buckets start at a memory address that is 512-byte
+ # If using a distributed optimizer, pad the memory buffer to be
+ # multiple of data_parallel_world_size. (This padding is done
+ # due to a constraint with the reduce_scatter op, which requires
+ # all tensors have equal size.)
+ # 512-byte for Ascend, 256-byte for nv.
+
+ element_size = 4 if param_dtype == torch.float else 2
+ global_args = get_args()
+ align_size = global_args.param_and_grad_buffer_pad // element_size
+ return _pad(bucket_end_index, self.data_parallel_world_size * align_size)
+ return bucket_end_index
+
+ def _pad_start_of_param_if_needed(param_start_index: int) -> int:
+ """
+ Pads start index of param if using distributed optimizer (to ensure "good" alignment).
+ """
+ if self.ddp_config.use_distributed_optimizer:
+ # Ensure that params start at 128-byte aligned addresses (64 values
+ # since params are >= 16-bit precision).
+ return _pad(param_start_index, 64)
+ return param_start_index
+
+ # First, figure out how many elements should be in the underlying buffer storage.
+ # Note that if we need to split the buffer into smaller buckets, each of these
+ # might need to be padded as well (if using the distributed optimizer).
+ data_start_index = 0
+ bucket_data_start_index = data_start_index
+ bucket_params = set()
+ self.bucket_indices = []
+ per_bucket_numel_unpadded = []
+ bucket_id = 0
+
+ def _create_new_bucket(data_end_index: int) -> int:
+ """
+ Create the bucket_id'th bucket with collected bucket_params, starting at
+ bucket_data_start_index.
+ """
+ nonlocal bucket_data_start_index, bucket_params, bucket_id
+ per_bucket_numel_unpadded.append(data_end_index - bucket_data_start_index)
+ data_end_index = _pad_end_of_bucket_if_needed(data_end_index)
+ # Update bucket metadata.
+ self.bucket_indices.append((bucket_data_start_index, data_end_index))
+ bucket_data_start_index = data_end_index
+ # Re-set bucket_params and increment bucket_id for next bucket.
+ bucket_params = set()
+ bucket_id += 1
+ # Return the potentially padded data_end_index.
+ return data_end_index
+
+ for param in params[::-1]:
+ # Iterate through parameters in reverse order to roughly follow backprop order,
+ # and skip parameters that don't require gradients.
+ if not param.requires_grad:
+ continue
+ this_numel = param.data.nelement()
+ data_start_index = _pad_start_of_param_if_needed(data_start_index)
+ data_end_index = data_start_index + this_numel
+
+ def _does_param_require_new_bucket(param):
+ """
+ Split shared embedding parameters into separate bucket if using distributed
+ optimizer that makes use of reduce-scatters instead of all-reduces.
+ This ensures that the first and last pipeline stage partition optimizer state
+ for the shared embedding parameters the same way across DP replicas, allowing
+ the DP reduce-scatter to be before the embedding all-reduce.
+ """
+ return (
+ getattr(param, "shared_embedding", False)
+ and self.ddp_config.use_distributed_optimizer
+ )
+
+ # Create bucket with already collected parameters if current param needs its own bucket.
+ if _does_param_require_new_bucket(param) and len(bucket_params) > 0:
+ # We are creating a bucket for the already accumulated parameters, whose params
+ # end at the current data_start_index.
+ if self.ddp_config.use_distributed_optimizer:
+ # data_start_index should already be padded.
+ assert data_start_index % self.data_parallel_world_size == 0
+ _create_new_bucket(data_start_index)
+
+ self.param_index_map[param] = (
+ data_start_index,
+ data_end_index,
+ bucket_id,
+ )
+ bucket_params.add(param)
+
+ # If we have enough elements already or the current param is part of the shared embedding
+ # layer and needs a separate bucket, form a new bucket.
+ if (
+ bucket_size is not None
+ and (data_end_index - bucket_data_start_index) >= bucket_size
+ ) or _does_param_require_new_bucket(param):
+ data_end_index = _create_new_bucket(data_end_index)
+ data_start_index = data_end_index
+
+ # Add remaining params to a new bucket.
+ if len(bucket_params) > 0:
+ data_end_index = _create_new_bucket(data_end_index)
+
+ # Next, create underlying storage for buffer (with numel elements that includes
+ # padding as necessary).
+ self.numel = data_end_index
+ self.numel_unpadded = sum(per_bucket_numel_unpadded)
+ assert self.numel_unpadded <= self.numel
+ if self.ddp_config.use_distributed_optimizer:
+ assert self.numel % self.data_parallel_world_size == 0
+ else:
+ assert self.numel == self.numel_unpadded
+
+ self.param_data = None
+ # Only re-map param tensors if using distributed optimizer.
+ if self.ddp_config.use_distributed_optimizer:
+ self.param_data = torch.zeros(
+ self.numel,
+ dtype=self.param_dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ self.grad_data = torch.zeros(
+ self.numel,
+ dtype=self.grad_dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+
+ # Finally, map param.data and param.main_grad fields to buffers.
+ bucket_params = set()
+ bucket_data_start_index = 0
+ cur_bucket_id = 0
+ for param in params[::-1]:
+ if not param.requires_grad:
+ continue
+ data_start_index, data_end_index, bucket_id = self.param_index_map[param]
+
+ # Assign param.data to appropriate segment of self.param_data.
+ if self.param_data is not None:
+ old_param_data = param.data
+ param.data = self._get(
+ param.data.shape, data_start_index, buffer_type=BufferType.PARAM
+ )
+ assert old_param_data._base is None
+ # Copy tensor values (from initialization or checkpoint).
+ param.data.detach().copy_(old_param_data)
+ del old_param_data
+
+ param.main_grad = self._get(
+ param.data.shape, data_start_index, buffer_type=BufferType.GRAD
+ )
+ if bucket_id != cur_bucket_id:
+ bucket_data_end_index = _pad_end_of_bucket_if_needed(data_start_index)
+ self._set_bucket(
+ bucket_params=bucket_params,
+ start_index=bucket_data_start_index,
+ end_index=bucket_data_end_index,
+ numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
+ bucket_id=cur_bucket_id,
+ )
+ bucket_data_start_index = bucket_data_end_index
+ bucket_params = set()
+ assert cur_bucket_id + 1 == len(self.buckets)
+ assert bucket_id == cur_bucket_id + 1
+ cur_bucket_id = bucket_id
+ bucket_params.add(param)
+
+ # Add remaining params to a new bucket.
+ if len(bucket_params) > 0:
+ bucket_data_end_index = _pad_end_of_bucket_if_needed(data_end_index)
+ self._set_bucket(
+ bucket_params=bucket_params,
+ start_index=bucket_data_start_index,
+ end_index=bucket_data_end_index,
+ numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
+ bucket_id=cur_bucket_id,
+ )
+
+ # Log buckets for all PP stages.
+ log_strs = []
+ log_strs.append(
+ f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
+ )
+ for index, bucket in enumerate(self.buckets):
+ numel = 0
+ for param in bucket.params:
+ numel += param.data.nelement()
+ log_strs.append(f'Params for bucket {index + 1} ({numel} elements):')
+ for param in bucket.params:
+ log_strs.append(f'\t{param_to_name[param]}')
+ log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/__init__.py b/model/train/yoco_moe/mindspeed/core/fusions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/fused_bias_swiglu.py b/model/train/yoco_moe/mindspeed/core/fusions/fused_bias_swiglu.py
new file mode 100644
index 0000000000000000000000000000000000000000..496312c566575998a331873f0a0e3e8b61d6a21f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/fusions/fused_bias_swiglu.py
@@ -0,0 +1,17 @@
+import torch_npu
+
+
+def fused_swiglu(x):
+ return torch_npu.npu_swiglu(x, dim=-1)
+
+
+class SwiGLUFunction:
+ @staticmethod
+ def apply(x, *args):
+ return fused_swiglu(x)
+
+
+class BiasSwiGLUFunction:
+ @staticmethod
+ def apply(x, bias, *args):
+ return fused_swiglu(x + bias)
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/fused_layer_norm.py b/model/train/yoco_moe/mindspeed/core/fusions/fused_layer_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..89ea8614dd3780fdaf45e0d9543ead2a09b93fa5
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/fusions/fused_layer_norm.py
@@ -0,0 +1,25 @@
+import os
+from functools import wraps
+
+import torch
+
+
+class FusedLayerNormAffineFunction:
+ @staticmethod
+ def apply(input_, weight, bias, normalized_shape, eps):
+ return torch.nn.functional.layer_norm(input_, normalized_shape, weight, bias, eps)
+
+ @staticmethod
+ def forward(*args, **kwargs):
+ return FusedLayerNormAffineFunction.apply(*args, **kwargs)
+
+
+class FastLayerNormFN:
+ @staticmethod
+ def apply(input_, weight, bias, eps):
+ normalized_shape = torch.Size(weight.numel())
+ return torch.nn.functional.layer_norm(input_, normalized_shape, weight, bias, eps)
+
+
+def fused_layer_norm_affine(input_, weight, bias, normalized_shape, eps):
+ return torch.nn.functional.layer_norm(input_, normalized_shape, weight, bias, eps)
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/fused_softmax.py b/model/train/yoco_moe/mindspeed/core/fusions/fused_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ec0a1256689725e4c77e1ba76d253a7c1a249f0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/fusions/fused_softmax.py
@@ -0,0 +1,52 @@
+import torch
+import torch_npu
+from megatron.core.transformer.enums import AttnMaskType
+
+
+class ScaledUpperTriangMaskedSoftmax:
+ @staticmethod
+ def apply(input_, scale):
+ # npu_scaled_masked_softmax must be 4D
+ size = input_.size()
+ input_ = input_.view((1,) + tuple(size)).contiguous()
+ dummy_mask = torch.zeros(input_.size()).bool().npu()
+ output = torch_npu.npu_scaled_masked_softmax(input_, dummy_mask, scale, True)
+ return output.view(size).contiguous()
+
+
+class ScaledMaskedSoftmax:
+ @staticmethod
+ def apply(input_, mask, scale):
+ return torch_npu.npu_scaled_masked_softmax(input_, mask, scale, False)
+
+
+class ScaledSoftmax:
+ @staticmethod
+ def apply(input_, scale):
+ dummy_mask = torch.zeros(input_.size()).bool().npu()
+ return torch_npu.npu_scaled_masked_softmax(input_, dummy_mask, scale, False)
+
+
+def is_kernel_available(self, mask, b, np, sq, sk):
+ return (
+ self.scaled_masked_softmax_fusion # user want to fuse
+ and self.input_in_float16 # input must be fp16
+ and 32 < sk <= 4096 # sk must be 32 ~ 4096
+ and sq % 16 == 0 # sq must be divisor of 16
+ and sk % 16 == 0 # sk must be divisor of 16
+ )
+
+
+def forward_fused_softmax(self, input_, mask):
+ b, np, sq, sk = input_.size()
+ scale = self.scale if self.scale is not None else 1.0
+
+ if self.attn_mask_type == AttnMaskType.causal:
+ assert sq == sk, "causal mask is only for self attention"
+ # if use upper triangular mask
+ return torch_npu.npu_scaled_masked_softmax(input_, mask, scale, True)
+ else:
+ if mask is not None:
+ return torch_npu.npu_scaled_masked_softmax(input_, mask, scale, False)
+ else:
+ return ScaledSoftmax.apply(input_, scale)
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/npu_moe_token_permute.py b/model/train/yoco_moe/mindspeed/core/fusions/npu_moe_token_permute.py
new file mode 100644
index 0000000000000000000000000000000000000000..01ac70be45ce220dd2bcc4dd9e770bb1dd9525f5
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/fusions/npu_moe_token_permute.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from functools import wraps
+
+import torch
+
+from mindspeed.ops.npu_moe_token_permute import npu_moe_token_permute
+
+
+def permute_wrapper(fn):
+ @wraps(fn)
+ def wrapper(
+ tokens: torch.Tensor,
+ indices: torch.Tensor,
+ num_out_tokens: int = None,
+ padded_mode: bool = False
+ ) -> torch.Tensor:
+ return npu_moe_token_permute(tokens, indices, num_out_tokens, padded_mode)
+
+ return wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/npu_moe_token_unpermute.py b/model/train/yoco_moe/mindspeed/core/fusions/npu_moe_token_unpermute.py
new file mode 100644
index 0000000000000000000000000000000000000000..6141f817613e5a96e60f23b2c877b3631a969829
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/fusions/npu_moe_token_unpermute.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from functools import wraps
+
+import torch
+
+from mindspeed.ops.npu_moe_token_unpermute import npu_moe_token_unpermute
+
+
+def unpermute_wrapper(fn):
+ @wraps(fn)
+ def wrapper(
+ permuted_tokens: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ probs: torch.Tensor = None,
+ padded_mode: bool = False,
+ restore_shape: torch.Size = None,
+) -> torch.Tensor:
+ dtype = permuted_tokens.dtype
+ if probs is not None and probs.dtype != permuted_tokens.dtype:
+ # make sure permuted_tokens has the same dtype with probs.
+ permuted_tokens = permuted_tokens.to(probs.dtype)
+ return npu_moe_token_unpermute(
+ permuted_tokens, sorted_indices, probs, padded_mode=padded_mode, restore_shape=restore_shape).to(dtype)
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/rms_norm.py b/model/train/yoco_moe/mindspeed/core/fusions/rms_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd5dcb401ea29491fe02bc97c484905dd622d021
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/fusions/rms_norm.py
@@ -0,0 +1,51 @@
+import os
+from functools import wraps
+import torch
+import torch_npu
+from torch import nn
+from megatron.legacy.model.rms_norm import RMSNorm
+from megatron.training import get_args
+
+from mindspeed.core.tensor_parallel.mapping import reduce_from_tensor_model_parallel_region_nd
+
+
+def rms_norm_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ fn(self, *args, **kwargs)
+ _args = get_args()
+ self.use_fused_rmsnorm = _args.use_fused_rmsnorm
+ self.use_nd_matmul = _args.use_nd_matmul
+ dim = args[0] if len(args) > 0 else kwargs.get('dim')
+ if self.use_nd_matmul:
+ if self.use_fused_rmsnorm:
+ raise RuntimeError('nd_matmul does not support fused_rmsnorm temporarily')
+ self.tensor_model_parallel_size = _args.tensor_model_parallel_size
+ self.weight = torch.nn.Parameter(
+ torch.ones(dim // self.tensor_model_parallel_size)
+ )
+ return wrapper
+
+
+def rms_norm_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, x):
+ if int(os.getenv('NPU_ASD_ENABLE', '0')):
+ from torch_npu.utils import register_asd_hook
+ register_asd_hook(x, self.weight)
+ if self.use_fused_rmsnorm:
+ return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
+ return fn(self, x)
+ return wrapper
+
+
+def rms_norm_norm_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, x):
+ if self.use_nd_matmul:
+ pow_mean = x.pow(2).mean(-1, keepdim=True)
+ all_pow_mean = reduce_from_tensor_model_parallel_region_nd(pow_mean)
+ pow_mean = torch.div(all_pow_mean, self.tensor_model_parallel_size)
+ return x * torch.rsqrt(pow_mean + self.eps)
+ return fn(self, x)
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/fusions/rotary_pos_embedding.py b/model/train/yoco_moe/mindspeed/core/fusions/rotary_pos_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/memory/__init__.py b/model/train/yoco_moe/mindspeed/core/memory/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/__init__.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_apply.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_apply.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dd6ec95aaad92c7e8b8cefec2ea5b0e959e8d25
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_apply.py
@@ -0,0 +1,159 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+from typing import List
+from copy import deepcopy
+from megatron.training import print_rank_0
+from .adaptive_memory_tool import SingletonBase, LayerAction, ModuleAction, ContextKey as Key
+from .adaptive_memory_solver import AdaptMemGraphSolver
+from .adaptive_memory_swap_manager import SwapManager, transformer_layer_register_post_forward_hook, transformer_layer_register_pre_backward_hook
+from .adaptive_memory_profiling import RecomputeHook, AdaptiveMemoryProfiling
+
+
+class AdaptMemApplyManager(metaclass=SingletonBase):
+
+ def __init__(self):
+ self.no_adapt_modules = [] # modules which don't join policy selections
+ self.cur_module_index = 0 # DFS index
+
+ # optype 0 save_to_cache, 1 apply_to_context
+ def apply_op_to_context(self, adapt_policy_list: list, origin_context: dict):
+ if len(adapt_policy_list) == 0:
+ print_rank_0("adapt_policy_list Empty!")
+ return origin_context
+ context = deepcopy(origin_context)
+ # 1.get all layers by order
+ ordered_layers = []
+ self.get_ordered_layers(context, ordered_layers, True)
+ # 2.handle policy list
+ idx = 0
+ self.get_ordered_modules(ordered_layers[0][Key.SUBMODULES], [], 0)
+ for policy in adapt_policy_list:
+ n = policy[0]
+ adapt_nodes = []
+ if policy[1] == LayerAction.FULL_RECOMPUTE:
+ status = ModuleAction.RECOMPUTE
+ adapt_nodes = [status for _ in range(len(policy[2:]))]
+ for i in range(idx, idx + n):
+ ordered_layers[i][ModuleAction.RECOMPUTE.name] = True
+ elif policy[1] == LayerAction.FULL_SWAP:
+ status = ModuleAction.SWAP
+ adapt_nodes = [status for _ in range(len(policy[2:]))]
+ elif policy[1] == LayerAction.ADAPTIVE:
+ adapt_nodes = policy[2:]
+ for i in range(idx, idx + n):
+ self.apply_op_to_layer(ordered_layers[i], adapt_nodes, i)
+ idx += n
+
+ return context
+
+ def apply_op_to_layer(self, ordered_layer, adapt_nodes: list, layer_index: int):
+ if len(adapt_nodes) == 0:
+ # don't need any operations if adapt_nodes is empty
+ return
+ # get all modules of the current layer through DFS
+ ordered_module: List[dict] = []
+ if Key.SUBMODULES not in ordered_layer:
+ return
+ self.cur_module_index = 0
+ if layer_index == 0:
+ self.no_adapt_modules.clear()
+ self.get_ordered_modules(ordered_layer[Key.SUBMODULES], ordered_module, layer_index)
+
+ for i, nodes in enumerate(adapt_nodes):
+ if i >= len(ordered_module):
+ break
+ if Key.IS_FUNCTION in ordered_module[i]:
+ func_action = nodes
+ # add location infos for autofrad.function
+ AdaptMemGraphSolver().add_func_locations(layer_index, ordered_module[i][Key.NAME], func_action)
+ continue
+ if nodes == ModuleAction.RECOMPUTE:
+ ordered_module[i][ModuleAction.RECOMPUTE.name] = True
+ elif nodes == ModuleAction.SWAP:
+ ordered_module[i][ModuleAction.SWAP.name] = True
+
+ def get_ordered_layers(self, model: dict, ordered_layers: list, is_root_layer: bool = False):
+ # root module may have multiple layers due to vpp parallel
+ if is_root_layer:
+ if Key.SUBMODULES not in model:
+ return
+ for sub_model in model[Key.SUBMODULES]:
+ self.get_ordered_layers(sub_model, ordered_layers)
+ return
+
+ if Key.IS_ADAPT_LAYER in model:
+ for sub_layer in model[Key.SUBMODULES]:
+ ordered_layers.append(sub_layer)
+ if Key.SUBMODULES not in model:
+ return
+ for sub_model in model[Key.SUBMODULES]:
+ self.get_ordered_layers(sub_model, ordered_layers)
+
+ def get_ordered_modules(self, layer: dict, ordered_modules: list, layer_index: int):
+ for sub_layer in layer:
+ # The first layer judges through ['memory']
+ if layer_index == 0:
+ if Key.MEMORY in sub_layer:
+ ordered_modules.append(sub_layer)
+ else:
+ # use the DFS index as the unique identifier
+ self.no_adapt_modules.append(self.cur_module_index)
+ else:
+ if self.cur_module_index not in self.no_adapt_modules:
+ ordered_modules.append(sub_layer)
+
+ self.cur_module_index += 1
+ if Key.SUBMODULES in sub_layer:
+ self.get_ordered_modules(sub_layer[Key.SUBMODULES], ordered_modules, layer_index)
+
+ def apply_hook_to_model(self, models, context, pre_context, is_root_layer: bool = False):
+ if is_root_layer and isinstance(models, list):
+ layer_idx = 0
+ for model in models:
+ self.apply_hook_to_model(model, get_cur_layer_context(context, layer_idx), context)
+ layer_idx += 1
+ return
+ # pass autograd.function
+ if Key.IS_FUNCTION in context:
+ if Key.SUBMODULES in context:
+ for i in range(0, len(context[Key.SUBMODULES])):
+ self.apply_hook_to_model(models, context[Key.SUBMODULES][i], context)
+ return
+ # apply hooks for recompute models
+ if context.get(ModuleAction.RECOMPUTE.name, False):
+ models.no_checkpoint_adaptive_recompute_forward = models.forward
+ models.forward = RecomputeHook().hook_checkpoint_forward(models.forward)
+ RecomputeHook().recompute_modules.append(models)
+ print_rank_0('recompute hooked on %s' % models._get_name())
+ return
+ # apply hooks for swap modules
+ if context.get(ModuleAction.SWAP.name, False):
+ SwapManager().hook_prefetch_forward(models, '')
+ print_rank_0('swap hooked on %s' % models._get_name())
+ return
+ # apply hooks for oom swap
+ if Key.ALLOWED_ADAPT in context:
+ transformer_layer_register_post_forward_hook(models)
+ transformer_layer_register_pre_backward_hook(models)
+ SwapManager().hook_oom_rescue_forward(models)
+ print_rank_0('oom rescue hooked on %s' % models._get_name())
+
+ module_idx = 0
+ for name, module in models.named_children():
+ self.apply_hook_to_model(module, context[Key.SUBMODULES][module_idx], context)
+ module_idx += 1
+
+ def apply_new_adapt_policy(self, adapt_policy_list, context, models):
+ AdaptMemGraphSolver().func_locations.clear()
+ new_context = self.apply_op_to_context(adapt_policy_list, context)
+ self.apply_hook_to_model(models, new_context, "", True)
+
+
+# get layer by idx in root module
+def get_cur_layer_context(context, idx):
+ current_context = {}
+ for k, v in context.items():
+ if k == Key.SUBMODULES:
+ current_context[k] = [v[idx]]
+ continue
+ current_context[k] = v
+ return current_context
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_cache.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..88cdb022aa516ece9aa488cb4bc0066b1b8f39af
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_cache.py
@@ -0,0 +1,277 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import os
+import stat
+import sys
+import json
+import hashlib
+from typing import List
+from pathlib import Path
+
+import torch
+import torch_npu
+from megatron.training import get_args, print_rank_0
+from megatron.core import parallel_state
+
+import mindspeed
+from .adaptive_memory_tool import SingletonBase, ModuleAction, LayerAction
+
+
+class AdaptiveLayerMemPolicy:
+ def __init__(self, recompute=None, swap=None, memory=0.0, time=sys.maxsize, adapt_type=LayerAction.ADAPTIVE):
+ self.recompute: List[str] = recompute or []
+ self.swap: List[str] = swap or []
+ self.memory: float = memory
+ self.time = time
+ self.adapt_type = adapt_type
+
+ def get_modules_by_tag(self, tag):
+ if ModuleAction.RECOMPUTE == tag:
+ return self.recompute
+ elif ModuleAction.SWAP == tag:
+ return self.swap
+ else:
+ msg = f"unknown layer policy tag name:{tag}"
+ raise ValueError(msg)
+
+ @staticmethod
+ def parse_from_json(src_json):
+ alp = AdaptiveLayerMemPolicy(memory=src_json["memory"], time=src_json["time"], recompute=[], swap=[])
+ alp.recompute = [str(r) for r in src_json["recompute"]]
+ alp.swap = [str(r) for r in src_json["swap"]]
+ return alp
+
+ def identity(self) -> str:
+ self.sort_modules()
+ modules = ",".join(self.recompute) + ":" + ",".join(self.swap)
+ return hashlib.md5(modules.encode('utf-8')).hexdigest()
+
+ def sort_modules(self):
+ self.recompute.sort()
+ self.swap.sort()
+
+ def __eq__(self, other):
+ if not isinstance(other, AdaptiveLayerMemPolicy):
+ return False
+ if len(self.recompute) != len(other.recompute) or len(self.swap) != len(other.swap):
+ return False
+
+ # sort values before compare
+ self.sort_modules()
+ other.sort_modules()
+
+ return self.recompute == other.recompute and self.swap == other.swap
+
+ def __repr__(self):
+ result = {'recompute': self.recompute, 'swap': self.swap, 'memory': self.memory, 'time': self.time, 'adapt_type': self.adapt_type}
+ return str(result)
+
+
+class AdaptiveModelMemPolicy:
+ def __init__(self, policy_type, polices, memory=0.0, time=sys.maxsize):
+ self.policy_type: str = policy_type
+ self.polices: List[AdaptiveLayerMemPolicy] = polices
+ self.memory: float = memory
+ self.time = time
+
+ def __post_init__(self):
+ if self.policy_type not in ["normal", "oom"]:
+ raise ValueError(f"unknown policy type:{self.policy_type}, {self.__repr__()}")
+
+ def __repr__(self):
+ return str(self.polices)
+
+ def to_json(self):
+ return json.dumps(self, default=lambda x: x.__dict__, sort_keys=True)
+
+ @staticmethod
+ def parse_from_json(src_json):
+ amp = AdaptiveModelMemPolicy(policy_type=src_json["policy_type"], polices=[])
+ amp.polices = [AdaptiveLayerMemPolicy.parse_from_json(p) for p in src_json["polices"]]
+ return amp
+
+ def __eq__(self, other):
+ if not isinstance(other, AdaptiveModelMemPolicy):
+ return False
+ if self.policy_type != other.policy_type or len(self.polices) != len(other.polices):
+ return False
+
+ cur_hash = sorted([x.identity() for x in self.polices])
+ other_hash = sorted([x.identity() for x in other.polices])
+ return cur_hash == other_hash
+
+
+class PolicyCacheManager(metaclass=SingletonBase):
+
+ def __init__(self):
+ self.local_file_name_list = []
+ self.normal_policy_cache: List[AdaptiveModelMemPolicy] = []
+ self.oom_policy_cache: List[AdaptiveModelMemPolicy] = []
+
+ def load_cache_file(self):
+ self.local_file_name_list = self._buildup_filename()
+ self.load_stage_cache_file()
+
+ def load_stage_cache_file(self):
+ cur_pp_rank = parallel_state.get_pipeline_model_parallel_rank()
+ if not os.path.isfile(self.local_file_name_list[cur_pp_rank]):
+ print_rank_0(f"load history oom policy False!!!!!!!!: {self.local_file_name_list[cur_pp_rank]}")
+ return
+
+ with open(self.local_file_name_list[cur_pp_rank], "r") as f:
+ for line in f:
+ json_format = json.loads(line)
+ policy: AdaptiveModelMemPolicy = AdaptiveModelMemPolicy.parse_from_json(json_format)
+ self.oom_policy_cache.append(policy)
+ print_rank_0(f"load history oom policy Success!!!!!!!!: {self.local_file_name_list[cur_pp_rank]}")
+
+ @staticmethod
+ def _get_version_file(src_path, key, version_file_name):
+ version_path = src_path[:src_path.index(key) + len(key)]
+ return os.path.join(version_path, version_file_name)
+
+ def _get_software_version(self):
+ torch_version: str = torch.__version__
+ torch_npu_version: str = torch_npu.__version__
+
+ library_path = os.environ.get("LD_LIBRARY_PATH").split(":")
+ ascend_toolkit_path = next((x for x in library_path if "ascend-toolkit" in x), None)
+ driver_path = next((x for x in library_path if "driver" in x), None)
+ if ascend_toolkit_path is None or driver_path is None:
+ return {}
+
+ ascend_toolkit_version_file = self._get_version_file(ascend_toolkit_path, "ascend-toolkit", "version.cfg")
+ driver_version_file = self._get_version_file(driver_path, "driver", "version.info")
+ if not os.path.isfile(ascend_toolkit_version_file) or not os.path.isfile(driver_version_file):
+ return {}
+
+ with open(ascend_toolkit_version_file, "r") as f:
+ f.readline()
+ ascend_version = f.readline()
+
+ with open(driver_version_file, "r") as f:
+ driver_version = f.readline()
+
+ return {
+ "torch": torch_version,
+ "torch_npu": torch_npu_version,
+ "ascend_toolkit": ascend_version,
+ "driver": driver_version
+ }
+
+ def _scan_dir_recursively(self, dir_name, md5s):
+ with os.scandir(dir_name) as it:
+ for entry in it:
+ if entry.is_dir(follow_symlinks=False):
+ self._scan_dir_recursively(entry.path, md5s)
+ elif entry.is_file(follow_symlinks=False):
+ if not entry.path.endswith(".py"):
+ return
+ md5_instance = hashlib.md5()
+ with open(entry.path, "rb") as f:
+ md5_instance.update(f.read())
+ md5s.append(md5_instance.hexdigest())
+
+ def _get_source_code_hash(self):
+ mindspeed_path, = mindspeed.__path__
+ md5s = []
+ self._scan_dir_recursively(mindspeed_path, md5s)
+ sorted(md5s)
+ md5_instance = hashlib.md5()
+ for x in md5s:
+ md5_instance.update(x.encode('utf-8'))
+ return md5_instance.hexdigest()
+
+ def _buildup_filename(self):
+ args = get_args()
+ gbs = args.global_batch_size
+ mbs = args.micro_batch_size
+ seq_len = args.seq_length
+ hidden = args.hidden_size
+ tp = 1 if not args.tensor_model_parallel_size else args.tensor_model_parallel_size
+ cp = 1 if not args.context_parallel_size else args.context_parallel_size
+ sp = 1 if not args.sequence_parallel else tp
+ ep = 1 if not args.expert_model_parallel_size else args.expert_model_parallel_size
+ pp = 1 if not args.pipeline_model_parallel_size else args.pipeline_model_parallel_size
+ world_size = args.world_size
+ dp = world_size // tp // cp // pp
+
+ arguments = {
+ "global_batch_size": gbs,
+ "micro_batch_size": mbs,
+ "sequence_len": seq_len,
+ "hidden": hidden,
+ "tp": tp, "cp": cp, "sp": sp, "ep": ep, "dp": dp,
+ "world_size": world_size,
+ "source_hash": self._get_source_code_hash()
+ }
+ software_versions = self._get_software_version()
+ arguments.update(software_versions)
+ args_content = json.dumps(arguments, sort_keys=True)
+ args_md5 = hashlib.md5(args_content.encode('utf-8')).hexdigest()
+
+ mindspeed_home = os.path.dirname(os.path.dirname(mindspeed.__file__))
+ adaptive_home = os.path.join(mindspeed_home, "adaptive_mem")
+ Path(adaptive_home).mkdir(parents=True, exist_ok=True)
+ file_abs_name_list = []
+
+ for i in range(pp):
+ file_name = f"b{mbs}_s{seq_len}_h{hidden}_tp{tp}_cp{cp}_w{world_size}_sp{sp}_ep{ep}_dp{dp}_stage{i}_{args_md5}.policy"
+ file_abs_name = os.path.join(adaptive_home, file_name)
+ file_abs_name_list.append(file_abs_name)
+
+ return file_abs_name_list
+
+ def _persistence(self):
+ cur_pp_rank = parallel_state.get_pipeline_model_parallel_rank()
+ cur_device_ranks = torch.cuda.device_count()
+ total_ranks = torch.distributed.get_world_size()
+ pp = 1 if not get_args().pipeline_model_parallel_size else get_args().pipeline_model_parallel_size
+ rank_per_pp = total_ranks // pp
+ # 不同节点的rank0需要存policy 以及 相同节点不同pp stage中的rank0需要存一下policy
+ if torch.distributed.get_rank() % cur_device_ranks == 0 or (
+ torch.distributed.get_rank() % rank_per_pp == 0 and torch.distributed.get_rank() % cur_device_ranks != 0):
+ flags = os.O_WRONLY | os.O_CREAT
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ with os.fdopen(os.open(self.local_file_name_list[cur_pp_rank], flags, mode), 'w') as fout:
+ fout.write("")
+ for p in self.oom_policy_cache:
+ fout.write(p.to_json() + "\n")
+
+
+ def add_normal_policy_cache(self, policy):
+ if policy in self.normal_policy_cache:
+ return
+
+ self.normal_policy_cache.append(policy)
+
+ def add_oom_policy_cache(self, policy):
+ if policy in self.oom_policy_cache:
+ return
+
+ self.oom_policy_cache.append(policy)
+ self._persistence()
+
+ def delete_normal_policy_cache(self, policy):
+ if policy not in self.normal_policy_cache:
+ return
+
+ self.normal_policy_cache.remove(policy)
+
+ def check_in_cache(self, policy: AdaptiveModelMemPolicy):
+ if policy is None:
+ raise ValueError(f"unexpect policy")
+
+ in_normal = next((x for x in self.normal_policy_cache if x == policy), None) is not None
+ return in_normal or next((x for x in self.oom_policy_cache if x == policy), None) is not None
+
+ def check_in_normal_cache(self, policy: AdaptiveModelMemPolicy):
+ if policy is None:
+ raise ValueError(f"unexpect policy")
+
+ return next((x for x in self.normal_policy_cache if x == policy), None) is not None
+
+ def check_in_oom_cache(self, policy: AdaptiveModelMemPolicy):
+ if policy is None:
+ raise ValueError(f"unexpect policy")
+
+ return next((x for x in self.oom_policy_cache if x == policy), None) is not None
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_function.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e68161f9f51aa7fa495f98c0636d980f477aab4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_function.py
@@ -0,0 +1,136 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+
+from copy import copy
+from typing import List, Any
+
+import torch
+from megatron.core.tensor_parallel.random import checkpoint
+
+from megatron.training import print_rank_0, get_args
+from megatron.core.num_microbatches_calculator import get_num_microbatches
+from megatron.core.tensor_parallel.random import get_cuda_rng_tracker
+from megatron.core import parallel_state as ps
+
+from mindspeed.core.tensor_parallel.random import _set_cuda_rng_state
+from .adaptive_memory_profiling import AdaptiveMemoryProfiling
+from .adaptive_memory_solver import AdaptMemGraphSolver
+from .adaptive_memory_prefetch import AdaptiveMemoryPrefetch, pre_forward_func
+from .adaptive_memory_tool import AdaptiveStepMgr, SingletonBase, ModuleAction, BYTES_PER_MB, ContextKey as Key
+from .adaptive_memory_tool import FuncLocationMgr, ForwardCounter
+from .adaptive_memory_swap_manager import SwapManager
+
+
+class FunctionCtxMgr(metaclass=SingletonBase):
+ def __init__(self):
+ self._ctx_dict = {}
+ self._child_dict = {}
+
+ def update_ctx(self, func_name, new_ctx, child_name):
+ if func_name not in self._ctx_dict:
+ self._ctx_dict[func_name] = new_ctx
+ self._ctx_dict[func_name][Key.FORWARD_CNT] = 1
+ self._ctx_dict[func_name][Key.AVG_TIME] = new_ctx[Key.PRE_TOTAL_TIME]
+ self._ctx_dict[func_name][Key.IS_FUNCTION] = True
+ else:
+ target_ctx = self._ctx_dict[func_name]
+ target_ctx[Key.FORWARD_CNT] += 1
+ target_ctx[Key.PRE_TOTAL_TIME] += new_ctx[Key.PRE_TOTAL_TIME]
+ target_ctx[Key.AVG_TIME] = target_ctx[Key.PRE_TOTAL_TIME] / target_ctx[Key.FORWARD_CNT]
+
+ if func_name not in self._child_dict:
+ self._child_dict[func_name] = child_name
+
+ def ctx_iter(self):
+ for key in self._ctx_dict.keys():
+ yield self._ctx_dict.get(key), self._child_dict.get(key)
+
+
+class FunctionProfilingWrapper:
+ def __init__(self, function):
+ self._function = function
+ self._ctx = {Key.NAME: function.__name__}
+
+ self.start_event = torch.npu.Event(enable_timing=True)
+ self.end_evnet = torch.npu.Event(enable_timing=True)
+
+ def _pre_process(self, *args):
+ self._ctx[Key.PREFIX_NAME] = FuncLocationMgr().get_latest_name()
+ self._ctx[Key.DEEP] = len(self._ctx[Key.PREFIX_NAME].split("."))
+ self._ctx[Key.IS_MODLUE_OF_LAYER0] = True
+ FuncLocationMgr().set_function_in_stack()
+
+ self._ctx[Key.INPUT] = AdaptiveMemoryProfiling().cal_input_output_size(args) / BYTES_PER_MB
+ self._ctx[Key.MEMORY] = torch.npu.memory_allocated() - self._ctx[Key.INPUT]
+ self.start_event.record()
+
+ def _post_process(self, outputs):
+ self.end_evnet.record()
+ torch.npu.synchronize()
+ self._ctx[Key.PRE_TOTAL_TIME] = self.start_event.elapsed_time(self.end_evnet)
+ self._ctx[Key.OUTPUT] = AdaptiveMemoryProfiling().cal_input_output_size(outputs) / BYTES_PER_MB
+ self._ctx[Key.MEMORY] = (torch.npu.memory_allocated() - self._ctx[Key.MEMORY]) / BYTES_PER_MB
+
+ child_name = FuncLocationMgr().get_function_location(self._ctx[Key.PREFIX_NAME])
+ FunctionCtxMgr().update_ctx(self._function.__name__, self._ctx, child_name)
+
+ def run_profiling(self, *args, **kwargs):
+ self._pre_process(args)
+ outputs = self._function.apply(*args, **kwargs)
+ self._post_process(outputs)
+ return outputs
+
+
+def pack_hook(tensor):
+ return SwapManager().prefetch_pack(tensor)
+
+
+def unpack_hook(swap_tensor):
+ return SwapManager().prefetch_unpack(swap_tensor)
+
+
+def pre_profiling_process(module_name):
+ pre_forward_func(module_name, False)
+
+
+def post_profiling_process(module_name):
+ AdaptiveMemoryPrefetch().sync_d2h_for_recording_time(module_name, True)
+
+
+def wrap_swap_profiling(function, module_name, *args):
+ pre_profiling_process(module_name)
+ with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
+ outputs = function.apply(*args)
+ post_profiling_process(module_name)
+ return outputs
+
+
+def wrap_function(function, *args):
+ with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
+ return function.apply(*args)
+
+
+def adapt_mem_func_wrapper(fc_class, *args):
+ if not issubclass(fc_class, torch.autograd.Function):
+ raise TypeError("adapt_mem_func_wrapper only support subclass of torch.autograd.Function")
+ cnt = ForwardCounter().get_count()
+ is_first_layer = FuncLocationMgr().is_first_layer
+ if AdaptiveStepMgr().is_recompute_profiling_step() and is_first_layer:
+ if fc_class.__name__ not in AdaptiveMemoryPrefetch().function_list:
+ AdaptiveMemoryPrefetch().function_list.append(fc_class.__name__)
+ return FunctionProfilingWrapper(fc_class).run_profiling(*args)
+ elif AdaptiveStepMgr().is_swap_profiling_step() and is_first_layer: # recording swap profiling
+ if FunctionCtxMgr()._ctx_dict.get(fc_class.__name__)[Key.DEEP] == AdaptiveMemoryPrefetch().function_swap_profiling_deep:
+ module_full_name = FunctionCtxMgr()._ctx_dict.get(fc_class.__name__)[Key.PREFIX_NAME] + "." + fc_class.__name__
+ return wrap_swap_profiling(fc_class, module_full_name, *args)
+ elif AdaptiveStepMgr().is_swap_profiling_done() and not AdaptiveMemoryPrefetch().is_stable_apply and is_first_layer:
+ if fc_class.__name__ in AdaptiveMemoryPrefetch().prefetch_function_list:
+ return wrap_function(fc_class, *args)
+ elif AdaptiveStepMgr().is_all_profiling_done() and AdaptiveMemoryPrefetch().is_stable_apply: # do one of prefetch/recompute/swap
+ action = AdaptMemGraphSolver().get_func_action(fc_class.__name__, cnt - 1)
+ if action == ModuleAction.RECOMPUTE:
+ def fc_class_apply():
+ return fc_class.apply(*args)
+ return checkpoint(fc_class_apply, False)
+ elif action == ModuleAction.SWAP:
+ return wrap_function(fc_class, *args)
+ return fc_class.apply(*args) # do default function.apply
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_opt.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_opt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ffeb4b3ee357e3a5d4fb11bd3fa93907ca6e992
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_opt.py
@@ -0,0 +1,212 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+from functools import wraps
+
+import acl
+import torch
+import torch.nn
+
+from megatron.training import print_rank_0
+from megatron.core.num_microbatches_calculator import get_num_microbatches
+from .adaptive_memory_profiling import AdaptiveMemoryProfiling, RecomputeHook
+from .adaptive_memory_solver import AdaptMemGraphSolver
+from .adaptive_memory_policy import AdaptMemPolicyManager
+from .adaptive_memory_prefetch import AdaptiveMemoryPrefetch
+from .adaptive_memory_tool import AdaptiveStepMgr, SingletonBase, ContextKey as Key, ForwardCounter
+from .adaptive_memory_function import FunctionCtxMgr
+from .adaptive_memory_swap_manager import SwapManager, LayerProfilingHook
+from .adaptive_memory_apply import AdaptMemApplyManager
+from .cpu_binder import bind_cpus
+from .adaptive_memory_cache import PolicyCacheManager
+
+
+class AdaptiveMemoryOpt(metaclass=SingletonBase):
+
+ @staticmethod
+ def reset_all_adapt_mem_hooks():
+ if not AdaptiveStepMgr().is_recompute_profiling_step():
+ AdaptiveMemoryProfiling().reset_profiling_all_hooks()
+
+ if AdaptiveMemoryOpt.is_policy_stable():
+ AdaptiveMemoryOpt.reset_final_rescue_hooks()
+
+ @staticmethod
+ def is_policy_stable():
+ # current policy run 10 more steps unchanged is a stable policy
+ return AdaptiveStepMgr().get_cur_step() >= AdaptMemGraphSolver().remove_swap_manager_hook_step
+
+ @staticmethod
+ def reset_final_rescue_hooks():
+ SwapManager().reset_oom_rescue_hooked_modules()
+
+ @staticmethod
+ def reset_adapt_mem_modules():
+ RecomputeHook().reset_recompute_modules() # clear recompute modules
+ AdaptiveMemoryProfiling().reset_profiling_all_hooks() # clear profiling all hook
+ AdaptiveMemoryPrefetch().reset_adaptive_prefetch_all_hooks() # clear adaptive prefetch all hook
+ SwapManager().reset_all_for_oom_rescue() # clear all hook and tensor in oom rescue
+
+ def set_adapt_mem_hook(self, models):
+ torch.npu.synchronize()
+ AdaptiveMemoryProfiling().record_time()
+ context = AdaptiveMemoryProfiling().context
+ # reset auto_function list
+ if not AdaptiveMemoryPrefetch().is_stable_apply:
+ AdaptiveMemoryPrefetch().function_swap_profiling_deep = 0
+ AdaptiveMemoryPrefetch().prefetch_function_list = []
+ AdaptiveMemoryPrefetch().prefetch_module_dict.clear()
+
+ if AdaptiveStepMgr().is_recompute_profiling_step():
+ if AdaptiveStepMgr().is_last_recompute_profiling_step():
+ # insert function profiling to context
+ for ctx, child in FunctionCtxMgr().ctx_iter():
+ AdaptiveMemoryProfiling().insert_func_profiling(ctx, child)
+ # update params when has function
+ if len(FunctionCtxMgr()._ctx_dict):
+ update_swap_profiling_step_and_deep_list()
+
+ # clear recompute profiling hook
+ AdaptiveMemoryProfiling().reset_profiling_hooks()
+ AdaptiveMemoryPrefetch().reset_adaptive_prefetch_all_hooks()
+ # apply layer profiling hook for following steps
+ LayerProfilingHook().apply_layer_profiling_hook(AdaptiveMemoryProfiling().layer0_module)
+ return
+
+ if AdaptiveStepMgr().is_layer_profiling_step():
+ if AdaptiveStepMgr().is_last_layer_profiling_step():
+ SwapManager().forward_time = LayerProfilingHook().get_single_layer_time()
+ LayerProfilingHook().reset_layer_profiling_hook()
+ LayerProfilingHook().forward_time_list.clear()
+ print_rank_0(f'forward time is {SwapManager().forward_time}')
+ config = AdaptiveMemoryPrefetch().config
+ AdaptiveMemoryPrefetch().register_recursive_apply_prefetch(config, models, context)
+ return
+
+ # update swap profiling stats
+ if AdaptiveStepMgr().is_swap_profiling_step():
+ AdaptiveMemoryPrefetch().update_ctx(models, context)
+
+
+ if AdaptiveStepMgr().is_swap_profiling_done() and not AdaptiveMemoryPrefetch().is_stable_apply:
+ AdaptiveMemoryPrefetch().adaptive_select_module(models, context)
+ if not AdaptiveMemoryPrefetch().is_stable_apply:
+ return
+
+ if AdaptMemGraphSolver().need_prepare_solver:
+ # reduce max_device_memory and generate all policy combinations at first solver step
+ AdaptMemGraphSolver().reduce_device_memory(context[Key.DEVICE_MEMORY])
+ AdaptMemPolicyManager().prepare_policy(context)
+ AdaptMemGraphSolver().prepare_solver(context)
+
+ AdaptMemGraphSolver().check_cur_adapt_policy()
+ print_rank_0("==================== ADAPTIVE-MEMORY Report START====================")
+ adapt_policy_list = AdaptMemGraphSolver().solve_adapt_mem_policy()
+ print_rank_0("==================== ADAPTIVE-MEMORY Report End ====================")
+ if adapt_policy_list is not None:
+ self.reset_adapt_mem_modules()
+ AdaptMemApplyManager().apply_new_adapt_policy(adapt_policy_list, context, models)
+ print_rank_0(f"ADAPTIVE MEMORY OPTIMIZATION apply policy done")
+
+ def hook_adapt_mem_step(self, step_func, models):
+ def custom_adapt_mem_step(*args, **kwargs):
+ try:
+ result = step_func(*args, **kwargs) # cur step is done after calling step_func
+ if AdaptMemPolicyManager().is_stable_mem_policy() or AdaptiveStepMgr().is_skipping_step():
+ return result
+
+ AdaptiveMemoryProfiling().update_whole_model_memory()
+ AdaptMemPolicyManager().update_hccl_memory()
+ self.set_adapt_mem_hook(models)
+
+ return result
+ finally:
+ AdaptiveStepMgr().incr_step() # incr step num after step_func and adapting
+
+ return custom_adapt_mem_step
+
+
+def addup_allowed_mem_adapt_module(module):
+ AdaptiveMemoryProfiling().addup_allowed_mem_adapt_profiling_module(module)
+
+
+def layer_beginning_callback_forward(module, *args, **kwargs):
+ ForwardCounter().incr_cnt()
+
+
+def register_custom_hooks(modules):
+ for module in modules:
+ _register_one_module(module)
+
+
+def _register_one_module(module):
+ allowed_list = AdaptiveMemoryProfiling().get_allowed_adapt_module()
+ if any(isinstance(module, a) for a in allowed_list):
+ module.register_forward_pre_hook(layer_beginning_callback_forward)
+
+ for name, child in module.named_children():
+ if isinstance(child, torch.nn.ModuleList):
+ for idx, sub_child in enumerate(child):
+ _register_one_module(sub_child)
+ else:
+ _register_one_module(child)
+
+
+def cal_swap_profiling_step(num_micro_batches):
+ swap_depth = AdaptiveMemoryPrefetch().prefetch_deep_end - AdaptiveMemoryPrefetch().prefetch_deep_start + 1
+ swap_profiling_times = 4
+ swap_profiling_steps = swap_profiling_times // num_micro_batches
+ if swap_profiling_times % num_micro_batches != 0:
+ swap_profiling_steps += 1
+ return swap_profiling_steps * swap_depth * AdaptiveMemoryPrefetch().each_depth_run_times
+
+
+def cal_profiling_step(num_micro_batches):
+ recompute_profiling_times = 4
+ min_profiling_steps = 5
+ recompute_profiling_steps = recompute_profiling_times // num_micro_batches
+ if recompute_profiling_times % num_micro_batches != 0:
+ recompute_profiling_steps += 1
+ return max(min_profiling_steps, recompute_profiling_steps)
+
+
+def init_profiling_steps():
+ num_micro_batches = get_num_microbatches()
+ # cal profiling step
+ recompute_profiling_steps = cal_profiling_step(num_micro_batches)
+ # cal swap profiling step
+ swap_profiling_steps = cal_swap_profiling_step(num_micro_batches)
+ # init step
+ AdaptiveStepMgr().init_steps(recompute_profiling_steps, swap_profiling_steps)
+ print_rank_0(f"init profiling steps, recompute:{recompute_profiling_steps}, swap:{swap_profiling_steps}")
+
+
+def update_swap_profiling_step_and_deep_list():
+ # update swap profiling step
+ swap_profiling_steps = cal_swap_profiling_step(get_num_microbatches())
+ # update deep_list
+ AdaptiveMemoryPrefetch().solve_prefetch_config()
+ AdaptiveStepMgr().init_steps(AdaptiveStepMgr().recompute_profiling_steps, swap_profiling_steps)
+ print_rank_0(f"update profiling steps, recompute:{AdaptiveStepMgr().recompute_profiling_steps}, swap:{swap_profiling_steps}, "
+ f"prefetch_deep_list:{AdaptiveMemoryPrefetch().prefetch_deep_list}, prefetch_hook_interval:{AdaptiveMemoryPrefetch().prefetch_hook_interval}")
+
+
+def setup_adapt_memory_optimizer_wrapper(setup_model_and_optimizer):
+ @wraps(setup_model_and_optimizer)
+ def wrapper(*args, **kwargs):
+ models, optimizer, opt_param_scheduler = setup_model_and_optimizer(*args, **kwargs)
+
+ optimizer.step = AdaptiveMemoryOpt().hook_adapt_mem_step(optimizer.step, models)
+ AdaptiveMemoryProfiling().construct_and_register_profiling_hooks(models)
+
+ init_profiling_steps()
+ register_custom_hooks(models)
+
+ AdaptiveMemoryPrefetch().solve_prefetch_config()
+ # 绑核
+ if "910B" in acl.get_soc_name() or "910A" in acl.get_soc_name():
+ bind_cpus(torch.cuda.device_count(), torch.cuda.current_device(), 0)
+ # 加载历史策略
+ PolicyCacheManager().load_cache_file()
+
+ return models, optimizer, opt_param_scheduler
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_policy.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fc3a4ad1e24b56431a5a62979e2e66aa5a99a3c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_policy.py
@@ -0,0 +1,185 @@
+import sys
+from copy import deepcopy
+
+import acl
+import torch
+from megatron.training import print_rank_0
+
+from .adaptive_memory_cache import AdaptiveLayerMemPolicy
+from .adaptive_memory_prefetch import AdaptiveMemoryPrefetch
+from .adaptive_memory_tool import AdaptiveStepMgr, SingletonBase, ModuleAction, LayerAction, ContextKey as Key
+
+
+class AdaptMemPolicyManager(metaclass=SingletonBase):
+
+ def __init__(self):
+ self.hccl_memory = 0
+
+ # policy combinations
+ self.policy_combinations = []
+ self.without_adapt_mem = 0.0
+ self.full_recompute_comb = None
+ self.full_swap_comb = None
+ self.without_adaptive_comb = None
+ # solve policy
+ self.adapt_modules_num = 0
+ self.total_adapt_memory = 0.0
+ self.module_layers_name = []
+ # adaptive prefetch
+ self.prefetch_parents_comb = []
+ self.memory_interval = 1
+
+ def prepare_policy(self, model_context):
+ self.traversal_model_context(model_context)
+ for comb in self.policy_combinations:
+ comb.memory = comb.memory + self.without_adapt_mem
+ # select policy that contains prefetch_parents_comb
+ self.select_policy_with_prefetch_parents_comb()
+
+ def traversal_model_context(self, context):
+ for layer_context in context.get(Key.SUBMODULES, []):
+ # 统计一下做自适应的总动态内存
+ if Key.IS_ADAPT_LAYER in layer_context and Key.MEMORY in context:
+ self.total_adapt_memory += context[Key.MEMORY]
+ if Key.ALLOWED_ADAPT in layer_context and Key.MEMORY in layer_context:
+ self.generate_full_combinations(layer_context, self.policy_combinations, "", 0, False)
+ return
+ else:
+ self.traversal_model_context(layer_context)
+
+ def generate_full_combinations(self, ctx, pre_policy_comb, pre_allow_adapt_n, idx, without_layer):
+ new_policy_comb = []
+ cycle_policy_comb = pre_policy_comb.copy()
+ if pre_allow_adapt_n:
+ ctx[Key.PREFIX_NAME] = remove_content_before_target(ctx[Key.PREFIX_NAME], pre_allow_adapt_n)
+ # root layers
+ if idx == 0:
+ self.build_initial_combinations(ctx)
+ pre_allow_adapt_n = ctx[Key.PREFIX_NAME] + '.'
+ self.prefetch_parents_comb = self.generate_prefetch_policy_combinations(pre_allow_adapt_n)
+ elif Key.MEMORY in ctx:
+ self.adapt_modules_num += 1
+ self.module_layers_name.append(ctx[Key.PREFIX_NAME] + "." + ctx[Key.NAME])
+ new_policy_comb.extend(self.build_combinations(ctx, pre_policy_comb, ModuleAction.SWAP, without_layer))
+ if ctx[Key.MEMORY] > ctx[Key.INPUT] + ctx[Key.OUTPUT] + self.memory_interval and not ctx[Key.IS_SWAP]:
+ new_policy_comb.extend(self.build_combinations(ctx, pre_policy_comb, ModuleAction.RECOMPUTE, without_layer))
+ if check_all_sub_same_mem(ctx):
+ same_mep_comb = pre_policy_comb.copy()
+ for sub in ctx.get(Key.SUBMODULES, []):
+ same_mep_comb = self.generate_full_combinations(sub, same_mep_comb, pre_allow_adapt_n, idx + 1, True)
+ new_policy_comb.extend(same_mep_comb)
+ cycle_policy_comb.extend(same_mep_comb)
+ else:
+ for sub in ctx.get(Key.SUBMODULES, []):
+ tmp_combs = self.generate_full_combinations(sub, cycle_policy_comb, pre_allow_adapt_n, idx + 1, False)
+ cycle_policy_comb.extend(tmp_combs)
+ new_policy_comb.extend(tmp_combs)
+ return new_policy_comb
+
+ def build_initial_combinations(self, context):
+ self.without_adapt_mem = context[Key.MEMORY]
+ self.full_recompute_comb = AdaptiveLayerMemPolicy(recompute=[context[Key.NAME]], swap=[],
+ memory=context[Key.INPUT] + context[Key.OUTPUT] - self.without_adapt_mem,
+ time=context[Key.AVG_TIME],
+ adapt_type=LayerAction.FULL_RECOMPUTE)
+ self.full_swap_comb = AdaptiveLayerMemPolicy(recompute=[], swap=[context[Key.NAME]],
+ memory=-context[Key.MODULE_SWAP_AVG_MEMORY],
+ time=context[Key.MODULE_SWAP_AVG_TIME], adapt_type=LayerAction.FULL_SWAP)
+ self.without_adaptive_comb = AdaptiveLayerMemPolicy(recompute=[], swap=[],
+ memory=0, time=0,
+ adapt_type=LayerAction.NONE)
+ self.policy_combinations.append(self.full_recompute_comb)
+ self.policy_combinations.append(self.without_adaptive_comb)
+
+ def generate_prefetch_policy_combinations(self, pre_allow_adapt_n):
+ prefetch_policy = AdaptiveLayerMemPolicy(time=0)
+ for module_name in AdaptiveMemoryPrefetch().need_swap_module_name:
+ suffix_name = remove_content_before_target(module_name, pre_allow_adapt_n)
+ prefetch_policy.swap.append(suffix_name)
+ return prefetch_policy
+
+
+ def build_combinations(self, context, pre_policy_combs, adapter_tag, without_cur_layer):
+ new_policy_combs = []
+ cur_policy_combs = pre_policy_combs.copy()
+ for policy_comb in cur_policy_combs:
+ new_policy_combs.append(self.build_one_combination(context, policy_comb, adapter_tag))
+ if without_cur_layer:
+ return new_policy_combs
+ single_policy_comb = self.build_one_combination(context, AdaptiveLayerMemPolicy(time=0), adapter_tag)
+ new_policy_combs.append(single_policy_comb)
+ return new_policy_combs
+
+ def build_one_combination(self, context, pre_policy_comb, adapter_tag):
+ layer_name = context[Key.PREFIX_NAME] + '.' + context[Key.NAME]
+ layer_list = pre_policy_comb.get_modules_by_tag(adapter_tag).copy()
+ policy_comb = AdaptiveLayerMemPolicy()
+ layer_list.append(layer_name)
+ if ModuleAction.RECOMPUTE == adapter_tag:
+ policy_comb.swap = pre_policy_comb.swap.copy()
+ policy_comb.recompute = layer_list
+ policy_comb.memory = pre_policy_comb.memory - context[Key.MEMORY] + context[Key.INPUT] + context[Key.OUTPUT]
+ policy_comb.time = pre_policy_comb.time + context[Key.AVG_TIME]
+ if ModuleAction.SWAP == adapter_tag:
+ policy_comb.recompute = pre_policy_comb.recompute.copy()
+ policy_comb.swap = layer_list
+ # if the module has swap information
+ if Key.MODULE_SWAP_AVG_MEMORY in context:
+ policy_comb.memory = pre_policy_comb.memory - context[Key.MODULE_SWAP_AVG_MEMORY]
+ if context[Key.IS_SWAP]:
+ # if swap doesn't waste time
+ policy_comb.time = pre_policy_comb.time
+ else:
+ policy_comb.time = pre_policy_comb.time + context[Key.MODULE_SWAP_AVG_TIME]
+ else:
+ policy_comb.memory = pre_policy_comb.memory
+ policy_comb.time = pre_policy_comb.time
+ self.policy_combinations.append(policy_comb)
+ return policy_comb
+
+ def select_policy_with_prefetch_parents_comb(self):
+ new_policy_comb = []
+ for policy_comb in self.policy_combinations:
+ if policy_comb.adapt_type != LayerAction.ADAPTIVE:
+ new_policy_comb.append(policy_comb)
+ elif self.is_contained_prefetch_parents_comb(self.prefetch_parents_comb.swap, policy_comb.swap):
+ new_policy_comb.append(policy_comb)
+ self.policy_combinations = new_policy_comb
+
+ def is_contained_prefetch_parents_comb(self, prefetch_parents_list, swap_list):
+ prefetch_parents_list_copy = prefetch_parents_list.copy()
+ prefetch_parents_list_copy.sort()
+ swap_list_copy = swap_list.copy()
+ swap_list_copy.sort()
+ return prefetch_parents_list_copy == swap_list_copy
+
+
+ def update_hccl_memory(self):
+ free, all_memory, _ = acl.rt.get_mem_info(1)
+ cur_hccl_memory = (all_memory - free - torch.npu.memory_reserved()) / 1024 / 1024
+ self.hccl_memory = max(cur_hccl_memory, self.hccl_memory)
+
+ def is_stable_mem_policy(self):
+ if not AdaptiveStepMgr().is_all_profiling_done():
+ return False
+ if not AdaptiveMemoryPrefetch().is_stable_apply:
+ return False
+ from .adaptive_memory_solver import AdaptMemGraphSolver
+ if not AdaptMemGraphSolver().is_stable_policy():
+ return False
+ return True
+
+
+def remove_content_before_target(path: str, prefix: str):
+ if path.startswith(prefix):
+ return path[len(prefix):]
+ else:
+ return path
+
+
+def check_all_sub_same_mem(context):
+ submodules = [child for child in context.get(Key.SUBMODULES, []) if Key.MEMORY in child]
+ for i in range(len(submodules) - 1):
+ if submodules[i][Key.MEMORY] != submodules[i + 1][Key.MEMORY]:
+ return False
+ return True
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_prefetch.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_prefetch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f07c6846bfa7032230c1530dfcb60634c90beb72
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_prefetch.py
@@ -0,0 +1,438 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import re
+import torch
+from megatron.training import print_rank_0, get_args
+from .adaptive_memory_tool import SingletonBase, FuncLocationMgr, broadcast_obj
+from .adaptive_memory_tool import AdaptiveStepMgr, ContextKey as Key
+from .adaptive_memory_swap_manager import SwapManager, transformer_layer_register_post_forward_hook, \
+ transformer_layer_register_pre_backward_hook, LayerProfilingHook
+
+
+class AdaptiveMemoryPrefetch(metaclass=SingletonBase):
+ def __init__(self):
+ self.modules_hooks = []
+ self.is_stable_apply = False
+ self.is_first_select_module = False
+ self.config = {
+ "pre_layer_full_name": "",
+ "cur_layer_name": "module",
+ }
+ self.chunk_num = 0
+ self.forward_time = 0
+ self.swap_time = 0
+ self.not_need_swap_module = []
+ self.need_swap_module_full_name = []
+ self.need_swap_module_name = []
+ self.need_swap_module_ctx = []
+ self.prefetch_module_dict = {}
+ self.abnormal_scenario_module_list = ["input_norm", "self_attention", "post_attention_norm"]
+ # 统计数据
+ self.prefetch_hook_interval = None
+ self.prefetch_deep_list = []
+ self.prefetch_deep_start = 0
+ self.prefetch_deep_end = 0
+ self.each_depth_run_times = 2
+ self.layer_list = []
+ self.swap_event_dict = {}
+ self.swap_memory_in_module_dict = {}
+ self.prefetch_module_event_dict = {}
+ # auto_function
+ self.function_swap_profiling_deep = 0
+ self.function_list = []
+ self.prefetch_function_list = []
+
+ def reset_prefetch_hooks(self):
+ SwapManager().reset_prefetch_hooked_modules()
+
+ def reset_module_hooks(self):
+ for hook_handle in self.modules_hooks:
+ hook_handle.remove()
+ self.modules_hooks.clear()
+
+ def reset_adaptive_prefetch_all_hooks(self):
+ self.reset_prefetch_hooks()
+ self.reset_module_hooks()
+ SwapManager().reset_post_layer_forward_and_pre_layer_backward_hooks()
+ LayerProfilingHook().reset_layer_profiling_hook()
+
+ def set_forward_time(self):
+ self.forward_time = SwapManager().forward_time
+
+ def _get_list_layers_context(self, ctx, idx):
+ current_ctx = {}
+ for k, v in ctx.items():
+ if k == Key.SUBMODULES:
+ current_ctx[k] = [v[idx]]
+ continue
+ current_ctx[k] = v
+ return current_ctx
+
+ def is_parent_module(self, key, keys):
+ if self.need_swap_module_name[key][-1] not in keys:
+ return True
+ else:
+ if not self.need_swap_module_name[self.need_swap_module_name[key][-1]][0]:
+ return True
+ else:
+ return False
+
+ # get prefetch config
+ def solve_prefetch_config(self):
+ self.prefetch_deep_list = [num for num in range(self.prefetch_deep_start, self.prefetch_deep_end + 1) for _ in range(self.each_depth_run_times)]
+ self.prefetch_hook_interval = len(self.prefetch_deep_list)
+ self.set_chunk_num()
+
+ def set_chunk_num(self):
+ all_args = get_args()
+ pp_size = all_args.pipeline_model_parallel_size or 1
+ vpp_size = all_args.virtual_pipeline_model_parallel_size or 1
+ num_prefetch = all_args.num_layers // pp_size
+ self.layer_list = [str(num) for num in range(0, num_prefetch)]
+ if vpp_size > 1:
+ if vpp_size <= num_prefetch:
+ self.chunk_num = vpp_size
+ else:
+ self.chunk_num = num_prefetch
+ else:
+ self.chunk_num = 1
+
+ def get_deep_index(self):
+ step = AdaptiveStepMgr().skip_steps + AdaptiveStepMgr().recompute_profiling_steps
+ return (AdaptiveStepMgr().get_cur_step() - step) % self.prefetch_hook_interval
+
+ # profiling for layer0
+ def prefetch_profiling_register(self, ctx, models, cur_layer_full_name):
+ if self.prefetch_deep_list[self.get_deep_index()] == ctx[Key.DEEP] and ctx.get(Key.IS_MODLUE_OF_LAYER0, False):
+ prefetch_register_forward_hook_for_recording_time(models, cur_layer_full_name)
+ prefetch_register_pre_forward_hook(models, cur_layer_full_name)
+ # register pack/unpack
+ print_rank_0(f"cur_step()={AdaptiveStepMgr().get_cur_step()}, is_recording=True, prefetch swap hook success: {cur_layer_full_name}")
+ SwapManager().hook_prefetch_forward(models, cur_layer_full_name)
+
+ if ctx.get(Key.IS_LAYER0_OF_MODULE0, False):
+ print_rank_0(f"cur_step()={AdaptiveStepMgr().get_cur_step()}, is_recording=True, prefetch forward and backward hook success: {cur_layer_full_name}")
+ prefetch_register_pre_forward_hook(models, cur_layer_full_name, True)
+ transformer_layer_register_post_forward_hook(models, True)
+ transformer_layer_register_pre_backward_hook(models)
+
+ def prefetch_profiling_register_for_function(self, ctx, cur_layer_full_name):
+ if self.prefetch_deep_list[self.get_deep_index()] == ctx[Key.DEEP]:
+ self.function_swap_profiling_deep = ctx[Key.DEEP]
+ print_rank_0(f"cur_step()={AdaptiveStepMgr().get_cur_step()}, {self.function_swap_profiling_deep=}, is_recording=True, prefetch swap hook success: {cur_layer_full_name}")
+
+ def prefetch_register(self, ctx, models, cur_layer_full_name):
+ if ctx.get(Key.IS_LAYER0_OF_MODULE0, False):
+ print_rank_0(f"is_recording=False, prefetch forward and backward hook success: cur_step()={AdaptiveStepMgr().get_cur_step()}, {cur_layer_full_name}")
+ transformer_layer_register_post_forward_hook(models)
+ transformer_layer_register_pre_backward_hook(models)
+ from .adaptive_memory_profiling import AdaptiveMemoryProfiling
+ LayerProfilingHook().apply_layer_profiling_hook(models)
+ if cur_layer_full_name in self.need_swap_module_name:
+ print_rank_0(f"is_recording=False, prefetch swap hook success: cur_step()={AdaptiveStepMgr().get_cur_step()}, {cur_layer_full_name}")
+ SwapManager().hook_prefetch_forward(models, cur_layer_full_name)
+ ctx[Key.IS_SWAP] = True
+ elif Key.AVG_TIME in ctx and Key.IS_MODLUE_OF_LAYER0 in ctx:
+ ctx[Key.IS_SWAP] = False
+
+ def prefetch_register_for_function(self, ctx, cur_layer_full_name):
+ if cur_layer_full_name in self.need_swap_module_name:
+ if ctx[Key.NAME] not in self.prefetch_function_list:
+ print_rank_0(f"is_recording=False, prefetch swap hook success: cur_step()={AdaptiveStepMgr().get_cur_step()}, {cur_layer_full_name}")
+ self.prefetch_function_list.append(ctx[Key.NAME])
+ ctx[Key.IS_SWAP] = True
+ else:
+ ctx[Key.IS_SWAP] = False
+
+
+ def register_recursive_apply_prefetch(self, config, models, ctx, is_prefetch_prof=True):
+ pre_layer_full_name = config["pre_layer_full_name"]
+ cur_layer_name = config["cur_layer_name"]
+ if cur_layer_name == Key.MODULE and isinstance(models, list):
+ idx = 0
+ for model in models:
+ if idx < self.chunk_num:
+ self.register_recursive_apply_prefetch(config, model, self._get_list_layers_context(ctx, idx), is_prefetch_prof)
+ idx += 1
+ return
+
+ # deal auto_function
+ if ctx.get(Key.IS_FUNCTION, False):
+ cur_layer_full_name = pre_layer_full_name + "." + ctx[Key.NAME]
+ if is_prefetch_prof:
+ # function profiling
+ self.prefetch_profiling_register_for_function(ctx, cur_layer_full_name)
+ else:
+ # function prefetch
+ self.prefetch_register_for_function(ctx, cur_layer_full_name)
+
+ config = {
+ "pre_layer_full_name": cur_layer_full_name,
+ "cur_layer_name": cur_layer_name,
+ }
+ self.register_recursive_apply_prefetch(config, models, ctx[Key.SUBMODULES][0], is_prefetch_prof)
+ return
+ cur_layer_full_name = pre_layer_full_name + '.' + cur_layer_name
+
+ if is_prefetch_prof:
+ self.prefetch_profiling_register(ctx, models, cur_layer_full_name)
+ else:
+ self.prefetch_register(ctx, models, cur_layer_full_name)
+
+ pre_layer_full_name = ctx[Key.PREFIX_NAME] + "." + ctx[Key.NAME]
+ idx = 0
+ for name, module in models.named_children():
+ config = {
+ "pre_layer_full_name": pre_layer_full_name,
+ "cur_layer_name": name,
+ }
+ self.register_recursive_apply_prefetch(config, module, ctx[Key.SUBMODULES][idx], is_prefetch_prof)
+ idx += 1
+
+ def _get_swappable_child_ctx(self, module_ctx):
+ res_ctxs, res_names = [], []
+ for child_ctx in module_ctx.get(Key.SUBMODULES, []):
+ if Key.AVG_TIME in child_ctx:
+ res_ctxs.append(child_ctx)
+ res_names.append(child_ctx[Key.PREFIX_NAME] + '.' + child_ctx[Key.NAME])
+ else:
+ sub_res_ctxs, sub_res_names = self._get_swappable_child_ctx(child_ctx)
+ res_ctxs.extend(sub_res_ctxs)
+ res_names.extend(sub_res_names)
+ return res_ctxs, res_names
+
+ def adjust_need_swap_module(self):
+ if len(self.need_swap_module_name) > 0:
+ last_module_ctx = self.need_swap_module_ctx.pop()
+ self.need_swap_module_name.pop()
+ child_module_ctxs, child_module_names = self._get_swappable_child_ctx(last_module_ctx)
+ self.need_swap_module_ctx.extend(child_module_ctxs)
+ self.need_swap_module_name.extend(child_module_names)
+
+ def is_no_module_to_swap(self):
+ return len(self.need_swap_module_name) == 0
+
+ def record_prefetch_time(self, context):
+ if len(list(self.prefetch_module_event_dict.keys())) == 0:
+ return
+ first_key = list(self.prefetch_module_event_dict.keys())[0]
+ if Key.PREFIX_NAME in context and Key.NAME in context and first_key == context[Key.PREFIX_NAME] + "." + context[Key.NAME]:
+ cur_event_list = self.prefetch_module_event_dict.pop(first_key)
+ for event_list in cur_event_list:
+ start, end = event_list[0], event_list[1]
+ cur_time = start.elapsed_time(end)
+ if Key.MODULE_FORWARD_TOTAL_TIME in context:
+ context[Key.MODULE_FORWARD_CNT] += 1
+ context[Key.MODULE_FORWARD_TOTAL_TIME] += cur_time
+ context[Key.MODULE_FORWARD_AVG_TIME] = context[Key.MODULE_FORWARD_TOTAL_TIME] / context[Key.MODULE_FORWARD_CNT]
+ else:
+ context[Key.MODULE_FORWARD_CNT] = 1
+ context[Key.MODULE_FORWARD_TOTAL_TIME] = cur_time
+ context[Key.MODULE_FORWARD_AVG_TIME] = cur_time
+ if Key.SUBMODULES not in context:
+ return
+ for submodule in context[Key.SUBMODULES]:
+ self.record_prefetch_time(submodule)
+
+ def record_swap_time(self, context):
+ if len(list(self.swap_event_dict.keys())) == 0:
+ return
+ first_key = list(self.swap_event_dict.keys())[0]
+ if Key.PREFIX_NAME in context and Key.NAME in context and first_key == context[Key.PREFIX_NAME] + "." + context[Key.NAME]:
+ cur_event_list = self.swap_event_dict.pop(first_key)
+ for event_list in cur_event_list:
+ start, end = event_list[0], event_list[1]
+ cur_time = start.elapsed_time(end)
+ if Key.MODULE_SWAP_TOTAL_TIME in context:
+ context[Key.MODULE_SWAP_CNT] += 1
+ context[Key.MODULE_SWAP_TOTAL_TIME] += cur_time
+ context[Key.MODULE_SWAP_AVG_TIME] = context[Key.MODULE_SWAP_TOTAL_TIME] / context[Key.MODULE_SWAP_CNT]
+ else:
+ context[Key.MODULE_SWAP_CNT] = 1
+ context[Key.MODULE_SWAP_TOTAL_TIME] = cur_time
+ context[Key.MODULE_SWAP_AVG_TIME] = cur_time
+ if Key.SUBMODULES not in context:
+ return
+ for submodule in context[Key.SUBMODULES]:
+ self.record_swap_time(submodule)
+
+ def record_swap_memory(self, context):
+ if len(list(self.swap_memory_in_module_dict.keys())) == 0:
+ return
+ first_key = list(self.swap_memory_in_module_dict.keys())[0]
+ if Key.PREFIX_NAME in context and Key.NAME in context and first_key == context[Key.PREFIX_NAME] + "." + context[Key.NAME]:
+ memory = self.swap_memory_in_module_dict.pop(first_key)
+ if Key.MODULE_SWAP_TOTAL_MEMORY in context:
+ context[Key.MODULE_SWAP_TOTAL_MEMORY] += memory
+ context[Key.MODULE_SWAP_AVG_MEMORY] = context[Key.MODULE_SWAP_TOTAL_MEMORY] / context[Key.MODULE_SWAP_CNT]
+ else:
+ context[Key.MODULE_SWAP_TOTAL_MEMORY] = memory
+ context[Key.MODULE_SWAP_AVG_MEMORY] = context[Key.MODULE_SWAP_TOTAL_MEMORY] / context[Key.MODULE_SWAP_CNT]
+ if Key.SUBMODULES not in context:
+ return
+ for submodule in context[Key.SUBMODULES]:
+ self.record_swap_memory(submodule)
+
+ def deal_not_need_swap_module(self, context):
+
+ if context.get(Key.IS_MODLUE_OF_LAYER0, False) and Key.IS_SWAP not in context:
+ context[Key.IS_SWAP] = False
+
+ if Key.IS_SWAP in context and not context[Key.IS_SWAP]:
+ self.not_need_swap_module.append(context[Key.PREFIX_NAME] + "." + context[Key.NAME])
+
+ if Key.SUBMODULES not in context:
+ return
+
+ for submodule in context[Key.SUBMODULES]:
+ self.deal_not_need_swap_module(submodule)
+
+ def clear_dict(self):
+ self.prefetch_module_event_dict.clear()
+ self.swap_event_dict.clear()
+ self.swap_memory_in_module_dict.clear()
+
+ def update_ctx(self, models, context):
+ if self.get_deep_index() % self.each_depth_run_times == 0:
+ self.record_prefetch_time(context)
+ self.record_swap_time(context)
+ self.record_swap_memory(context)
+ # 清除所有钩子
+ self.reset_adaptive_prefetch_all_hooks()
+ # 重新挂hook
+ if not AdaptiveStepMgr().is_swap_profiling_done():
+ self.register_recursive_apply_prefetch(self.config, models, context)
+ # 清空dict
+ self.clear_dict()
+
+ def init_swap_modules(self, context):
+ if Key.IS_LAYER0_OF_MODULE0 in context:
+ for child_ctx in context[Key.SUBMODULES]:
+ if Key.AVG_TIME in child_ctx:
+ self.need_swap_module_name.append(child_ctx[Key.PREFIX_NAME] + '.' + child_ctx[Key.NAME])
+ self.need_swap_module_ctx.append(child_ctx)
+ return
+ for child_ctx in context.get(Key.SUBMODULES, []):
+ self.init_swap_modules(child_ctx)
+
+ def adaptive_select_module(self, models, context):
+ if len(self.need_swap_module_name) == 0:
+ # 估计需要swap的module
+ self.set_forward_time()
+ self.init_swap_modules(context)
+ self.need_swap_module_name = broadcast_obj(self.need_swap_module_name)
+
+ if self.is_first_select_module and SwapManager().is_need_adjust_module():
+ # 微调swap module
+ print_rank_0(f"start adjust swap module, forward time is {LayerProfilingHook().get_single_layer_time()}")
+ self.adjust_need_swap_module()
+ if self.is_no_module_to_swap():
+ # 处理异常场景
+ self.is_stable_apply = True
+ elif self.is_first_select_module and not SwapManager().is_need_adjust_module():
+ print_rank_0(f"swap is stable, step={AdaptiveStepMgr().get_cur_step()}, "
+ f"forward time is {LayerProfilingHook().get_single_layer_time()}")
+ self.is_stable_apply = True
+
+ self.is_first_select_module = True
+ # 移除preftech的所有hook
+ self.reset_adaptive_prefetch_all_hooks()
+ # 重新挂preftech的钩子
+ self.register_recursive_apply_prefetch(self.config, models, context, False)
+ # 清空
+ self.clear_dict()
+ LayerProfilingHook().forward_time_list.clear()
+
+ def sync_d2h_for_recording_time(self, module_name, is_function=False):
+ # 每个module前向结束后插入end_event
+ module_forward_end_event = torch.npu.Event(enable_timing=True)
+ module_forward_end_event.record()
+ self.prefetch_module_event_dict[module_name][-1].append(module_forward_end_event)
+
+ torch.cuda.current_stream().wait_stream(SwapManager().prefetch_stream)
+ end_pack_event = None
+ if AdaptiveStepMgr().is_swap_profiling_step():
+ end_pack_event = torch.npu.Event(enable_timing=True)
+ end_pack_event.record()
+
+ for swap_tensor in SwapManager().swap_tensor_in_module:
+ # 更新每个tensor的pack_module_name
+ swap_tensor.pack_module_name = SwapManager().swap_tensors[-1].layer_name
+ if swap_tensor is SwapManager().swap_tensor_in_module[0]:
+ swap_tensor.first_tensor = True
+ swap_tensor.end_pack_event = end_pack_event
+
+ # record swap info
+ for swap_tensor in SwapManager().swap_tensor_in_module:
+ # cal tensor memory (MB)
+ tensor_memory = (swap_tensor.tensor.numel() * swap_tensor.tensor.element_size()) / (1024 * 1024)
+
+ if swap_tensor.pack_module_name == module_name:
+ self.recording_swap_momery_in_module(swap_tensor, swap_tensor.pack_module_name, tensor_memory)
+ self.recording_swap_time_in_module(swap_tensor, swap_tensor.pack_module_name, is_function)
+ else:
+ self.recording_swap_momery_in_module(swap_tensor, module_name, tensor_memory)
+ self.recording_swap_time_in_module(swap_tensor, module_name, is_function)
+
+ # reset swap_tensor_in_module
+ SwapManager().swap_tensor_in_module = []
+
+ def is_module_in_need_swap_module_name(self, module_name):
+ if module_name in self.need_swap_module_name:
+ return module_name
+ return None
+
+ # Records the memory swapped to the cpu in module
+ def recording_swap_momery_in_module(self, swap_tensor, key, tensor_memory):
+ has_key = key in AdaptiveMemoryPrefetch().swap_memory_in_module_dict.keys()
+ if not has_key:
+ AdaptiveMemoryPrefetch().swap_memory_in_module_dict[key] = tensor_memory
+ else:
+ if not swap_tensor.is_slice_tensor:
+ AdaptiveMemoryPrefetch().swap_memory_in_module_dict[key] += tensor_memory
+
+ # Records the time swapped to the cpu in module
+ def recording_swap_time_in_module(self, swap_tensor, key, is_function):
+ has_key = key in AdaptiveMemoryPrefetch().swap_event_dict.keys()
+ if not has_key and swap_tensor.first_tensor:
+ AdaptiveMemoryPrefetch().swap_event_dict[key] = [[swap_tensor.start_pack_event, swap_tensor.end_pack_event]]
+ elif has_key and swap_tensor.first_tensor:
+ if is_function:
+ AdaptiveMemoryPrefetch().swap_event_dict[key].append([swap_tensor.start_pack_event, swap_tensor.end_pack_event])
+ else:
+ AdaptiveMemoryPrefetch().swap_event_dict[swap_tensor.pack_module_name].append([swap_tensor.start_pack_event, swap_tensor.end_pack_event])
+
+
+def forward_post_hook_func_for_recording_time(module_name):
+ def custom_func(module, *args, **kargs):
+ AdaptiveMemoryPrefetch().sync_d2h_for_recording_time(module_name)
+
+ return custom_func
+
+
+def pre_forward_func(module_name, is_mark_first_layer):
+ if is_mark_first_layer:
+ FuncLocationMgr().is_first_layer = True
+ else:
+ module_forward_start_event = torch.npu.Event(enable_timing=True)
+ if module_name not in AdaptiveMemoryPrefetch().prefetch_module_event_dict.keys():
+ AdaptiveMemoryPrefetch().prefetch_module_event_dict[module_name] = [[module_forward_start_event]]
+ else:
+ AdaptiveMemoryPrefetch().prefetch_module_event_dict[module_name].append([module_forward_start_event])
+ module_forward_start_event.record()
+
+
+def pre_forward_hook_func(module_name, is_mark_first_layer):
+ def custom_func(module, *args, **kargs):
+ pre_forward_func(module_name, is_mark_first_layer)
+
+ return custom_func
+
+
+def prefetch_register_forward_hook_for_recording_time(module, name):
+ post_hook = module.register_forward_hook(forward_post_hook_func_for_recording_time(name))
+ AdaptiveMemoryPrefetch().modules_hooks.append(post_hook)
+
+
+def prefetch_register_pre_forward_hook(module, name, is_mark_first_layer=False):
+ pre_hook = module.register_forward_pre_hook(pre_forward_hook_func(name, is_mark_first_layer))
+ AdaptiveMemoryPrefetch().modules_hooks.append(pre_hook)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_profiling.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..1299641c4547ab280017c914ecb8d4e4c36afd12
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_profiling.py
@@ -0,0 +1,327 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+
+from collections.abc import Iterable
+from copy import deepcopy
+
+import re
+import acl
+import torch
+import torch.nn
+
+from megatron.training import print_rank_0, get_args
+from megatron.core import tensor_parallel
+
+from .adaptive_memory_tool import AdaptiveStepMgr, BYTES_PER_MB, SingletonBase, ContextKey as Key
+from .adaptive_memory_tool import FuncLocationMgr
+from .adaptive_memory_prefetch import AdaptiveMemoryPrefetch
+
+
+class RecomputeHook(metaclass=SingletonBase):
+ def __init__(self):
+ self.recompute_modules = []
+
+ @staticmethod
+ def hook_checkpoint_forward(forward_func):
+ def custom_forward(*args, **kwargs):
+ def inside_forward(*new_args):
+ origin_args = new_args[:len(args)]
+ origin_kwargs = dict(zip(kwargs.keys(), new_args[len(args):]))
+ return forward_func(*origin_args, **origin_kwargs)
+ new_args = args + tuple(kwargs.values())
+ return tensor_parallel.checkpoint(inside_forward, False, *new_args)
+ return custom_forward
+
+ def reset_recompute_modules(self):
+ for m in self.recompute_modules:
+ m.forward = m.no_checkpoint_adaptive_recompute_forward
+ self.recompute_modules.clear()
+
+
+class AdaptiveMemoryProfiling(metaclass=SingletonBase):
+
+ def __init__(self):
+ # saved module data and structure
+ self.context = {'name': 'root', 'deep': 0, 'prefix_name': '', 'submodules': []}
+ # save modules hook
+ self.profiling_hooks = []
+ # record allowed memory adaptation module
+ self.allowed_adapt_module = []
+ # time events, used to calculate time
+ self.time_event_list = []
+ # save origin modules
+ self.checkpointed_modules = []
+ self.layer0_module = None
+ self.layer0_ctx = None
+
+ def addup_allowed_mem_adapt_profiling_module(self, module):
+ if not issubclass(module, torch.nn.Module):
+ raise TypeError("Allowed adapt module must be subclass of torch.nn.Module")
+ self.allowed_adapt_module.append(module)
+
+ @staticmethod
+ def _tag_module(ctx, current_ctx, current_is_adapt_module, upper_is_nn_module_list):
+ if current_is_adapt_module:
+ current_ctx[Key.ALLOWED_ADAPT] = True
+ if upper_is_nn_module_list:
+ ctx[Key.IS_MODULE_LIST] = True
+ ctx[Key.IS_ADAPT_LAYER] = True
+ else:
+ current_ctx[Key.IS_ADAPT_LAYER] = True
+
+ return False
+
+ return True
+
+ def record_time(self):
+ while self.time_event_list:
+ self._record_submodule_forward_time(self.context)
+
+ def update_whole_model_memory(self):
+ _, all_memory, _ = acl.rt.get_mem_info(1)
+ self.context[Key.USED_MEM] = torch.npu.memory_allocated() / BYTES_PER_MB
+ self.context[Key.DEVICE_MEMORY] = all_memory / BYTES_PER_MB
+
+ def reset_profiling_all_hooks(self):
+ self.reset_profiling_hooks()
+ self.reset_profiling_recompute_hook()
+
+ def reset_profiling_hooks(self):
+ for ph in self.profiling_hooks:
+ ph.remove()
+ self.profiling_hooks.clear()
+
+ def reset_profiling_recompute_hook(self):
+ for m in self.checkpointed_modules:
+ m.forward = m.no_checkpoint_forward
+ self.checkpointed_modules.clear()
+
+ def insert_func_profiling(self, ctx, child_name):
+ self._find_adapt_layer(self.context, ctx, child_name)
+
+ def _find_adapt_layer(self, ctx, new_ctx, child):
+ if ctx.get(Key.ALLOWED_ADAPT, False):
+ self._insert_ctx(ctx, new_ctx, child)
+ return
+ for sub in ctx.get(Key.SUBMODULES, []):
+ self._find_adapt_layer(sub, new_ctx, child)
+
+ @staticmethod
+ def _is_parent_child_relation(parent_ctx, child_ctx):
+ if parent_ctx[Key.DEEP] + 1 != child_ctx[Key.DEEP]:
+ return False
+
+ part1 = f"{parent_ctx[Key.PREFIX_NAME]}.{parent_ctx[Key.NAME]}".split(".")
+ part2 = child_ctx[Key.PREFIX_NAME].split(".")
+ if len(part1) != len(part2):
+ return False
+
+ # compare ctx parent cross chunks and layers, the prefix differ only with the index in torch.nn.ModuleList
+ def compare(p1, p2):
+ return re.sub(r'\d+$', '#', p1) == re.sub(r'\d+$', '#', p2)
+
+ return all(compare(x, y) for x, y in zip(part1, part2))
+
+ @staticmethod
+ def _clone_to_insert_ctx(parent_ctx, new_ctx):
+ cur_prefix_name = f"{parent_ctx[Key.PREFIX_NAME]}.{parent_ctx[Key.NAME]}"
+ to_insert_ctx = deepcopy(new_ctx)
+ if to_insert_ctx[Key.PREFIX_NAME] != cur_prefix_name:
+ to_insert_ctx[Key.PREFIX_NAME] = cur_prefix_name
+ del to_insert_ctx[Key.INPUT]
+ del to_insert_ctx[Key.MEMORY]
+ del to_insert_ctx[Key.PRE_TOTAL_TIME]
+ del to_insert_ctx[Key.OUTPUT]
+ del to_insert_ctx[Key.FORWARD_CNT]
+ del to_insert_ctx[Key.AVG_TIME]
+ del to_insert_ctx[Key.IS_MODLUE_OF_LAYER0]
+ return to_insert_ctx
+
+ def _insert_ctx(self, ctx, new_ctx, child_name):
+ if self._is_parent_child_relation(ctx, new_ctx):
+ to_insert_ctx = self._clone_to_insert_ctx(ctx, new_ctx)
+ if child_name:
+ idx = next(idx for idx, tmp in enumerate(ctx[Key.SUBMODULES]) if tmp[Key.NAME] == child_name)
+ child_ctx = ctx[Key.SUBMODULES][idx]
+ self._update_children_ctx(child_ctx, to_insert_ctx[Key.PREFIX_NAME], to_insert_ctx[Key.NAME])
+ to_insert_ctx[Key.SUBMODULES] = [child_ctx]
+ ctx[Key.SUBMODULES][idx] = to_insert_ctx
+ else:
+ siblings = ctx.get(Key.SUBMODULES, [])
+ siblings.append(to_insert_ctx)
+ ctx[Key.SUBMODULES] = siblings
+ return True
+
+ for sub in ctx.get(Key.SUBMODULES, []):
+ if self._insert_ctx(sub, new_ctx, child_name):
+ return True
+ return False
+
+ def _update_children_ctx(self, ctx, parent, func_name):
+ old_prefix_name = ctx[Key.PREFIX_NAME]
+ new_prefix_name = old_prefix_name[0:len(parent)] + "." + func_name + old_prefix_name[len(parent):]
+ ctx[Key.PREFIX_NAME] = new_prefix_name
+ ctx[Key.DEEP] += 1
+ AdaptiveMemoryPrefetch().prefetch_deep_end = max(AdaptiveMemoryPrefetch().prefetch_deep_end, ctx[Key.DEEP])
+
+ for sub in ctx.get(Key.SUBMODULES, []):
+ self._update_children_ctx(sub, parent, func_name)
+
+ def get_allowed_adapt_module(self):
+ return self.allowed_adapt_module
+
+ def is_layer0(self, ctx):
+ if ctx[Key.NAME] == "0" and "expert" not in ctx[Key.PREFIX_NAME]:
+ return True
+ return False
+
+ def forward_pre_hook(self, prefix, name, ctx):
+ """ Hook, which will be registered before the FWD to add context parameters and add timer start event """
+ def hook(module, *args, **kwargs):
+ FuncLocationMgr().push_name(prefix, name)
+ if Key.IS_LAYER0_OF_MODULE0 in ctx:
+ FuncLocationMgr().is_first_layer = True
+
+ if AdaptiveStepMgr().is_skipping_step():
+ return
+
+ if AdaptiveStepMgr().is_last_recompute_profiling_step():
+ ctx[Key.INPUT] = self.cal_input_output_size(args) / BYTES_PER_MB
+ mem_alloc = torch.npu.memory_allocated()
+ ctx[Key.MEMORY] = mem_alloc / BYTES_PER_MB - ctx[Key.INPUT]
+ else:
+ # 通过Key.MEMORY来判断此module是否被执行
+ ctx[Key.INPUT] = 0
+ ctx[Key.MEMORY] = 0
+
+
+ if AdaptiveStepMgr().is_recompute_profiling_step() and not AdaptiveStepMgr().is_last_recompute_profiling_step():
+ start_event = torch.npu.Event(enable_timing=True)
+ self.time_event_list.append([start_event])
+ start_event.record()
+
+ return hook
+
+ def forward_post_hook(self, prefix, name, ctx):
+ """ Hook, which will be registered in the FWD to calculate context parameters and add timer stop event """
+ def hook(module, args, output):
+ FuncLocationMgr().pop_name(prefix, name)
+ if Key.IS_LAYER0_OF_MODULE0 in ctx:
+ FuncLocationMgr().is_first_layer = False
+
+ if AdaptiveStepMgr().is_recompute_profiling_step() and not AdaptiveStepMgr().is_last_recompute_profiling_step():
+ end_event = torch.npu.Event(enable_timing=True)
+ end_event.record()
+ for item in reversed(self.time_event_list):
+ if len(item) == 1:
+ item.append(end_event)
+ break
+
+ if AdaptiveStepMgr().is_last_recompute_profiling_step():
+ ctx[Key.OUTPUT] = self.cal_input_output_size(output) / BYTES_PER_MB
+ ctx[Key.MEMORY] = torch.npu.memory_allocated() / BYTES_PER_MB - ctx[Key.MEMORY]
+
+ return hook
+
+ def construct_ctx_recursively(self, deep, prefix_name, model, ctx, allowed_adapting):
+ """ Function, recursively construct context to save profiling data in the future """
+ next_allowed_adapting = allowed_adapting
+ for name, module in model.named_children():
+ if Key.SUBMODULES not in ctx:
+ ctx[Key.SUBMODULES] = []
+ current_ctx = {Key.NAME: name, Key.DEEP: deep, Key.PREFIX_NAME: prefix_name}
+ ctx[Key.SUBMODULES].append(current_ctx)
+ if self.is_layer0(current_ctx):
+ AdaptiveMemoryPrefetch().prefetch_deep_start = current_ctx[Key.DEEP]
+ if current_ctx[Key.DEEP] > AdaptiveMemoryPrefetch().prefetch_deep_end:
+ AdaptiveMemoryPrefetch().prefetch_deep_end = current_ctx[Key.DEEP]
+ if allowed_adapting:
+ for allowed_adapt_module in self.allowed_adapt_module:
+ module_flag = isinstance(module, allowed_adapt_module)
+ model_flag = isinstance(model, torch.nn.ModuleList)
+ next_allowed_adapting = self._tag_module(ctx, current_ctx, module_flag, model_flag)
+ next_name = (prefix_name + '.' + name) if prefix_name != '' else name
+ next_deep = deep + 1
+ self.construct_ctx_recursively(next_deep, next_name, module, current_ctx, next_allowed_adapting)
+
+ def register_hook_recursively(self, model, ctx, in_first_module=False, in_first_layer=False, start_index=0):
+ """ Function, recursively register hooks to get profiling data on needed modules """
+ for module in model.children():
+ if Key.SUBMODULES not in ctx:
+ continue
+
+ current_ctx = ctx[Key.SUBMODULES][start_index]
+ name = current_ctx[Key.NAME]
+ prefix_name = current_ctx[Key.PREFIX_NAME]
+
+ # whole first module or in layer 0
+ if prefix_name in (Key.MODULE, Key.MODULE + '0') or in_first_layer:
+ if prefix_name not in (Key.MODULE, Key.MODULE + '0'):
+ current_ctx[Key.IS_MODLUE_OF_LAYER0] = True
+ self._register_hook(module, prefix_name, name, current_ctx)
+ self.register_hook_recursively(module, current_ctx, in_first_module, in_first_layer)
+ # whole layer 0
+ elif Key.ALLOWED_ADAPT in current_ctx and in_first_module and start_index == 0:
+ self.layer0_ctx = current_ctx
+ self.layer0_module = module
+ current_ctx[Key.IS_LAYER0_OF_MODULE0] = True
+ current_ctx[Key.IS_MODLUE_OF_LAYER0] = True
+ self._register_hook(module, prefix_name, name, current_ctx)
+ self.register_hook_recursively(module, current_ctx, in_first_module, True)
+ # encoder
+ elif isinstance(module, torch.nn.ModuleList) and Key.IS_ADAPT_LAYER in current_ctx and in_first_module:
+ self._register_hook(model, ctx[Key.PREFIX_NAME], ctx[Key.NAME], ctx)
+ self.register_hook_recursively(module, current_ctx, in_first_module, in_first_layer)
+ # recompute layer hook
+ elif Key.IS_MODULE_LIST in ctx and Key.ALLOWED_ADAPT in current_ctx:
+ module.no_checkpoint_forward = module.forward
+ module.forward = RecomputeHook().hook_checkpoint_forward(module.forward)
+ self.checkpointed_modules.append(module)
+ # do not hook, and check next one
+ else:
+ self.register_hook_recursively(module, current_ctx, in_first_module, in_first_layer)
+
+ start_index += 1
+
+ def cal_input_output_size(self, args):
+ size = 0
+ if isinstance(args, torch.Tensor):
+ size += args.numel() * args.element_size()
+ elif isinstance(args, Iterable):
+ for arg in args:
+ size += self.cal_input_output_size(arg)
+
+ return size
+
+ def _register_hook(self, module, prefix_name, name, current_ctx):
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(prefix_name, name, current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(prefix_name, name, current_ctx))
+ self.profiling_hooks.append(pre_hook)
+ self.profiling_hooks.append(post_hook)
+
+ def _record_submodule_forward_time(self, context):
+ if Key.MEMORY in context and Key.IS_FUNCTION not in context:
+ cur_event_list = self.time_event_list.pop(0)
+ start, end = cur_event_list[0], cur_event_list[1]
+ cur_time = start.elapsed_time(end)
+ if Key.PRE_TOTAL_TIME in context:
+ context[Key.FORWARD_CNT] += 1
+ context[Key.PRE_TOTAL_TIME] += cur_time
+ context[Key.AVG_TIME] = context[Key.PRE_TOTAL_TIME] / context[Key.FORWARD_CNT]
+ else:
+ context[Key.FORWARD_CNT] = 1
+ context[Key.PRE_TOTAL_TIME] = cur_time
+ context[Key.AVG_TIME] = cur_time
+ if Key.SUBMODULES not in context:
+ return
+ for sub_layer in context[Key.SUBMODULES]:
+ self._record_submodule_forward_time(sub_layer)
+
+ def construct_and_register_profiling_hooks(self, models):
+ """ Function, used to construct and register hooks into first model to get profiling data in the future """
+ if isinstance(models, Iterable):
+ for idx, model in enumerate(models):
+ self.construct_ctx_recursively(1, Key.MODULE + str(idx), model, self.context, True)
+ self.register_hook_recursively(model, self.context, idx == 0, start_index=idx)
+ else:
+ self.construct_ctx_recursively(1, Key.MODULE, models, self.context, True)
+ self.register_hook_recursively(models, self.context)
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_solver.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bbf7f731ba169029617418a0036b636a8c1248e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_solver.py
@@ -0,0 +1,438 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import sys
+import time
+import pickle
+from copy import deepcopy
+from typing import List
+
+import torch
+import numpy as np
+from megatron.training import print_rank_0, get_args
+from megatron.core.num_microbatches_calculator import get_num_microbatches
+from megatron.core import parallel_state as ps
+from .adaptive_memory_cache import AdaptiveModelMemPolicy, PolicyCacheManager
+from .adaptive_memory_policy import AdaptMemPolicyManager
+from .adaptive_memory_swap_manager import SwapManager
+from .adaptive_memory_tool import SingletonBase, LayerAction, ModuleAction, FuncLocation, ContextKey as Key, BYTES_PER_MB
+from .adaptive_memory_tool import AdaptiveStepMgr, broadcast_obj
+
+
+class AdaptMemGraphSolver(metaclass=SingletonBase):
+ def __init__(self):
+ self.num_warmup_bs_in_chunks = self.get_chunk_num_warmup_micro_batches()
+ self.adapt_mem_policy = {}
+ self.static_memory = 0
+ self.best_layer_policy_comb = []
+ self.func_locations: List[FuncLocation] = []
+ self.need_prepare_solver = True
+
+ self.device_memory = sys.maxsize
+ self.cur_adapt_policy = None
+ self.swap_size = 0
+ self.record_swap_out_size = 0
+ self.last_num_alloc_retries = torch.npu.memory_stats()["num_alloc_retries"]
+ self.remove_swap_manager_hook_step = 0
+ self.cur_device_memory = -1
+ self.flag_find_target_memory = False
+ self.first_non_oom_device_memory = 0
+ self.min_dichotomy_value = 1
+ self.dichotomy_memory_left = 0
+ self.dichotomy_memory_right = 0
+ self.alloc_retries_times = 0 # 记录当前策略alloc失败的次数
+ self.is_stable_for_non_oom_policy = 1 # 判断非oom的策略是否稳定==>1:稳定、0:不稳定
+
+ def prepare_solver(self, model_context):
+ self.need_prepare_solver = False
+ self.static_memory = self.get_static_mem(model_context)
+ self.dichotomy_memory_left = self.static_memory
+
+ @staticmethod
+ def get_chunk_num_warmup_micro_batches():
+ num_warmup_bs_in_chunks = []
+ pp = ps.get_pipeline_model_parallel_world_size()
+ vpp = ps.get_virtual_pipeline_model_parallel_world_size() or 1
+ pp_rank = ps.get_pipeline_model_parallel_rank()
+ num_micro_batches = get_num_microbatches()
+ if pp <= 1 or None in (num_micro_batches, pp_rank, vpp):
+ return [1]
+ elif vpp == 1:
+ num_warmup_bs = pp - pp_rank - 1
+ num_warmup_bs += 1
+ num_warmup_bs_in_chunks.append(num_warmup_bs)
+ else:
+ total_num_micro_batches = num_micro_batches * vpp
+ num_warmup_bs = (pp - pp_rank - 1) * 2
+ num_warmup_bs += (vpp - 1) * pp
+ num_warmup_bs += 1
+ num_warmup_bs = min(num_warmup_bs, total_num_micro_batches)
+ remain_batch_num = (num_warmup_bs - pp * vpp)
+ for i in range(vpp):
+ if i == 0:
+ num_warmup_bs_in_chunks.append(pp + max(0, remain_batch_num))
+ elif i == vpp - 1:
+ num_warmup_bs_in_chunks.append(pp + min(0, remain_batch_num))
+ else:
+ num_warmup_bs_in_chunks.append(pp)
+
+ print_rank_0(f"layer_num:{get_args().num_layers}")
+ print_rank_0(f"pp:{pp}")
+ print_rank_0(f"vpp:{vpp}")
+ print_rank_0(f"pp_rank:{pp_rank}")
+ print_rank_0(f"num_micro_batches:{num_micro_batches}")
+ print_rank_0(f"num_warmup_bs_in_chunks:{num_warmup_bs_in_chunks}")
+ print_rank_0(f"layer_num_per_ppstage:{get_args().num_layers // ps.get_pipeline_model_parallel_world_size()}")
+
+ return num_warmup_bs_in_chunks
+
+ @staticmethod
+ def tensor_all_reduce(num_list, op):
+ # all reduce the "num_list" between tp ranks in group
+ reduce_tensor = torch.tensor(num_list, device=torch.npu.current_device())
+ if ps.get_tensor_model_parallel_world_size() > 1:
+ torch.distributed.all_reduce(reduce_tensor, op=op, group=ps.get_tensor_model_parallel_group())
+ # all reduce the "num_list" between dp ranks in group
+ if ps.get_data_parallel_world_size(True) > 1:
+ torch.distributed.all_reduce(reduce_tensor, op=op, group=ps.get_data_parallel_group(True))
+ result = reduce_tensor.cpu().numpy().tolist()
+ del reduce_tensor
+ return result
+
+ def is_stable_policy(self):
+ if AdaptiveStepMgr().get_cur_step() > self.remove_swap_manager_hook_step != 0:
+ return True
+ total_swap_out_size = SwapManager().oom_rescue_total_swap_out_size
+ self.swap_size = (total_swap_out_size - self.record_swap_out_size) // BYTES_PER_MB
+ self.check_num_alloc_retries()
+ num_list = [
+ int(total_swap_out_size), int(AdaptMemPolicyManager().hccl_memory), int(self.swap_size),
+ int(self.flag_find_target_memory), int(self.alloc_retries_times)
+ ]
+ size_tensor = self.tensor_all_reduce(num_list, torch.distributed.ReduceOp.MAX)
+ total_swap_out_size = size_tensor[0]
+ AdaptMemPolicyManager().hccl_memory = size_tensor[1]
+ self.swap_size = size_tensor[2]
+ self.flag_find_target_memory = bool(size_tensor[3])
+ self.alloc_retries_times = size_tensor[4]
+ SwapManager().oom_rescue_total_swap_out_size = total_swap_out_size
+
+ if self.swap_size <= 0 and self.flag_find_target_memory:
+ return True
+ self.record_swap_out_size = total_swap_out_size
+ return False
+
+ def check_num_alloc_retries(self):
+ num_alloc_retries = torch.npu.memory_stats()["num_alloc_retries"]
+ # if policy is normal and stable
+ if num_alloc_retries == self.last_num_alloc_retries:
+ return
+ retries_times = num_alloc_retries - self.last_num_alloc_retries
+ self.last_num_alloc_retries = num_alloc_retries
+ # policy tag oom if policy is unstable
+ if self.swap_size == 0 and (retries_times > 1 or self.is_stable_for_non_oom_policy == 0):
+ self.swap_size = 1
+ # if policy is oom or unstable
+ if self.swap_size > 0:
+ return
+
+ self.alloc_retries_times += 1
+ if self.alloc_retries_times > 1:
+ print_rank_0("this is a unstable policy, try select another one.")
+ self.swap_size = 1
+
+ def reduce_device_memory(self, device_memory):
+ cur_min_memory = min(self.device_memory, device_memory)
+ self.device_memory, = self.tensor_all_reduce([int(cur_min_memory)], torch.distributed.ReduceOp.MIN)
+ print_rank_0(f"reduce device memory from {device_memory} to {self.device_memory}")
+
+ def check_cur_adapt_policy(self):
+ if not self.cur_adapt_policy:
+ return
+
+ policy_cache_manager = PolicyCacheManager()
+ flag_in_oom_list = policy_cache_manager.check_in_oom_cache(self.cur_adapt_policy)
+ flag_in_normal_list = policy_cache_manager.check_in_normal_cache(self.cur_adapt_policy)
+ if self.swap_size > 0:
+ if not flag_in_oom_list:
+ policy_cache_manager.add_oom_policy_cache(deepcopy(self.cur_adapt_policy))
+ if flag_in_normal_list:
+ policy_cache_manager.delete_normal_policy_cache(self.cur_adapt_policy)
+ return
+ if flag_in_oom_list or self.alloc_retries_times != 0:
+ return
+ if not flag_in_normal_list:
+ policy_cache_manager.add_normal_policy_cache(deepcopy(self.cur_adapt_policy))
+
+ def solve_adapt_mem_policy(self):
+ flag_is_known_policy = True
+ cur_step = AdaptiveStepMgr().get_cur_step()
+ self.remove_swap_manager_hook_step = cur_step + 1
+ adapt_policy_list = None
+ while flag_is_known_policy:
+ torch.npu.synchronize()
+ self.cur_device_memory = self.dichotomy_find_memory()
+ print_rank_0(f"cur_device_memory:{self.cur_device_memory}")
+ if self.is_stable_for_non_oom_policy != 0: # 对于不稳定的策略不生成新策略,使用旧策略再测试一遍
+ adapt_policy_list = self.get_mem_policy(self.cur_device_memory)
+ self.cur_adapt_policy = AdaptiveModelMemPolicy("normal", self.best_layer_policy_comb)
+ if self.flag_find_target_memory:
+ self.remove_swap_manager_hook_step = cur_step + 10
+ print_rank_0(
+ f"success to find the target value of the current round of search: {self.cur_device_memory}")
+ break
+ # OOM policy
+ policy_cache_manager = PolicyCacheManager()
+ if policy_cache_manager.check_in_oom_cache(self.cur_adapt_policy):
+ self.swap_size = max(self.swap_size, 1)
+ continue
+ # no OOM policy
+ if policy_cache_manager.check_in_normal_cache(self.cur_adapt_policy):
+ self.swap_size = 0
+ continue
+ flag_is_known_policy = False
+
+ return adapt_policy_list
+
+ def get_dichotomy_value(self):
+ return (self.dichotomy_memory_left + self.dichotomy_memory_right) // 2
+
+ def dichotomy_find_memory(self):
+ # last policy is instability
+ if self.flag_find_target_memory:
+ self.dichotomy_memory_left = self.first_non_oom_device_memory
+ self.dichotomy_memory_right = self.cur_device_memory
+ self.flag_find_target_memory = False
+ if self.cur_device_memory == -1:
+ return self.device_memory
+
+ # OOM
+ if self.swap_size > 0:
+ print_rank_0(f"current policy is OOM, policy device memory: {self.cur_device_memory}")
+ self.is_stable_for_non_oom_policy = 1
+ self.alloc_retries_times = 0
+ self.dichotomy_memory_right = self.cur_device_memory
+ if self.first_non_oom_device_memory >= self.cur_device_memory:
+ self.first_non_oom_device_memory = 0
+ if self.dichotomy_memory_right <= self.static_memory:
+ raise ValueError("out of Memory!!!!!!!!!!")
+ elif self.dichotomy_memory_right <= self.dichotomy_memory_left:
+ self.dichotomy_memory_left = self.static_memory
+ return self.get_dichotomy_value()
+
+ # check non oom policy
+ if self.alloc_retries_times != 0 and self.is_stable_for_non_oom_policy == 1:
+ print_rank_0(f"current policy may be an unstable, policy device memory: {self.cur_device_memory}")
+ self.is_stable_for_non_oom_policy = 0
+ self.alloc_retries_times = 0
+ return self.cur_device_memory
+
+ self.is_stable_for_non_oom_policy = 1
+ self.alloc_retries_times = 0
+ self.dichotomy_memory_left = self.cur_device_memory
+ if self.first_non_oom_device_memory == 0:
+ self.first_non_oom_device_memory = self.cur_device_memory
+ if self.dichotomy_memory_right - self.dichotomy_memory_left <= self.min_dichotomy_value:
+ self.flag_find_target_memory = True
+ return self.dichotomy_memory_left
+
+ return self.get_dichotomy_value()
+
+ @staticmethod
+ def get_pp_layer_num():
+ return get_args().num_layers // ps.get_pipeline_model_parallel_world_size()
+
+ @staticmethod
+ def get_layer_num_per_chunk():
+ vpp = ps.get_virtual_pipeline_model_parallel_world_size() or 1
+ return AdaptMemGraphSolver.get_pp_layer_num() // vpp
+
+ def get_static_mem(self, model_context):
+ single_chunk_memory = 0
+ num_of_chunk = len(model_context[Key.SUBMODULES])
+ if num_of_chunk > 0 and Key.MEMORY in model_context[Key.SUBMODULES][0]:
+ single_chunk_memory = model_context[Key.SUBMODULES][0][Key.MEMORY]
+ # 不能被节省的动态内存
+ mem_space_cannot_be_saved = (single_chunk_memory - AdaptMemPolicyManager().total_adapt_memory) * num_of_chunk
+ # 静态内存 = 模型总内存 + 不能被节省的动态内存
+ static_mem_size = model_context[Key.USED_MEM] + mem_space_cannot_be_saved
+ print_rank_0(f"static_memory:{static_mem_size}")
+ return static_mem_size
+
+ def get_mem_policy(self, device_memory):
+ print_rank_0("Using the knapsack algorithm to find the optimal strategy")
+ self.adapt_mem_policy.clear()
+ self.knapsack_best(device_memory)
+ adapt_mem_policy_list = self.get_adapt_mem_policy_list()
+ print_rank_0(f"adapt_mem_policy_list:{adapt_mem_policy_list}")
+ if torch.distributed.is_initialized():
+ # 把self.recompute_policy字典转换为recompute_policy_list列表,方便广播到其他卡上
+ adapt_mem_policy_list = broadcast_obj(adapt_mem_policy_list)
+ self.best_layer_policy_comb = broadcast_obj(self.best_layer_policy_comb)
+ return adapt_mem_policy_list
+
+ def get_max_goods_value(self, idx, ans, device_memory):
+ i, j, k = idx
+ pre_step_ans = ans[i - 1][j - k]
+ if k == 0:
+ return deepcopy(pre_step_ans)
+
+ goods_value = ans[i][j]
+ # calculate memory
+ memory = pre_step_ans.memory
+ pre_layer_num = len(pre_step_ans.polices)
+ for index in range(k):
+ cur_layer_index = pre_layer_num + index
+ cur_layer_chunk_rank = cur_layer_index // self.get_layer_num_per_chunk()
+ cur_layer_bs = self.num_warmup_bs_in_chunks[cur_layer_chunk_rank]
+ cur_layer_memory_cost = cur_layer_bs * AdaptMemPolicyManager().policy_combinations[i].memory
+ memory += cur_layer_memory_cost
+ # calculate cost
+ comb_time = pre_step_ans.time + k * AdaptMemPolicyManager().policy_combinations[i].time
+ # calculate device_memory
+ if pre_step_ans.time == sys.maxsize:
+ comb_time = k * AdaptMemPolicyManager().policy_combinations[i].time
+ max_free_memory = max(device_memory - self.static_memory, 0)
+
+ if max_free_memory >= memory and comb_time <= goods_value.time and (len(pre_step_ans.polices) + k) == j:
+ goods_value.memory = memory
+ goods_value.time = comb_time
+ goods_value.polices.clear()
+ goods_value.polices.extend(pre_step_ans.polices)
+ goods_value.polices.extend(AdaptMemPolicyManager().policy_combinations[i] for _ in range(k))
+
+ return goods_value
+
+ def add_func_locations(self, layer_idx, func_name, action):
+ self.func_locations.append(FuncLocation(layer_idx, func_name, action))
+
+ def get_cur_layer_idx(self, count):
+ pp = ps.get_pipeline_model_parallel_world_size()
+ vpp = ps.get_virtual_pipeline_model_parallel_world_size() or 1
+ total_layers = get_args().num_layers
+ if vpp > 1:
+ layers_per_chunk = total_layers // pp // vpp
+
+ # calc count belong to chunk and layer idx
+ remain = count % (pp * vpp * layers_per_chunk)
+ cur_chunk_idx = remain // (pp * layers_per_chunk) # 当前chunk id
+ cur_layer_idx = remain % (pp * layers_per_chunk) % layers_per_chunk # 当前layer在chunk内的id
+ global_layer_idx = cur_chunk_idx * layers_per_chunk + cur_layer_idx
+ return global_layer_idx
+ elif pp > 1:
+ layers_per_pp = total_layers // pp
+ global_layer_idx = count % layers_per_pp
+ return global_layer_idx
+ else:
+ global_layer_idx = count % total_layers
+ return global_layer_idx
+
+ def get_func_action(self, function_name, count) -> ModuleAction:
+ pp = ps.get_pipeline_model_parallel_world_size()
+ total_layers = get_args().num_layers
+ layers_per_pp = total_layers // pp
+
+ all_same_func_loc = [x for x in self.func_locations if x.func_name == function_name]
+ if len(all_same_func_loc) != layers_per_pp:
+ raise AssertionError("get_func_action error.")
+ global_layer_idx = self.get_cur_layer_idx(count)
+ if global_layer_idx != all_same_func_loc[global_layer_idx].layer_idx:
+ raise AssertionError("get_func_action error.")
+ return all_same_func_loc[global_layer_idx].action
+
+ def get_mem_layer_policy(self, combination_num, layer_num, ans):
+ apm = AdaptMemPolicyManager()
+ layer_full_recompute_memory = 0
+ for index in range(layer_num):
+ cur_layer_index = index
+ cur_layer_chunk_rank = cur_layer_index // self.get_layer_num_per_chunk()
+ cur_layer_memory_cost = self.num_warmup_bs_in_chunks[cur_layer_chunk_rank] * apm.full_recompute_comb.memory
+ layer_full_recompute_memory += cur_layer_memory_cost
+
+ layer_full_recompute_time = layer_num * apm.full_recompute_comb.time
+
+ self.best_layer_policy_comb = [apm.full_recompute_comb for _ in range(layer_num)]
+
+ size = layer_num - len(ans[combination_num][layer_num].polices)
+ pre_layer_num = len(ans[combination_num][layer_num].polices)
+ memory = ans[combination_num][layer_num].memory
+ for index in range(size):
+ cur_layer_index = pre_layer_num + index
+ cur_layer_chunk_rank = cur_layer_index // self.get_layer_num_per_chunk()
+ memory += self.num_warmup_bs_in_chunks[cur_layer_chunk_rank] * apm.full_recompute_comb.memory
+ comb_time = ans[combination_num][layer_num].time + size * apm.full_recompute_comb.time
+ best_policy_comb = deepcopy(ans[combination_num][layer_num].polices)
+ best_policy_comb.extend(size * [apm.full_recompute_comb])
+
+ if comb_time < layer_full_recompute_time:
+ self.best_layer_policy_comb.clear()
+ self.best_layer_policy_comb = best_policy_comb
+
+ print_rank_0(f"full_recompute_comb.time:{apm.full_recompute_comb.time}")
+ print_rank_0(f"full_recompute_comb.memory:{apm.full_recompute_comb.memory}")
+ print_rank_0(f"without_adaptive_comb.time:{apm.without_adaptive_comb.time}")
+ print_rank_0(f"without_adaptive_comb.memory:{apm.without_adaptive_comb.memory}")
+ print_rank_0(f"full_swap_comb.time:{apm.full_swap_comb.time}")
+ print_rank_0(f"full_swap_comb.memory:{apm.full_swap_comb.memory}")
+
+ for policy in self.best_layer_policy_comb:
+ policy_recompute = str(policy.recompute)
+ policy_swap = str(policy.swap)
+ if (policy_recompute, policy_swap) in self.adapt_mem_policy.keys():
+ self.adapt_mem_policy[policy_recompute, policy_swap] += 1
+ else:
+ self.adapt_mem_policy[policy_recompute, policy_swap] = 1
+ print_rank_0(f"adapt_mem_policy_dict:{self.adapt_mem_policy}")
+
+ def knapsack_best(self, device_memory):
+ start_time = time.time()
+ combination_num = len(AdaptMemPolicyManager().policy_combinations) - 1
+ if AdaptMemPolicyManager().policy_combinations[0] is not None:
+ combination_num = len(AdaptMemPolicyManager().policy_combinations)
+ # make combination index id begin for 1.
+ AdaptMemPolicyManager().policy_combinations.insert(0, None)
+ print_rank_0(f"combination_num:{combination_num}")
+
+ # init ans
+ def default_policy():
+ return AdaptiveModelMemPolicy("normal", [])
+
+ ans = [[default_policy() for _ in range(self.get_pp_layer_num() + 1)] for _ in range(combination_num + 1)]
+
+ # find max goods value
+ for i in range(1, combination_num + 1):
+ for j in range(self.get_pp_layer_num() + 1):
+ if i >= 2:
+ ans[i - 2][j].polices.clear()
+ for k in range(j + 1):
+ ans[i][j] = self.get_max_goods_value([i, j, k], ans, device_memory)
+ self.get_mem_layer_policy(combination_num, self.get_pp_layer_num(), ans)
+ end_time = time.time()
+ execution_time = end_time - start_time
+ print_rank_0(f"The execution time of the knapsack algorithm is {execution_time} seconds.")
+
+ def get_adapt_mem_policy_list(self):
+ adapt_mem_policy_list = []
+ apm = AdaptMemPolicyManager()
+ for key, times in self.adapt_mem_policy.items():
+ temp_adapt_mem_policy_list = [times]
+ key_recompute = eval(key[0])
+ key_swap = eval(key[1])
+ if key_recompute == apm.without_adaptive_comb.recompute and key_swap == apm.without_adaptive_comb.swap:
+ temp_adapt_mem_policy_list.append(LayerAction.NONE)
+ temp_adapt_mem_policy_list.extend([ModuleAction.NONE] * apm.adapt_modules_num)
+ elif key_recompute == apm.full_recompute_comb.recompute and key_swap == apm.full_recompute_comb.swap:
+ temp_adapt_mem_policy_list.append(LayerAction.FULL_RECOMPUTE)
+ temp_adapt_mem_policy_list.extend([ModuleAction.RECOMPUTE] * apm.adapt_modules_num)
+ elif key_recompute == apm.full_swap_comb.recompute and key_swap == apm.full_swap_comb.swap:
+ temp_adapt_mem_policy_list.append(LayerAction.FULL_SWAP)
+ temp_adapt_mem_policy_list.extend([ModuleAction.SWAP] * apm.adapt_modules_num)
+ else:
+ temp_adapt_mem_policy_list.append(LayerAction.ADAPTIVE)
+ for module_name in apm.module_layers_name:
+ if module_name in key_recompute:
+ temp_adapt_mem_policy_list.append(ModuleAction.RECOMPUTE)
+ elif module_name in key_swap:
+ temp_adapt_mem_policy_list.append(ModuleAction.SWAP)
+ else:
+ temp_adapt_mem_policy_list.append(ModuleAction.NONE)
+ adapt_mem_policy_list.append(temp_adapt_mem_policy_list)
+ return adapt_mem_policy_list
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_swap_manager.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_swap_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..66757425cd198c265a2e8c72d4957f7c028a67fa
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_swap_manager.py
@@ -0,0 +1,450 @@
+from enum import IntEnum
+from typing import List, Tuple
+
+import torch
+import torch_npu
+from numpy import mean
+from torch.cuda import Event
+from megatron.training import print_rank_0
+
+from .adaptive_memory_tool import SingletonBase, CpuTensorCache
+from .adaptive_memory_tool import FuncLocationMgr, broadcast_obj
+from .adaptive_memory_tool import AdaptiveStepMgr
+
+
+class SwappableTensorStat(IntEnum):
+ HOST = 0
+ DEVICE = 1
+ D2H = 2
+ H2D = 3
+
+
+class SwappableTensor:
+ def __init__(self, tensor, stream, is_prefetch=True):
+ self.stream = stream
+ self.tensor = tensor
+ self.h2d_event = None
+ self.is_prefetch = is_prefetch
+ self.tensor_cpu = None
+ self.storage_size = tensor.storage().size()
+ self.stat = SwappableTensorStat.DEVICE
+ self.data_ptr = tensor.data_ptr()
+ self.storage_data_ptr = tensor.storage().data_ptr()
+ self.is_slice_tensor = tensor.storage().size() != tensor.numel()
+ self.first_tensor = False
+ self.is_allowed_oom_rescue_swap = False
+ self.bro_tensors = None
+ self.cap_tensor = None # 和此tensor共享底层storage,并占用整个storage的tensor
+ # prefetch
+ self.start_pack_event = None
+ self.end_pack_event = None
+ self.layer_name = "" # 记录tensor在那个module被挂的hook
+ self.pack_module_name = None # 记录tensor在那个module被pack出去的
+ self.is_firt_same_ptr_tensor = True
+
+
+ def launch_d2h(self):
+ if self.stat != SwappableTensorStat.DEVICE:
+ return
+ forward_event = torch.npu.Event()
+ forward_event.record()
+ with torch.no_grad():
+ with torch_npu.npu.stream(self.stream):
+ self.stream.wait_event(forward_event)
+ if self.is_slice_tensor:
+ self.tensor_cpu.copy_(self.tensor, non_blocking=self.is_prefetch)
+ else:
+ self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=self.is_prefetch)
+ self.stat = SwappableTensorStat.D2H
+
+ def change_stat_to_host(self):
+ if self.stat != SwappableTensorStat.D2H:
+ return
+ self.stat = SwappableTensorStat.HOST
+
+ def launch_h2d(self):
+ if self.stat != SwappableTensorStat.HOST:
+ return
+ with torch.no_grad():
+ with torch_npu.npu.stream(self.stream):
+ if self.is_slice_tensor:
+ self.tensor.copy_(self.tensor_cpu, non_blocking=self.is_prefetch)
+ else:
+ self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=self.is_prefetch)
+ if self.h2d_event is not None:
+ self.h2d_event.record()
+ self.stat = SwappableTensorStat.H2D
+
+
+ def change_stat_to_device(self):
+ if self.stat != SwappableTensorStat.H2D:
+ return
+ self.stat = SwappableTensorStat.DEVICE
+
+
+
+class SwapManager(metaclass=SingletonBase):
+ def __init__(self):
+ self.swap_tensors = [] # 存储swap出去的tensor
+ self.cpu_tensors = {}
+ self.cpu_tensors_h2d_events = {}
+ self.prefetch_hooked_modules = []
+
+ self.oom_rescue_device_tensors = {}
+ self.oom_rescue_host_tensors = {}
+ self.oom_rescue_total_swap_out_size = 0
+ self.oom_rescue_hooked_modules = []
+
+ # recording
+ self.swap_tensor_in_module = []
+ self.layer_name = ""
+ self.post_layer_forward_and_pre_layer_backward_hooks = []
+ self.forward_time = 0
+
+ self.prefetch_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ self.oom_rescue_stream = torch_npu.npu.current_stream()
+
+ def get_mean_wait_ms(self, event_pairs):
+ time_list = []
+ for forward_time in event_pairs:
+ start, end = forward_time
+ cur_time = start.elapsed_time(end)
+ time_list.append(cur_time)
+ return mean(time_list)
+
+ def is_need_adjust_module(self, max_overhead_percentage=0.05):
+ result = (LayerProfilingHook().get_single_layer_time() - self.forward_time) / self.forward_time > max_overhead_percentage
+ result = broadcast_obj(result)
+ return result
+
+ def no_swap_tensor(self, ori_tensor):
+ if ori_tensor.numel() * ori_tensor.element_size() * 2 < 1024 * 1024:
+ return True
+ if ori_tensor.grad_fn is None:
+ return True
+ if ori_tensor.storage().size() == 0:
+ return True
+ ori_tensor_base = ori_tensor._base
+ if ori_tensor_base is not None and ori_tensor_base.dim() >= 5:
+ return True
+ if ori_tensor_base is not None and ori_tensor_base.grad_fn is None and ori_tensor_base.requires_grad:
+ return True
+ return False
+
+ def prefetch_pack(self, origin_tensor):
+ if self.no_swap_tensor(origin_tensor):
+ return origin_tensor
+ swap_tensor = SwappableTensor(origin_tensor, self.prefetch_stream)
+ if swap_tensor.is_slice_tensor:
+ swap_tensor.tensor_cpu = CpuTensorCache().get_cpu_tensor(origin_tensor.shape, origin_tensor.dtype)
+ swap_tensor.h2d_event = torch.npu.Event()
+ else:
+ if swap_tensor.storage_data_ptr not in self.cpu_tensors:
+ self.cpu_tensors[swap_tensor.storage_data_ptr] = CpuTensorCache().get_cpu_tensor(origin_tensor.shape, origin_tensor.dtype)
+ self.cpu_tensors_h2d_events[swap_tensor.storage_data_ptr] = torch.npu.Event()
+ swap_tensor.tensor_cpu = self.cpu_tensors[swap_tensor.storage_data_ptr]
+ swap_tensor.h2d_event = self.cpu_tensors_h2d_events[swap_tensor.storage_data_ptr]
+ else:
+ swap_tensor.tensor_cpu = self.cpu_tensors[swap_tensor.storage_data_ptr]
+ swap_tensor.h2d_event = self.cpu_tensors_h2d_events[swap_tensor.storage_data_ptr]
+ swap_tensor.stat = SwappableTensorStat.HOST
+ swap_tensor.layer_name = self.layer_name
+
+ # 在tensor开始pack的时候插入event
+ if AdaptiveStepMgr().is_swap_profiling_step():
+ start_pack_event = torch.npu.Event(enable_timing=True)
+ start_pack_event.record()
+ swap_tensor.start_pack_event = start_pack_event # 记录tensor开始swap的时间
+
+ swap_tensor.launch_d2h()
+ self.swap_tensors.append(swap_tensor)
+ if swap_tensor.stat == SwappableTensorStat.D2H:
+ self.swap_tensor_in_module.append(swap_tensor)
+ return swap_tensor
+
+ def prefetch_unpack(self, swap_tensor):
+ if isinstance(swap_tensor, torch.Tensor):
+ return swap_tensor
+
+ if swap_tensor.h2d_event:
+ torch.cuda.current_stream().wait_event(swap_tensor.h2d_event)
+ swap_tensor.change_stat_to_device()
+ CpuTensorCache().release_cpu_tensor(swap_tensor.tensor_cpu)
+ return swap_tensor.tensor
+
+ def _generate_prefetch_forward_hook(self, origin_forward, layer_name):
+ def custom_forward(*args, **kwargs):
+ self.layer_name = layer_name
+ with torch.autograd.graph.saved_tensors_hooks(self.prefetch_pack, self.prefetch_unpack):
+ return origin_forward(*args, **kwargs)
+ return custom_forward
+
+ def hook_prefetch_forward(self, module, layer_name):
+ module.no_prefetch_hook_forward = module.forward
+ self.prefetch_hooked_modules.append(module)
+ module.forward = self._generate_prefetch_forward_hook(module.forward, layer_name)
+
+ def reset_prefetch_hooked_modules(self):
+ for module in self.prefetch_hooked_modules:
+ module.forward = module.no_prefetch_hook_forward
+ self.prefetch_hooked_modules.clear()
+
+ def sync_d2h(self, layer_module, is_mark_first_layer):
+ if not self.swap_tensors:
+ return
+ # Wait until the prefetch is complete.
+ torch.cuda.current_stream().wait_stream(self.prefetch_stream)
+ storage_resized = set()
+ for swap_tensor in self.swap_tensors:
+ if swap_tensor.stat == SwappableTensorStat.D2H:
+ if swap_tensor.storage_data_ptr not in storage_resized:
+ swap_tensor.tensor.storage().resize_(0)
+ storage_resized.add(swap_tensor.storage_data_ptr)
+ swap_tensor.change_stat_to_host()
+
+ layer_module.microbatch_swap_tensors_queue.append(self.swap_tensors)
+ layer_module.microbatch_cpu_tensors_queue.append(self.cpu_tensors)
+
+ self.swap_tensors = []
+ self.cpu_tensors = {}
+ self.cpu_tensors_h2d_events = {}
+ self.swap_tensor_in_module = []
+ if is_mark_first_layer:
+ FuncLocationMgr().is_first_layer = False
+
+
+ def h2d(self, layer_module):
+ if not hasattr(layer_module, 'microbatch_swap_tensors_queue'):
+ return
+ if len(layer_module.microbatch_swap_tensors_queue) == 0 or len(layer_module.microbatch_swap_tensors_queue[-1]) == 0:
+ return
+ swap_tensors = layer_module.microbatch_swap_tensors_queue.pop(0)
+ cpu_tensors = layer_module.microbatch_cpu_tensors_queue.pop(0)
+ storage_resized = set()
+ self.prefetch_stream.wait_stream(torch.cuda.current_stream())
+ for swap_tensor in reversed(swap_tensors):
+ if swap_tensor.storage_data_ptr not in storage_resized:
+ swap_tensor.tensor.storage().resize_(swap_tensor.storage_size)
+ storage_resized.add(swap_tensor.storage_data_ptr)
+ if swap_tensor.storage_data_ptr in cpu_tensors:
+ cpu_tensors.pop(swap_tensor.storage_data_ptr)
+ elif not swap_tensor.is_slice_tensor:
+ swap_tensor.stat = SwappableTensorStat.DEVICE
+ swap_tensor.launch_h2d()
+
+ def change_oom_rescue_tensors_status_to_allowed_swap(self):
+ for wrapped_tensor in self.oom_rescue_device_tensors:
+ wrapped_tensor.is_allowed_oom_rescue_swap = True
+
+ def oom_rescue_pack(self, origin_tensor):
+ if self.no_swap_tensor(origin_tensor):
+ return origin_tensor
+ if origin_tensor.grad_fn is None:
+ return origin_tensor
+ wrapped_tensor = SwappableTensor(origin_tensor, self.oom_rescue_stream, is_prefetch=False)
+ self.oom_rescue_device_tensors[wrapped_tensor] = None
+ return wrapped_tensor
+
+ def oom_rescue_unpack(self, wrapped_tensor: SwappableTensor):
+ if isinstance(wrapped_tensor, torch.Tensor):
+ return wrapped_tensor
+ if wrapped_tensor in self.oom_rescue_host_tensors:
+ self.move_storage_in(wrapped_tensor)
+ self.oom_rescue_device_tensors.pop(wrapped_tensor)
+ wrapped_tensor.cap_tensor = None
+ if wrapped_tensor.bro_tensors is not None:
+ wrapped_tensor.bro_tensors.remove(wrapped_tensor)
+ wrapped_tensor.bro_tensors = None
+ return wrapped_tensor.tensor
+
+ def _generate_oom_rescue_forward_hook(self, origin_forward):
+ def custom_forward(*args, **kwargs):
+ with torch.autograd.graph.saved_tensors_hooks(self.oom_rescue_pack, self.oom_rescue_unpack):
+ return origin_forward(*args, **kwargs)
+ return custom_forward
+
+ def hook_oom_rescue_forward(self, module):
+ module.no_oom_rescue_hook_forward = module.forward
+ self.oom_rescue_hooked_modules.append(module)
+ module.forward = self._generate_oom_rescue_forward_hook(module.forward)
+
+ def reset_oom_rescue_hooked_modules(self):
+ for module in self.oom_rescue_hooked_modules:
+ module.forward = module.no_oom_rescue_hook_forward
+ self.oom_rescue_hooked_modules.clear()
+
+ def get_storage_cap_tensor(self, wrapped_tensor: SwappableTensor):
+ if wrapped_tensor.cap_tensor is not None:
+ return wrapped_tensor.cap_tensor
+ storage_tensor = torch.tensor([], dtype=wrapped_tensor.tensor.dtype, device=wrapped_tensor.tensor.device).set_(wrapped_tensor.tensor.storage())
+ wrapped_storage_tensor = SwappableTensor(storage_tensor, self.oom_rescue_stream, is_prefetch=False)
+ wrapped_storage_tensor.tensor_cpu = torch.empty(storage_tensor.shape, dtype=storage_tensor.dtype, pin_memory=True, device='cpu')
+ return wrapped_storage_tensor
+
+
+ def get_share_storage_tensors(self, wrapped_tensor: SwappableTensor):
+ result = set()
+ storage_data_ptr = wrapped_tensor.tensor.storage().data_ptr()
+ for wt in self.oom_rescue_device_tensors:
+ if wt.tensor.storage().data_ptr() == storage_data_ptr:
+ result.add(wt)
+ return result
+
+ def move_storage_out(self, wrapped_tensor: SwappableTensor):
+ if wrapped_tensor not in self.oom_rescue_device_tensors:
+ return 0, 0
+ storage_size = wrapped_tensor.storage_size * wrapped_tensor.tensor.element_size()
+ share_storage_tensors = wrapped_tensor.bro_tensors if wrapped_tensor.bro_tensors is not None else self.get_share_storage_tensors(wrapped_tensor)
+ cap_tensor = self.get_storage_cap_tensor(wrapped_tensor)
+ cap_tensor.launch_d2h()
+ cap_tensor.stat = SwappableTensorStat.HOST
+ for wt in share_storage_tensors:
+ wt.stat = SwappableTensorStat.HOST
+ wt.bro_tensors = share_storage_tensors
+ wt.cap_tensor = cap_tensor
+ self.oom_rescue_device_tensors.pop(wt)
+ self.oom_rescue_host_tensors[wt] = None
+ wrapped_tensor.tensor.storage().resize_(0)
+ return storage_size, len(share_storage_tensors)
+
+ def move_storage_in(self, wrapped_tensor: SwappableTensor):
+ wrapped_tensor.tensor.storage().resize_(wrapped_tensor.storage_size)
+ share_storage_tensors = wrapped_tensor.bro_tensors
+ wrapped_tensor.cap_tensor.launch_h2d()
+ wrapped_tensor.cap_tensor.stat = SwappableTensorStat.DEVICE
+ for wt in share_storage_tensors:
+ wt.stat = SwappableTensorStat.DEVICE
+ self.oom_rescue_host_tensors.pop(wt)
+ self.oom_rescue_device_tensors[wt] = None
+
+
+ def is_exist_tensor_allowed_swap(self):
+ for wt in self.oom_rescue_device_tensors:
+ if wt.is_allowed_oom_rescue_swap:
+ return True
+ return False
+
+ def is_exist_tensor_contiguous(self):
+ for wt in self.oom_rescue_device_tensors:
+ if wt.is_allowed_oom_rescue_swap and wt.tensor.is_contiguous():
+ return True
+ return False
+
+ def swap_out_by_size(self, size):
+ print_rank_0("Need size %d (%fMB)" % (size, size / 1024 / 1024))
+ if not self.is_exist_tensor_allowed_swap():
+ return False
+ swap_size = 0
+ swap_num = 0
+ only_swap_contiguous_tensor = self.is_exist_tensor_contiguous()
+ device_tensors = list(self.oom_rescue_device_tensors.keys())
+ for wrapped_tensor in device_tensors:
+ if swap_size >= size:
+ break
+ if not wrapped_tensor.is_allowed_oom_rescue_swap:
+ continue
+ if only_swap_contiguous_tensor and not wrapped_tensor.tensor.is_contiguous():
+ continue
+
+ storage_size, moved_tensor_count = self.move_storage_out(wrapped_tensor)
+ swap_size += storage_size
+ swap_num += moved_tensor_count
+
+ if swap_size != 0:
+ print_rank_0("swap tensor to CPU, tensor num: %d, release NPU memory size: %d (%fMB)" % (
+ swap_num, swap_size, swap_size / 1024 / 1024))
+ print_rank_0("tensor nums wrap manager for [device: %d, CPU: %d]" % (
+ len(self.oom_rescue_device_tensors), len(self.oom_rescue_host_tensors)))
+ self.oom_rescue_total_swap_out_size += swap_size
+ return True
+
+ def reset_oom_rescue_tensors(self):
+ self.oom_rescue_device_tensors.clear()
+ self.oom_rescue_host_tensors.clear()
+
+ def reset_all_for_oom_rescue(self):
+ self.reset_oom_rescue_tensors()
+ self.reset_oom_rescue_hooked_modules()
+
+ def reset_post_layer_forward_and_pre_layer_backward_hooks(self):
+ for hook_handle in self.post_layer_forward_and_pre_layer_backward_hooks:
+ hook_handle.remove()
+ self.post_layer_forward_and_pre_layer_backward_hooks.clear()
+
+
+def transformer_layer_register_post_forward_hook(module, is_mark_first_layer=False):
+ def post_forward_hook(module, *args, **kwargs):
+ if not torch.is_grad_enabled():
+ return
+ if not hasattr(module, 'microbatch_swap_tensors_queue'):
+ setattr(module, 'microbatch_swap_tensors_queue', [])
+ setattr(module, 'microbatch_cpu_tensors_queue', [])
+ SwapManager().sync_d2h(module, is_mark_first_layer)
+ SwapManager().change_oom_rescue_tensors_status_to_allowed_swap()
+ return
+
+ post_hook = module.register_forward_hook(post_forward_hook)
+ SwapManager().post_layer_forward_and_pre_layer_backward_hooks.append(post_hook)
+
+
+def transformer_layer_register_pre_backward_hook(module):
+ def post_forward_hook(module, args, output):
+ if not torch.is_grad_enabled():
+ return
+
+ def tensor_backward_hook(grad_output):
+ SwapManager().h2d(module)
+ if isinstance(output, tuple):
+ output = output[0]
+ output.register_hook(tensor_backward_hook)
+ pre_back_hook = module.register_forward_hook(post_forward_hook)
+ SwapManager().post_layer_forward_and_pre_layer_backward_hooks.append(pre_back_hook)
+
+
+class LayerProfilingHook(metaclass=SingletonBase):
+ def __init__(self):
+ self.hook_handles = []
+ self.forward_time_list = []
+ self.last_compute_forward_time = None
+
+ def _layer_register_pre_forward_hook(self, module):
+ def pre_forward_hook(module, args):
+ if AdaptiveStepMgr().is_layer_profiling_step() or AdaptiveStepMgr().is_all_profiling_done():
+ start_event = torch.npu.Event(enable_timing=True)
+ self.forward_time_list.append([start_event])
+ start_event.record()
+ else:
+ return
+ hook_handler = module.register_forward_pre_hook(pre_forward_hook)
+ self.hook_handles.append(hook_handler)
+
+
+ def _layer_register_post_forward_hook(self, module):
+ def post_forward_hook(module, args, output):
+ if AdaptiveStepMgr().is_layer_profiling_step() or AdaptiveStepMgr().is_all_profiling_done():
+ end_event = torch.npu.Event(enable_timing=True)
+ self.forward_time_list[-1].append(end_event)
+ end_event.record()
+ else:
+ return
+ hook_handler = module.register_forward_hook(post_forward_hook)
+ self.hook_handles.append(hook_handler)
+
+ def apply_layer_profiling_hook(self, module):
+ self._layer_register_pre_forward_hook(module)
+ self._layer_register_post_forward_hook(module)
+
+ def reset_layer_profiling_hook(self):
+ for hook_handler in self.hook_handles:
+ hook_handler.remove()
+ self.hook_handles.clear()
+
+ def get_single_layer_time(self):
+ total_time = 0
+ forward_cnt = len(self.forward_time_list)
+ for event_pair in self.forward_time_list:
+ start_event, end_event = event_pair
+ tmp_time = start_event.elapsed_time(end_event)
+ total_time += tmp_time
+ self.last_compute_forward_time = total_time / forward_cnt
+ return self.last_compute_forward_time
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_tool.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..85dc24740a3d6d176139f61f4f7995f0ede17958
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/adaptive_memory_tool.py
@@ -0,0 +1,219 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+
+from enum import Enum, IntEnum
+from typing import Set, Dict
+import pickle
+import torch
+from megatron.core import parallel_state as ps
+from megatron.training import print_rank_0
+
+BYTES_PER_MB = 1024 * 1024
+
+
+class LayerAction(IntEnum):
+ FULL_RECOMPUTE = 0
+ FULL_SWAP = 1
+ ADAPTIVE = 2
+ NONE = 3
+
+
+class ModuleAction(IntEnum):
+ RECOMPUTE = 0
+ SWAP = 1
+ NONE = 2
+
+
+class SingletonBase(type):
+ singleton_instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls.singleton_instances:
+ instance = super().__call__(*args, **kwargs)
+ cls.singleton_instances[cls] = instance
+ return cls.singleton_instances[cls]
+
+
+class ContextKey(SingletonBase):
+ NAME = "name" # module name
+ DEEP = "deep" # module depth
+ PREFIX_NAME = "prefix_name" # module parent name
+ MODULE = "module" # module field
+ SUBMODULES = "submodules" # children modules
+ INPUT = "input" # input args total size
+ MEMORY = "memory" # current module's activation memory + input + output
+ OUTPUT = "output" # output total size
+ FORWARD_CNT = "forward_cnt" # forward called times
+ PRE_TOTAL_TIME = "pre_total_time" # forward called total time
+ AVG_TIME = "avg_time" # forward called avg time
+ ALLOWED_ADAPT = "allowed_adapt" # allowed adapted modules, user init and set name at startup
+ IS_FUNCTION = "is_function" # mark module if it is a torch.autograd.Function
+ IS_MODULE_LIST = "is_module_list" # mark module if it is a torch.nn.ModuleList
+ IS_ADAPT_LAYER = "is_adapt_layer" # mark self or parent if self is ALLOWED_ADAPT
+ USED_MEM = "used_mem" # model memory consumption, as allocated memory
+ DEVICE_MEMORY = "device_memory" # device total memory
+ # adaptive
+ MODULE_FORWARD_TOTAL_TIME = "module_forward_total_time" # module forward called total time
+ MODULE_FORWARD_AVG_TIME = "module_forward_avg_time" # module forward called avg time
+ MODULE_FORWARD_CNT = "module_forward_cnt" # module forward called times
+ MODULE_SWAP_TOTAL_TIME = "module_swap_total_time" # module forward called total swap time
+ MODULE_SWAP_AVG_TIME = "module_swap_avg_time" # module forward called avg swap time
+ MODULE_SWAP_CNT = "module_swap_cnt" # module forward called swap times
+ MODULE_SWAP_TOTAL_MEMORY = "module_swap_total_memory" # module forward swap total memory
+ MODULE_SWAP_AVG_MEMORY = "module_swap_avg_memory" # module forward swap avg memory
+ IS_SWAP = "is_swap" # mark module if it is swap
+ IS_LAYER0_OF_MODULE0 = "is_layer0_of_module0" # mark module if it is layer0 of module0
+ IS_MODLUE_OF_LAYER0 = "is_modlue_of_layer0" # mark module if it belongs to layer0 of module0
+
+
+class FuncLocationMgr(metaclass=SingletonBase):
+ def __init__(self):
+ self._module_names = []
+ self._function_in_stack = None
+ self._function_child = None
+ self.is_first_layer = False
+
+ def push_name(self, prefix, name):
+ self._module_names.append(f"{prefix}.{name}")
+ if self._function_in_stack and not self._function_child:
+ self._function_child = f"{prefix}.{name}"
+
+ def pop_name(self, prefix, name):
+ last_name = self._module_names.pop()
+ if f"{prefix}.{name}" != last_name:
+ raise ValueError(f"unexpected module name in stack, expect:{prefix}.{name}, find:{last_name}")
+
+ def get_latest_name(self):
+ return self._module_names[-1]
+
+ def set_function_in_stack(self):
+ self._function_in_stack = True
+
+ def get_function_location(self, parent):
+ if not self._function_child:
+ direct_child = ""
+ else:
+ first_child = self._function_child[len(parent):]
+ direct_child = first_child.split(".")[1]
+ self._function_child = None
+ self._function_in_stack = False
+ return direct_child
+
+
+class AdaptiveStepMgr(metaclass=SingletonBase):
+ def __init__(self):
+ self.cur_step = 1
+ self.skip_steps = 3
+ self.recompute_profiling_steps = 0
+ self.layer_profiling_steps = 5
+ self.swap_profiling_steps = 0
+ self.pre_steps = 0
+
+ def init_steps(self, recompute_profiling_steps, swap_profiling_steps):
+ self.recompute_profiling_steps = recompute_profiling_steps
+ self.swap_profiling_steps = swap_profiling_steps
+ self.pre_steps = self.skip_steps + recompute_profiling_steps + swap_profiling_steps + self.layer_profiling_steps
+
+ def get_cur_step(self):
+ return self.cur_step
+
+ def reset_step(self, step_num):
+ self.cur_step = step_num
+
+ def incr_step(self):
+ self.cur_step += 1
+
+ def is_skipping_step(self): # 两处调用,profiling时决定是否下发event,step里是否return
+ return self.cur_step <= self.skip_steps
+
+ def is_recompute_profiling_step(self):
+ pre_steps = self.skip_steps
+ return pre_steps < self.cur_step <= pre_steps + self.recompute_profiling_steps
+
+ def is_last_recompute_profiling_step(self):
+ return self.cur_step == (self.skip_steps + self.recompute_profiling_steps)
+
+ def is_layer_profiling_step(self):
+ pre_steps = self.skip_steps + self.recompute_profiling_steps
+ return pre_steps < self.cur_step <= pre_steps + self.layer_profiling_steps
+
+ def is_last_layer_profiling_step(self):
+ return self.cur_step == self.skip_steps + self.recompute_profiling_steps + self.layer_profiling_steps
+
+ def is_layer_profiling_done(self):
+ return self.cur_step >= self.skip_steps + self.recompute_profiling_steps + self.layer_profiling_steps
+
+ def is_all_profiling_done(self): # note: this called in step_func, should use > instead of >=
+ return self.cur_step > self.pre_steps
+
+ def is_swap_profiling_step(self):
+ pre_steps = self.skip_steps + self.recompute_profiling_steps + self.layer_profiling_steps
+ return pre_steps < self.cur_step <= self.pre_steps
+
+ def is_swap_profiling_done(self): # note: this called after step_func, should use >= instead of >
+ return self.cur_step >= self.pre_steps
+
+
+class ForwardCounter(metaclass=SingletonBase):
+ def __init__(self):
+ self._counter: int = 0
+
+ def get_count(self):
+ return self._counter
+
+ def incr_cnt(self):
+ self._counter += 1
+
+
+class FuncLocation:
+ def __init__(self, idx: int, func_name: str, action: ModuleAction):
+ self.layer_idx = idx
+ self.func_name = func_name
+ self.action = action
+
+
+class CpuTensorCache(metaclass=SingletonBase):
+ def __init__(self):
+ self.shape_to_tensor_list_map: Dict[(torch.Size, torch.dtype), Set[torch.Tensor]] = {}
+
+ def get_cpu_tensor(self, shape: torch.Size, dtype: torch.dtype):
+ tensor_set = self.shape_to_tensor_list_map.setdefault((shape, dtype), set())
+ if len(tensor_set) != 0:
+ cpu_tensor = tensor_set.pop()
+ else:
+ cpu_tensor = torch.empty(shape, dtype=dtype, pin_memory=True, device='cpu')
+ return cpu_tensor
+
+ def release_cpu_tensor(self, cpu_tensor):
+ tensor_set = self.shape_to_tensor_list_map.setdefault((cpu_tensor.shape, cpu_tensor.dtype), set())
+ tensor_set.add(cpu_tensor)
+
+
+def broadcast_in_mp_dp(tensor, src, mp, dp):
+ if mp > 1 and ps.get_tensor_model_parallel_src_rank() == src:
+ broadcast_src = ps.get_tensor_model_parallel_src_rank()
+ broadcast_group = ps.get_tensor_model_parallel_group()
+ torch.distributed.broadcast(tensor, src=broadcast_src, group=broadcast_group)
+ if dp > 1:
+ broadcast_src = ps.get_data_parallel_src_rank(True)
+ broadcast_group = ps.get_data_parallel_group(True)
+ torch.distributed.broadcast(tensor, src=broadcast_src, group=broadcast_group)
+
+
+def broadcast_obj(obj):
+ mp = ps.get_tensor_model_parallel_world_size()
+ dp = ps.get_data_parallel_world_size(True)
+
+ global_rank = torch.distributed.get_rank()
+ src = (global_rank // (mp * dp)) * dp * mp
+ obj_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
+ obj_shape_tensor = torch.tensor(obj_tensor.shape, device=torch.npu.current_device())
+ broadcast_in_mp_dp(obj_shape_tensor, src, mp, dp)
+ obj_len = obj_shape_tensor.cpu().tolist()
+ if global_rank == src:
+ obj_tensor_npu = obj_tensor.npu()
+ else:
+ obj_tensor_npu = torch.empty(obj_len, dtype=torch.uint8, device=torch.npu.current_device())
+ broadcast_in_mp_dp(obj_tensor_npu, src, mp, dp)
+ result = pickle.loads(obj_tensor_npu.cpu().numpy().tobytes())
+ del obj_tensor_npu
+ return result
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/cpu_binder.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/cpu_binder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b6708a4544171fe09d3acc62575f8fe6713a929
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_memory/cpu_binder.py
@@ -0,0 +1,109 @@
+import os
+import psutil
+from megatron.training import print_rank_0
+
+
+def _get_pcie_info(devices, keyword="PCIeBusInfo"):
+ device_pcie_tbl = dict()
+ for device in devices:
+ pcie_info = os.popen(f"npu-smi info -t board -i {device}").read().strip().split("\n")
+ for _ in pcie_info:
+ line = ''.join(_.split())
+ if line.startswith(keyword):
+ device_pcie_tbl[device] = line[len(keyword) + 1:]
+ break
+
+ return device_pcie_tbl
+
+
+def _get_numa_info(pcie_tbl, keyword="NUMAnode"):
+ device_numa_tbl = dict() # key is device id, value is numa id
+ numa_devices_tbl = dict() # key is numa id, value is device id list
+
+ for device, pcie_no in pcie_tbl.items():
+ numa_info = os.popen(f"lspci -s {pcie_no} -vvv").read().strip().split("\n")
+ for _ in numa_info:
+ line = ''.join(_.split())
+ if line.startswith(keyword):
+ numa_id = int(line[len(keyword) + 1:])
+ device_numa_tbl[device] = numa_id
+
+ devices = numa_devices_tbl.get(numa_id, None)
+ if devices is None:
+ numa_devices_tbl[numa_id] = list()
+
+ numa_devices_tbl[numa_id].append(device)
+ break
+
+ return device_numa_tbl, numa_devices_tbl
+
+
+def _get_cpu_info(numa_ids, keyword1="NUMAnode", keyword2="CPU(s)"):
+ cpu_idx_tbl = dict()
+ numa_keywords = [keyword1 + str(idx) + keyword2 for idx in numa_ids]
+ cpu_info = os.popen(f"lscpu").read().strip().split("\n")
+ for _ in cpu_info:
+ line = ''.join(_.split())
+ if any(line.startswith(word) for word in numa_keywords):
+ split_info = line.split(":")
+ cpu_id_ranges = split_info[-1].split(",")
+
+ ranges = list()
+ for range_str in cpu_id_ranges:
+ endpoints = range_str.split("-")
+ if len(endpoints) != 2:
+ raise Exception("lscpu command output error, please check !")
+
+ ranges += [cid for cid in range(int(endpoints[0]), int(endpoints[1]) + 1)]
+
+ numa_id = int(split_info[0].replace(keyword1, '').replace(keyword2, ''))
+ cpu_idx_tbl[numa_id] = ranges
+ return cpu_idx_tbl
+
+
+# 可以用export CPU_BINDING_NUM设置每个进程绑的核数;如果不设置CPU_BINDING_NUM,
+# 会根据ratio(numa利用率)进行计算,如果有64个核,0.5表示用一半,用32个核, 平分给亲和在这个numa上的npu
+def bind_cpus(world_size, rank_id, device_id, ratio=0.5):
+ devices = [_ for _ in range(device_id, device_id + world_size)]
+ # 获取npu和pcie的对应关系
+ device_pcie_tbl = _get_pcie_info(devices)
+ # 根据pcie信息获取npu和numa的对应关系
+ device_numa_tbl, numa_devices_tbl = _get_numa_info(device_pcie_tbl)
+ # 获取使用的numa对应的cpu核分配信息
+ cpu_idx_tbl = _get_cpu_info(list(numa_devices_tbl.keys()))
+
+ # 当前rank的npu id
+ cur_device = rank_id + device_id
+ # 获取npu对应的numa id
+ numa_id = device_numa_tbl[cur_device]
+
+ # 获取共享该numa的npu信息
+ shard_devices = numa_devices_tbl[numa_id]
+ # 按照npu id进行排序
+ shard_devices.sort()
+
+ # 获取该numa上所有的cpu id信息
+ all_cpus = cpu_idx_tbl[numa_id]
+
+ cpu_nums = len(all_cpus)
+ # 计算给该共享numa的npu分配的核的个数
+ CPU_BINDING_NUM = os.environ.get("CPU_BINDING_NUM", None)
+ if CPU_BINDING_NUM is None:
+ cpu_num_per_device = int(cpu_nums * ratio // len(shard_devices))
+ else:
+ cpu_num_per_device = int(CPU_BINDING_NUM)
+ if len(shard_devices) * cpu_num_per_device > cpu_nums:
+ raise Exception(
+ f"Cpu num in numa {numa_id} to assign {cpu_num_per_device} for every device is not enough, "
+ f"please decrease the value of CPU_BINDING_NUM!")
+
+ # 获取该npu的下标信息
+ idx = shard_devices.index(cur_device)
+ # 给该npu分配要绑定的cpu id
+ binding_cpus = [all_cpus[_] for _ in range(idx * cpu_num_per_device, (idx + 1) * cpu_num_per_device)]
+
+ # cpu bind
+ p = psutil.Process()
+ p.cpu_affinity(binding_cpus)
+ new_affinity = p.cpu_affinity()
+ print_rank_0("Bind cpu successful!!!")
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/__init__.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute.py
new file mode 100644
index 0000000000000000000000000000000000000000..3450c824fb20600885fb7c2ad4f1c1c5d9b61efc
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute.py
@@ -0,0 +1,782 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import sys
+from copy import deepcopy
+from functools import wraps
+from collections.abc import Iterable
+
+import acl
+import numpy as np
+import torch
+import torch.nn
+import torch_npu
+
+from megatron.training import print_rank_0
+from megatron.core.num_microbatches_calculator import get_num_microbatches
+from megatron.training import get_args
+from megatron.core import parallel_state
+
+from mindspeed.core.memory.adaptive_recomputing.adaptive_recompute_apply import get_recompute_hook
+from mindspeed.core.memory.adaptive_recomputing.adaptive_recompute_apply import get_swap_hook
+from mindspeed.core.memory.adaptive_recomputing.adaptive_recompute_apply import register_recursive_apply as apply_adaptive_recompute
+from mindspeed.core.memory.adaptive_recomputing.adaptive_recompute_apply import register_recursive_apply_prefetch as apply_prefetch_strategy
+from mindspeed.core.memory.adaptive_recomputing.adaptive_recompute_solver import get_graph_solver, GraphSolver
+from mindspeed.core.memory.adaptive_recomputing.swap_manager import SwapManager, get_tensor_mem_size
+
+DTYPE_NBYTES_MAP = {"bf16": 2, "fp16": 2, "fp32": 4}
+
+
+class AdaptiveRecomputePolicy:
+ adaptive_recomputing_policy = None
+
+ def __init__(self):
+ # total swap out size after OOM
+ self.record_swap_out_size = 0
+ # module context copy
+ self.context_copy = None
+ # unit for device memory size(MB)
+ self.unit_mb = 1024 * 1024
+ # find target device memory for policy
+ self.is_find_target_device_memory = False
+ # swap size for this step OOM
+ self.swap_size = 0
+
+ # policy
+ self.cur_recompute_policy = []
+ self.oom_recompute_policy_list = []
+ self.normal_recompute_policy_list = []
+
+ # device memory dichotomy for solve graph
+ self.device_memory_dichotomy_left = 0
+ self.device_memory_dichotomy_right = 0
+ self.cur_device_memory = -1
+ self.stop_dichotomy_value = 1
+
+ # device memory free default is maxsize
+ self.default_device_memory = sys.maxsize
+ all_args = get_args()
+ if all_args.adaptive_recompute_device_size >= 0:
+ self.default_device_memory = all_args.adaptive_recompute_device_size
+ self.hccl_memory = 0
+
+ self.remove_swap_manager_hook_step = 0
+ torch_npu.npu.init()
+ self.last_num_alloc_retries = torch.npu.memory_stats()["num_alloc_retries"]
+ self.change_num_alloc_retries_times = 0
+ self.first_non_oom_device_memory = 0
+ self.check_non_oom_times = 0
+
+ # swap_attention
+ self.interval = 0
+ self.threshold_prefetch = 0
+ self.num_prefetch = 0
+ self.num_layers = 0
+
+ @staticmethod
+ def tensor_all_reduce(num_list, op):
+ shard_tensor = torch.tensor(num_list, device=torch.npu.current_device())
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
+ torch.distributed.all_reduce(
+ shard_tensor,
+ op=op,
+ group=parallel_state.get_tensor_model_parallel_group(), )
+ if parallel_state.get_data_parallel_world_size() > 1:
+ torch.distributed.all_reduce(
+ shard_tensor,
+ op=op,
+ group=parallel_state.get_data_parallel_group(), )
+ result = shard_tensor.cpu().numpy().tolist()
+ del shard_tensor
+ return result
+
+ @staticmethod
+ def is_policy_in_list(policy, policy_list):
+ for p in policy_list:
+ if np.all(p == policy):
+ return True
+ return False
+
+
+ def is_stable_policy(self, profiling_step):
+ all_args = get_args()
+ # not activate swap function or remove swap manager hook
+ if not all_args.adaptive_recompute_device_swap or (profiling_step > self.remove_swap_manager_hook_step != 0):
+ return True
+
+ total_swap_out_size = SwapManager().total_swap_out_size
+ self.swap_size = (total_swap_out_size - self.record_swap_out_size) // self.unit_mb
+ self.check_num_alloc_retries()
+ num_list = [
+ int(total_swap_out_size), int(self.hccl_memory), int(self.swap_size),
+ int(self.is_find_target_device_memory), int(self.change_num_alloc_retries_times)
+ ]
+ size_tensor = self.tensor_all_reduce(num_list, torch.distributed.ReduceOp.MAX)
+ total_swap_out_size = size_tensor[0]
+ self.hccl_memory = size_tensor[1]
+ self.swap_size = size_tensor[2]
+ self.is_find_target_device_memory = bool(size_tensor[3])
+ self.change_num_alloc_retries_times = size_tensor[4]
+ SwapManager().total_swap_out_size = total_swap_out_size
+
+ if self.swap_size <= 0 and self.is_find_target_device_memory:
+ return True
+ self.record_swap_out_size = total_swap_out_size
+ return False
+
+ def get_default_device_memory(self, max_device_memory):
+ self.default_device_memory = min(self.default_device_memory, max_device_memory)
+ size_tensor = self.tensor_all_reduce([int(self.default_device_memory)], torch.distributed.ReduceOp.MIN)
+ self.default_device_memory = size_tensor[0]
+
+ def check_cur_recompute_policy(self):
+ if len(self.cur_recompute_policy) == 0:
+ return
+ is_exist_oom = self.is_policy_in_list(self.cur_recompute_policy, self.oom_recompute_policy_list)
+ is_exist_normal = self.is_policy_in_list(self.cur_recompute_policy, self.normal_recompute_policy_list)
+ if self.swap_size > 0:
+ if not is_exist_oom:
+ self.oom_recompute_policy_list.append(deepcopy(self.cur_recompute_policy))
+ if is_exist_normal:
+ self.normal_recompute_policy_list.remove(self.cur_recompute_policy)
+ return
+ if is_exist_oom or self.change_num_alloc_retries_times != 0:
+ return
+ if not is_exist_normal:
+ self.normal_recompute_policy_list.append(deepcopy(self.cur_recompute_policy))
+
+ def dichotomy_best(self):
+ # last policy is instability
+ if self.is_find_target_device_memory:
+ self.device_memory_dichotomy_left = self.first_non_oom_device_memory
+ self.device_memory_dichotomy_right = self.cur_device_memory
+ self.is_find_target_device_memory = False
+ if self.cur_device_memory == -1:
+ return self.default_device_memory
+
+ # OOM
+ if self.swap_size > 0:
+ self.check_non_oom_times = 0
+ self.change_num_alloc_retries_times = 0
+ self.device_memory_dichotomy_right = self.cur_device_memory
+ if self.first_non_oom_device_memory >= self.cur_device_memory:
+ self.first_non_oom_device_memory = 0
+ if self.device_memory_dichotomy_right <= self.device_memory_dichotomy_left:
+ self.device_memory_dichotomy_left = 0
+ return (self.device_memory_dichotomy_left + self.device_memory_dichotomy_right) // 2
+
+ # check non oom policy
+ if self.change_num_alloc_retries_times != 0 and self.check_non_oom_times == 0:
+ print_rank_0(f"current policy may be an unstable one, try to check it once again, "
+ f"policy device memory: {self.cur_device_memory}")
+ self.check_non_oom_times += 1
+ self.change_num_alloc_retries_times = 0
+ return self.cur_device_memory
+
+ self.check_non_oom_times = 0
+ self.change_num_alloc_retries_times = 0
+ self.device_memory_dichotomy_left = self.cur_device_memory
+ if self.first_non_oom_device_memory == 0:
+ self.first_non_oom_device_memory = self.cur_device_memory
+ if self.device_memory_dichotomy_right - self.device_memory_dichotomy_left <= self.stop_dichotomy_value:
+ self.is_find_target_device_memory = True
+ return self.device_memory_dichotomy_left
+
+ return (self.device_memory_dichotomy_left + self.device_memory_dichotomy_right) // 2
+
+ def solve_recompute_policy(self, profiling_step):
+ is_known_policy = True
+ self.remove_swap_manager_hook_step = profiling_step + 1
+ swap_size = self.swap_size
+ recompute_policy_list = None
+ while is_known_policy:
+ torch.npu.synchronize()
+ self.cur_device_memory = self.dichotomy_best()
+ if self.check_non_oom_times == 0:
+ recompute_policy_list = get_graph_solver().get_policy(self.cur_device_memory)
+ np_result = np.array(recompute_policy_list)
+ self.cur_recompute_policy = np.array([r * r[0] for r in np_result]).sum(axis=0).tolist()
+ if self.is_find_target_device_memory:
+ self.remove_swap_manager_hook_step = profiling_step + 10
+ print_rank_0(
+ f"success to find the target value of the current round of search: {self.cur_device_memory}")
+ break
+ # OOM policy
+ if self.is_policy_in_list(self.cur_recompute_policy, self.oom_recompute_policy_list):
+ self.swap_size = max(self.swap_size, 1)
+ continue
+ # no OOM policy
+ if self.is_policy_in_list(self.cur_recompute_policy, self.normal_recompute_policy_list):
+ self.swap_size = 0
+ continue
+ is_known_policy = False
+ if recompute_policy_list is None:
+ print_rank_0(f"{get_graph_solver().final_policy_info}")
+ return None
+ get_graph_solver().print_list_to_policy(recompute_policy_list)
+ print_rank_0(
+ f"max available memory: {self.context_copy['max_device_memory']}, previous policy swap size: {swap_size}, "
+ f"next policy device memory: {self.cur_device_memory}")
+ print_rank_0(f"{get_graph_solver().without_recompute_info}\n{get_graph_solver().all_recompute_info}\n"
+ f"{get_graph_solver().selective_recompute_info}\n{get_graph_solver().final_policy_info}")
+ return self.set_tag_to_context(recompute_policy_list)
+
+ def set_tag_to_context(self, recompute_policy_list):
+ context = deepcopy(self.context_copy)
+ solver = GraphSolver()
+ solver.layer_full_recompute_combination = get_graph_solver().layer_full_recompute_combination
+ solver.layer_without_recompute_combination = get_graph_solver().layer_without_recompute_combination
+ solver.layer_recompute_one_combination = get_graph_solver().layer_recompute_one_combination
+ solver.layers_combination = get_graph_solver().layers_combination
+ solver.get_layers_module(context, "")
+ solver.get_no_recompute_layer()
+ solver.apply_policy_to_model(recompute_policy_list)
+ return context
+
+ def check_num_alloc_retries(self):
+ num_alloc_retries = torch.npu.memory_stats()["num_alloc_retries"]
+ if num_alloc_retries == self.last_num_alloc_retries:
+ return
+ retries_times = num_alloc_retries - self.last_num_alloc_retries
+ self.last_num_alloc_retries = num_alloc_retries
+ if self.swap_size == 0 and (retries_times > 1 or self.check_non_oom_times != 0):
+ self.swap_size = 1
+ if self.swap_size > 0:
+ return
+
+ self.change_num_alloc_retries_times += 1
+ if self.change_num_alloc_retries_times > 1:
+ print_rank_0(f"[^?^?^] this is a unstable policy, try select another one.")
+ self.swap_size = 1
+
+ def granular_module_allocation(self, vpp_size, recompute_num_layers, cur_pp_noop_layers):
+ swap_list = []
+ recompute_list = []
+ args = get_args()
+ cur_pp_rank = parallel_state.get_pipeline_model_parallel_rank()
+ pp_size = args.pipeline_model_parallel_size or 1
+ vpp_layer = args.num_layers_per_virtual_pipeline_stage
+ if self.num_prefetch <= vpp_size:
+ swap_list = [['0'] if i < self.num_prefetch else [''] for i in range(vpp_size)]
+ else:
+ for chunk in range(vpp_size):
+ chunk_swap_layer = ['0']
+ for layer_id in range(vpp_size, self.num_prefetch):
+ if layer_id % vpp_size == chunk:
+ chunk_swap_layer.append(f'{layer_id // vpp_size}')
+ swap_list.append(chunk_swap_layer)
+
+ if recompute_num_layers <= vpp_size:
+ recompute_list = [['0'] if i < recompute_num_layers else [''] for i in range(vpp_size)]
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True) and args.reduce_recompute_for_last_chunk:
+ recompute_list[-1] = ['']
+ else:
+ for chunk in range(vpp_size):
+ chunk_recompute_layer = ['0']
+ for layer_id in range(vpp_size, recompute_num_layers):
+ if layer_id % vpp_size == chunk:
+ chunk_recompute_layer.append(f'{layer_id // vpp_size}')
+ recompute_list.append(chunk_recompute_layer)
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True) and args.reduce_recompute_for_last_chunk:
+ if recompute_list[-1][-1] == str(args.num_layers_per_virtual_pipeline_stage - 1):
+ recompute_list[-1].pop()
+ if len(recompute_list[-1]) == 0:
+ recompute_list[-1].append('')
+ for vpp in range(vpp_size):
+ vpp_layers = swap_list[vpp]
+ for i in range(len(vpp_layers)):
+ layer_id = vpp * vpp_layer * pp_size + i + vpp_layer * cur_pp_rank
+ if layer_id in cur_pp_noop_layers:
+ swap_list[vpp][i] = ''
+ if len(recompute_list[vpp]) >= i + 1:
+ recompute_list[vpp][i] = ''
+
+ prefetch_list = swap_list
+ interval = 0
+ prefetch_recompute_group = [swap_list, prefetch_list, recompute_list]
+ return [prefetch_recompute_group, interval, self.num_prefetch, cur_pp_noop_layers]
+
+ def get_cur_stage_noop_layers(self, noop_layers, cur_pp_rank):
+ all_args = get_args()
+ cur_pp_noop_layers = []
+ pp_size = all_args.pipeline_model_parallel_size or 1
+ layers_per_pp = all_args.num_layers // pp_size
+ vpp_layer = all_args.num_layers_per_virtual_pipeline_stage or layers_per_pp
+ vpp_layers = vpp_layer * pp_size
+ for i in noop_layers:
+ pp_id = (i % vpp_layers) // vpp_layer
+ if pp_id == cur_pp_rank:
+ cur_pp_noop_layers.append(i)
+ return cur_pp_noop_layers
+
+ def solve_prefetch_policy(self):
+ all_args = get_args()
+ noop_layers = list(all_args.noop_layers) if isinstance(all_args.noop_layers, set) else []
+ cur_pp_rank = parallel_state.get_pipeline_model_parallel_rank()
+ cur_pp_noop_layers = self.get_cur_stage_noop_layers(noop_layers, cur_pp_rank)
+ recompute_num_layers = all_args.recompute_num_layers or 0
+ pp_size = all_args.pipeline_model_parallel_size or 1
+ vpp_size = all_args.virtual_pipeline_model_parallel_size or 1
+ per_pp_layers = all_args.num_layers // pp_size
+ per_vpp_layers = all_args.num_layers_per_virtual_pipeline_stage or per_pp_layers
+ if not all_args.enable_recompute_layers_per_pp_rank:
+ if recompute_num_layers >= per_vpp_layers:
+ recompute_num_layers = per_pp_layers
+ else:
+ recompute_num_layers *= vpp_size
+ else:
+ if recompute_num_layers >= per_pp_layers:
+ recompute_num_layers = per_pp_layers
+ if all_args.recompute_method == 'block':
+ self.num_prefetch = recompute_num_layers
+ elif all_args.recompute_method == 'uniform':
+ recompute_num_layers = per_pp_layers
+ self.num_prefetch = recompute_num_layers
+ else:
+ self.num_prefetch = per_pp_layers
+ self.interval = 0
+ if vpp_size > 1:
+ return self.granular_module_allocation(vpp_size, recompute_num_layers, cur_pp_noop_layers)
+ else:
+ swap_list, recompute_list = [], []
+ for i in range(self.num_prefetch):
+ if i + cur_pp_rank * per_pp_layers not in cur_pp_noop_layers:
+ swap_list.append(str(i))
+ else:
+ swap_list.append('')
+ for i in range(recompute_num_layers):
+ if i + cur_pp_rank * per_pp_layers not in cur_pp_noop_layers:
+ recompute_list.append(str(i))
+ else:
+ recompute_list.append('')
+
+ prefetch_list = swap_list
+ prefetch_recompute_group = [[swap_list], [prefetch_list], [recompute_list]]
+ return [prefetch_recompute_group, 0, len(prefetch_list), cur_pp_noop_layers]
+
+
+def get_adaptive_recomputing_policy():
+ if AdaptiveRecomputePolicy.adaptive_recomputing_policy is None:
+ AdaptiveRecomputePolicy.adaptive_recomputing_policy = AdaptiveRecomputePolicy()
+ return AdaptiveRecomputePolicy.adaptive_recomputing_policy
+
+
+class AdaptiveRecompute:
+ adaptive_recomputing = None
+
+ def __init__(self):
+ # layer profiling info
+ self.context = {
+ 'module': []
+ }
+ #record allowed recomputing module
+ self.allowed_recomputing_module = []
+ # profiling prefix
+ self.profiling_prefix = ""
+ # save origin modules
+ self.checkpointed_modules = []
+ # save modules hook, remove it after apply policy
+ self.modules_hooks = []
+ # current profiling step
+ self.profiling_step = 0
+ # step for stop profiling, default is 10
+ self.stop_profiling_step = 10
+ # skip step for profiling
+ self.skip_profiling_step = 3
+ # step for solve graph by adaptive recompute, after step for stop profiling
+ self.solve_graph_at_step = 11
+ # unit for device memory size(MB)
+ self.unit_mb = 1024 * 1024
+ # pp or vpp
+ self.num_warmup_micro_batches = 1
+ # store all module event
+ self.event_list = []
+
+ @staticmethod
+ def get_memory_status():
+ free, all_memory, _ = acl.rt.get_mem_info(1)
+ memory_info = {
+ "free": free,
+ "all_memory": all_memory,
+ "used_memory": torch.npu.memory_allocated(),
+ "reserved_memory": torch.npu.memory_reserved(),
+ "max_memory_allocated": torch.npu.max_memory_allocated()
+ }
+
+ return memory_info
+
+ def get_num_warmup_micro_batches(self, num_model_chunks):
+ if parallel_state.get_pipeline_model_parallel_world_size() <= 1:
+ return
+ num_microbatches = get_num_microbatches()
+ pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
+ total_num_micro_batches = num_microbatches * num_model_chunks
+ if num_model_chunks == 1:
+ num_warmup_micro_batches = pipeline_parallel_size - pipeline_parallel_rank - 1
+ else:
+ num_warmup_micro_batches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
+ num_warmup_micro_batches += (num_model_chunks - 1) * pipeline_parallel_size
+ num_warmup_micro_batches += 1
+ if num_model_chunks >= 1:
+ self.num_warmup_micro_batches = min(num_warmup_micro_batches, total_num_micro_batches) / num_model_chunks
+
+ def pre_hook_func(self, state, prefix, name, *args, **kargs):
+ if self.profiling_step < self.skip_profiling_step:
+ return
+ state['memory'] = 0
+ state['input'] = self._cal_input_output_size(args)
+ if self.profiling_step == self.stop_profiling_step:
+ state['memory'] = torch.npu.memory_allocated() - state['input'] * self.unit_mb
+ # The memory and time information is obtained separately. The average time is calculated when the step in
+ # [skip_profiling_step, stop_profiling_step). The memory information is obtained only for the last time.
+ if self.profiling_step < self.stop_profiling_step:
+ start_event = torch.npu.Event(enable_timing=True)
+ self.event_list.append([start_event])
+ start_event.record()
+
+ def post_hook_func(self, state, prefix, name, args, output):
+ if self.profiling_step < self.skip_profiling_step:
+ return
+ if self.profiling_step < self.stop_profiling_step:
+ end_event = torch.npu.Event(enable_timing=True)
+ end_event.record()
+ # add end_event to corresponding position of list
+ for item in reversed(self.event_list):
+ if len(item) == 1:
+ item.append(end_event)
+ break
+ if self.profiling_step == self.stop_profiling_step:
+ output_memory = self._cal_input_output_size(output)
+ state['memory'] = (torch.npu.memory_allocated() - state['memory']) // self.unit_mb
+ state['input'] += output_memory
+
+ def forward_pre_hook(self, prefix, name, ctx):
+ def hook(module, *args, **kargs):
+ if 'module' in self.context:
+ self.context['module'].append(ctx)
+ self.pre_hook_func(ctx, prefix, name, *args, **kargs)
+
+ return hook
+
+ def forward_post_hook(self, prefix, name, ctx):
+ def hook(module, args, output):
+ self.post_hook_func(ctx, prefix, name, args, output)
+ if 'module' in self.context:
+ self.context['module'].pop()
+
+ return hook
+
+ def construct_context_recursive(self, prefix_name, model, ctx, have_allowed_recomputing):
+ # 1.construct context
+ next_have_allowed_recomputing = have_allowed_recomputing
+ for name, module in model.named_children():
+ if 'layers' not in ctx:
+ ctx['layers'] = []
+
+ current_ctx = {'name': name, 'prefix_name': prefix_name}
+ if 'layers' in ctx:
+ ctx['layers'].append(current_ctx)
+
+ next_name = prefix_name + "." + name if prefix_name != "" else name
+
+ # 2.tag allowed_recomputing module
+ if have_allowed_recomputing:
+ for allowed_recomputing_module in self.allowed_recomputing_module:
+ if isinstance(module, allowed_recomputing_module):
+ current_ctx['allowed_recomputing'] = True
+ if isinstance(model, torch.nn.ModuleList):
+ ctx['is_module_list'] = True
+ ctx['is_recomputing_layer'] = True
+ else:
+ current_ctx['is_recomputing_layer'] = True
+ next_have_allowed_recomputing = False
+ self.construct_context_recursive(next_name, module, current_ctx, next_have_allowed_recomputing)
+
+ def register_recursive_hook(self, model, ctx, profiling_prefix, first_chunk=False, layer_index=0, prefetch=False):
+ index = layer_index
+ for module in model.children():
+ if 'layers' not in ctx:
+ continue
+ current_ctx = ctx['layers'][index]
+ if prefetch:
+ if 'is_module_list' in ctx and 'allowed_recomputing' in current_ctx:
+ # transformer layer
+ module.no_checkpoint_forward = module.forward
+ module.forward = get_recompute_hook().hook_checkpoint_forward(module.forward)
+ self.checkpointed_modules.append(module)
+ else:
+ # only has allowed_recomputing Tag can set recomputing hook
+ recompute_layer_condition = index != 0 or index == 0 and not first_chunk
+ if 'is_module_list' in ctx and 'allowed_recomputing' in current_ctx and recompute_layer_condition:
+ # transformer layer
+ module.no_checkpoint_forward = module.forward
+ module.forward = get_recompute_hook().hook_checkpoint_forward(module.forward)
+ self.checkpointed_modules.append(module)
+ prefix_name = current_ctx['prefix_name']
+ name = current_ctx['name']
+
+ # profiling entire module
+ if "module" == prefix_name or 'module0' == prefix_name:
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(prefix_name, name, current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(prefix_name, name, current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+
+ # profiling transformer Layers
+ if isinstance(module, torch.nn.ModuleList) and 'is_recomputing_layer' in current_ctx and first_chunk:
+ pre_hook = model.register_forward_pre_hook(self.forward_pre_hook(ctx['prefix_name'], ctx['name'], ctx))
+ post_hook = model.register_forward_hook(self.forward_post_hook(ctx['prefix_name'], ctx['name'], ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+ elif 'is_recomputing_layer' in current_ctx and first_chunk:
+ profiling_prefix = prefix_name + "." + name
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(prefix_name, name, current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(prefix_name, name, current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+
+ # only has allowed_recomputing Tag and its submodule can set profiling hook
+ if 'allowed_recomputing' in current_ctx and index == 0 and first_chunk:
+ profiling_prefix = prefix_name + "." + name
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(prefix_name, name, current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(prefix_name, name, current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+ elif profiling_prefix and prefix_name.startswith(profiling_prefix) and first_chunk:
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(prefix_name, name, current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(prefix_name, name, current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+ self.register_recursive_hook(module, current_ctx, profiling_prefix, first_chunk, prefetch=prefetch)
+ index += 1
+
+ def reset_modules(self):
+ for m in self.checkpointed_modules:
+ m.forward = m.no_checkpoint_forward
+ self.checkpointed_modules.clear()
+ get_recompute_hook().reset_recompute_modules()
+ get_swap_hook().reset_swap_manager_modules()
+ SwapManager().reset_swap_manager_tensors()
+ if (get_adaptive_recomputing_policy().check_non_oom_times == 0
+ and not get_adaptive_recomputing_policy().is_find_target_device_memory):
+ torch_npu.npu.empty_cache()
+
+ def reset_all_hook_args(self):
+ all_args = get_args()
+ step = get_adaptive_recomputing_policy().remove_swap_manager_hook_step
+ if not all_args.adaptive_recompute_device_swap:
+ for hook_handle in self.modules_hooks:
+ hook_handle.remove()
+ self.modules_hooks.clear()
+ SwapManager().reset_swap_manager_tensors()
+ get_swap_hook().reset_swap_manager_modules()
+ return
+ if self.profiling_step >= self.solve_graph_at_step:
+ for hook_handle in self.modules_hooks:
+ hook_handle.remove()
+ self.modules_hooks.clear()
+ if not get_adaptive_recomputing_policy().is_find_target_device_memory or self.profiling_step > step + 1:
+ return
+ if self.profiling_step == step + 1:
+ title = (f"===== finish to check policy, search policy memory size is: "
+ f"{get_adaptive_recomputing_policy().cur_device_memory} =====")
+ print_rank_0(f"{title}\n{get_graph_solver().final_policy_info}\n{'=' * len(title)}")
+ if self.profiling_step == step:
+ get_swap_hook().reset_swap_manager_modules()
+ if get_adaptive_recomputing_policy().is_find_target_device_memory:
+ SwapManager().reset_swap_manager_tensors()
+
+ def prefetch_hook(self, models):
+ self.reset_modules()
+ all_args = get_args()
+ pp = all_args.pipeline_model_parallel_size
+ vpp = all_args.virtual_pipeline_model_parallel_size if all_args.virtual_pipeline_model_parallel_size else 1
+ print_rank_0("ADAPTIVE-PREFETCH: Start applying policy to the model")
+ config = {
+ "pre_layer_full_name": "",
+ "pre_layer_ctx": {},
+ "cur_layer_name": "module",
+ }
+ prefetch_recompute_group, interval, num_prefetch, swap_noop_layers = get_adaptive_recomputing_policy().solve_prefetch_policy()
+ print(f"[DEBUG] swap_list: {prefetch_recompute_group[0]},"
+ f" prefetch_list: {prefetch_recompute_group[1]},"
+ f" recompute_list: {prefetch_recompute_group[2]}")
+ for i in prefetch_recompute_group[0]:
+ if not any(filter(None, i)):
+ vpp -= 1
+ prefetch_args = [prefetch_recompute_group[0], vpp, interval, num_prefetch]
+ apply_prefetch_strategy(config, models, self.context, prefetch_recompute_group, prefetch_args)
+
+
+ def step_hook(self, models):
+ torch.npu.synchronize()
+ while self.event_list:
+ record_time(self.context, self.event_list)
+ self.reset_all_hook_args()
+ if self.profiling_step < self.solve_graph_at_step:
+ return
+
+ if get_adaptive_recomputing_policy().context_copy is None:
+ get_adaptive_recomputing_policy().context_copy = deepcopy(self.context)
+ try:
+ get_adaptive_recomputing_policy().get_default_device_memory(self.context["max_device_memory"])
+ except KeyError:
+ print_rank_0("[ERROR] Some of these keys don't exist.")
+ get_graph_solver().build_solver_info(self.context, self.num_warmup_micro_batches, len(models))
+
+ get_adaptive_recomputing_policy().check_cur_recompute_policy()
+ print_rank_0("==================== ADAPTIVE-RECOMPUTE Report ====================")
+ context = get_adaptive_recomputing_policy().solve_recompute_policy(self.profiling_step)
+ print_rank_0("==================== ADAPTIVE-RECOMPUTE Report End ====================")
+ if context is not None:
+ self.context = context
+ self.reset_modules()
+ print_rank_0("ADAPTIVE-RECOMPUTE: Start applying policy to the model")
+ config = {
+ "pre_layer_ctx": {},
+ "cur_layer_name": "module",
+ }
+ apply_adaptive_recompute(config, models, self.context)
+ print_rank_0("ADAPTIVE-RECOMPUTE: Finish applying policy to the model")
+ get_swap_hook().reset_tensor_layer_info()
+
+ def hook_step_func(self, step_func, models):
+ def custom_step_func(*args, **kargs):
+ result = step_func(*args, **kargs)
+ if (self.profiling_step > self.solve_graph_at_step and \
+ get_adaptive_recomputing_policy().is_stable_policy(self.profiling_step)):
+ return result
+ memory_info = self.get_memory_status()
+ try:
+ hccl_memory = (memory_info["all_memory"] - memory_info["free"] - memory_info[
+ "reserved_memory"]) // self.unit_mb
+ get_adaptive_recomputing_policy().hccl_memory = max(hccl_memory, get_adaptive_recomputing_policy().hccl_memory)
+ self.context['used_mem'] = memory_info["used_memory"] // self.unit_mb
+ self.context['max_device_memory'] = memory_info["all_memory"] // self.unit_mb
+ except KeyError:
+ print_rank_0("[ERROR] Some of these keys don't exist.")
+ self.profiling_step += 1
+ self.step_hook(models)
+ return result
+
+ return custom_step_func
+
+ def set_profiling_step(self, step):
+ self.stop_profiling_step = step
+ self.solve_graph_at_step = step + 1
+
+ def add_allowed_recomputing_module(self, module):
+ if module not in self.allowed_recomputing_module:
+ self.allowed_recomputing_module.append(module)
+
+ def _cal_input_output_size(self, args):
+ size = 0
+ if isinstance(args, torch.Tensor):
+ size += get_tensor_mem_size(args)
+ return size // self.unit_mb
+ for arg in args:
+ if isinstance(arg, torch.Tensor):
+ size += get_tensor_mem_size(arg)
+ elif isinstance(arg, Iterable):
+ for t in arg:
+ if isinstance(t, torch.Tensor):
+ size += get_tensor_mem_size(t)
+ elif t is None:
+ pass
+ else:
+ print_rank_0(f"[WARNING]: unknown input/output type {str(type(t))}")
+ elif arg is None:
+ pass
+ else:
+ print_rank_0(f"[WARNING]: unknown input/output type {str(type(t))}")
+ return size // self.unit_mb
+
+
+def get_adaptive_recomputing():
+ if AdaptiveRecompute.adaptive_recomputing is None:
+ AdaptiveRecompute.adaptive_recomputing = AdaptiveRecompute()
+ return AdaptiveRecompute.adaptive_recomputing
+
+
+def record_time(context, remaining_event_list):
+ if "memory" in context:
+ cur_level_event_list = remaining_event_list.pop(0)
+ start_event = cur_level_event_list[0]
+ end_event = cur_level_event_list[1]
+ total_time = start_event.elapsed_time(end_event)
+ if 'pre_total_time' in context:
+ context['forward_cnt'] += 1
+ context['time'] = total_time
+ context['pre_total_time'] += total_time
+ try:
+ context['time'] = context['pre_total_time'] / context['forward_cnt']
+ except ZeroDivisionError:
+ context['time'] = 0
+ else:
+ context['forward_cnt'] = 1
+ context['time'] = total_time
+ context['pre_total_time'] = total_time
+ if "layers" not in context:
+ return
+ for sub_layer_context in context["layers"]:
+ record_time(sub_layer_context, remaining_event_list)
+
+
+def is_activate_adaptive_recompute():
+ all_args = get_args()
+ profiling_step = 0
+ if all_args.adaptive_recompute_device_size < 0 and not all_args.adaptive_recompute_device_swap and not all_args.swap_attention:
+ print_rank_0("[ERROR] failed to activate adaptive selective recompute train, please add param: "
+ "\"adaptive-recompute-device-swap\", or set param: \"adaptive-recompute-device-size\".")
+ return False, profiling_step
+ max_profiling_step = all_args.train_iters // 10
+ profiling_step = all_args.adaptive_recompute_profiling_step
+ if profiling_step > all_args.train_iters and not all_args.swap_attention:
+ raise AssertionError('\"adaptive-recompute-profiling-step\" cannot be greater than train_iters')
+ if profiling_step < 5 or profiling_step > max_profiling_step:
+ print_rank_0(f"[WARNING] consider set \"adaptive-recompute-profiling-step\" value >=5"
+ f"and <={max_profiling_step}, or remove it.")
+ if profiling_step <= 0:
+ print_rank_0("[WARNING] \"adaptive-recompute-profiling-step\" value can not <=0, will use default value 10.")
+ profiling_step = 10
+ print_rank_0(
+ "success to activate adaptive recompute train: adaptive-recompute-device-swap={}, adaptive-recompute-device-size={}, "
+ "adaptive-recompute-profiling-step={}".format(all_args.adaptive_recompute_device_swap,
+ all_args.adaptive_recompute_device_size, profiling_step))
+ return True, profiling_step
+
+
+def setup_model_and_optimizer_wrapper(setup_model_and_optimizer):
+ @wraps(setup_model_and_optimizer)
+ def wrapper(*args, **kargs):
+ models, optimizer, opt_param_scheduler = setup_model_and_optimizer(*args, **kargs)
+ activated, profile_step = is_activate_adaptive_recompute()
+ if not activated:
+ return models, optimizer, opt_param_scheduler
+ recomputing = get_adaptive_recomputing()
+ recomputing.set_profiling_step(profile_step)
+ recomputing.get_num_warmup_micro_batches(len(models))
+ args = get_args()
+ if not args.swap_attention:
+ optimizer.step = recomputing.hook_step_func(optimizer.step, models)
+ if isinstance(models, list):
+ for index, model in enumerate(models):
+ recomputing.construct_context_recursive("module" + str(index), model, recomputing.context, True)
+ if not args.swap_attention:
+ recomputing.register_recursive_hook(model, recomputing.context, recomputing.profiling_prefix,
+ index == 0, index, prefetch=args.swap_attention)
+ else:
+ recomputing.construct_context_recursive("module", models, recomputing.context, True)
+ if not args.swap_attention:
+ recomputing.register_recursive_hook(models, recomputing.context, recomputing.profiling_prefix, \
+ True, prefetch=args.swap_attention)
+ if args.swap_attention:
+ recomputing.prefetch_hook(models)
+ print_rank_0("ADAPTIVE-RECOMPUTE: successfully hooking module")
+ return models, optimizer, opt_param_scheduler
+
+ return wrapper
+
+
+def allowed_recomputing_module_wrapper(allowed_recomputing_module):
+ recomputing = get_adaptive_recomputing()
+ recomputing.add_allowed_recomputing_module(allowed_recomputing_module)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute_apply.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute_apply.py
new file mode 100644
index 0000000000000000000000000000000000000000..bacb5c17d4f65087ea2e1cd4d351ecbb934dfc8d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute_apply.py
@@ -0,0 +1,200 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import torch
+from megatron.core import tensor_parallel
+from megatron.training import print_rank_0
+from megatron.training import get_args
+from mindspeed.core.memory.adaptive_recomputing.swap_manager import SwapManager
+from mindspeed.core.memory.adaptive_recomputing.prefetch import prefetch_tensor, prefetch_register_post_backward_hook, prefetch_register_pre_forward_hook, get_swap_prefetch, get_layer_id
+
+
+class RecomputeHook:
+ recompute_hook = None
+
+ def __init__(self):
+ self.recompute_modules = []
+
+ @staticmethod
+ def hook_checkpoint_forward(forward_func):
+ def custom_forward(*args, **kargs):
+ def inside_forward(*args):
+ return forward_func(*args, **kargs)
+
+ return tensor_parallel.checkpoint(inside_forward, None, *args)
+
+ return custom_forward
+
+ def reset_recompute_modules(self):
+ for m in self.recompute_modules:
+ m.forward = m.no_checkpoint_adaptive_recompute_forward
+ self.recompute_modules.clear()
+
+
+def get_recompute_hook():
+ if RecomputeHook.recompute_hook is None:
+ RecomputeHook.recompute_hook = RecomputeHook()
+ return RecomputeHook.recompute_hook
+
+
+class SwapManagerHook:
+ swap_hook = None
+
+ def __init__(self):
+ self.tensor_layer_name_prefix = ""
+ self.pre_tensor_layer_name_prefix = ""
+ self.swap_manager_modules = []
+
+ @staticmethod
+ def unpack_hook(data):
+ return SwapManager().unwrap_tensor(data)
+
+ def pack_hook(self, origin_tensor):
+ pre_tensor_is_allowed_swap = False
+ # enter diff layer, make other layer tensor status to can be swapped
+ if self.tensor_layer_name_prefix != self.pre_tensor_layer_name_prefix:
+ pre_tensor_is_allowed_swap = True
+ self.pre_tensor_layer_name_prefix = self.tensor_layer_name_prefix
+ return SwapManager().wrap_tensor(origin_tensor, pre_tensor_is_allowed_swap)
+
+ def hook_swap_manager_forward(self, forward_func, layer_name_prefix):
+ def custom_forward(*args, **kargs):
+ self.tensor_layer_name_prefix = layer_name_prefix
+ with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):
+ return forward_func(*args, **kargs)
+
+ return custom_forward
+
+ def reset_tensor_layer_info(self):
+ self.tensor_layer_name_prefix = ""
+ self.pre_tensor_layer_name_prefix = ""
+
+ def reset_swap_manager_modules(self):
+ for m in self.swap_manager_modules:
+ m.forward = m.no_checkpoint_swap_forward
+ self.swap_manager_modules.clear()
+
+
+def get_swap_hook():
+ if SwapManagerHook.swap_hook is None:
+ SwapManagerHook.swap_hook = SwapManagerHook()
+ return SwapManagerHook.swap_hook
+
+
+def register_recursive_apply(config, models, ctx):
+ pre_layer_ctx = config["pre_layer_ctx"]
+ cur_layer_name = config["cur_layer_name"]
+ if cur_layer_name == "module" and isinstance(models, list):
+ idx = 0
+ for model in models:
+ register_recursive_apply(config, model, get_list_layers_context(ctx, idx))
+ idx += 1
+ return
+
+ if 'recompute' in ctx and ctx['recompute']:
+ models.no_checkpoint_adaptive_recompute_forward = models.forward
+ models.forward = get_recompute_hook().hook_checkpoint_forward(models.forward)
+ get_recompute_hook().recompute_modules.append(models)
+ return
+
+ if 'allowed_recomputing' in pre_layer_ctx:
+ models.no_checkpoint_swap_forward = models.forward
+ models.forward = get_swap_hook().hook_swap_manager_forward(models.forward, ctx["prefix_name"])
+ get_swap_hook().swap_manager_modules.append(models)
+ return
+
+ idx = 0
+ for name, module in models.named_children():
+ config = {
+ "pre_layer_ctx": ctx,
+ "cur_layer_name": name,
+ }
+ register_recursive_apply(config, module, ctx['layers'][idx])
+ idx += 1
+
+
+def is_hook_layer(ctx, hook_list):
+ if "name" in ctx and ctx["name"] in hook_list and "expert" not in ctx['prefix_name']:
+ return True
+ return False
+
+
+def is_recompute_layer(ctx, prefetch_list):
+ if "name" in ctx and "mlp" == ctx["name"] and get_layer_id(ctx["prefix_name"]) in prefetch_list:
+ return True
+ return False
+
+
+def register_recursive_apply_prefetch(config, models, ctx, prefetch_recompute_group, prefetch_args):
+ args = get_args()
+ prefetch_list, hook_list, recompute_list = prefetch_recompute_group
+ if not isinstance(prefetch_list[0], list):
+ prefetch_layer = prefetch_list
+ hook_layer = hook_list
+ recompute_layer = recompute_list
+
+ pre_layer_full_name = config["pre_layer_full_name"]
+ pre_layer_ctx = config["pre_layer_ctx"]
+ cur_layer_name = config["cur_layer_name"]
+ if cur_layer_name == "module" and isinstance(models, list):
+ idx = 0
+ for model in models:
+ prefetch_layer = prefetch_list[idx] if isinstance(prefetch_list[0], list) else prefetch_list
+ hook_layer = hook_list[idx] if isinstance(hook_list[0], list) else hook_list
+ recompute_layer = recompute_list[idx] if isinstance(recompute_list[0], list) else recompute_list
+ print_rank_0(f'prefetch_layer: {prefetch_layer}---{hook_layer}')
+ if any(filter(None, prefetch_layer)):
+ prefetch_recompute_group = [prefetch_layer, hook_layer, recompute_layer]
+ register_recursive_apply_prefetch(config, model, get_list_layers_context(ctx, idx),
+ prefetch_recompute_group, prefetch_args)
+ idx += 1
+ return
+
+ if is_hook_layer(ctx, hook_list):
+ print_rank_0(f"prefetch forward and backward hook success: {pre_layer_full_name + '.' + cur_layer_name}")
+ prefetch_register_post_backward_hook(models, pre_layer_full_name + '.' + cur_layer_name, prefetch_args)
+ prefetch_register_pre_forward_hook(models, pre_layer_full_name + '.' + cur_layer_name, prefetch_args)
+ if hook_list == prefetch_list and prefetch_list != ['']:
+ if "name" in ctx and ctx["name"] in args.swap_modules and \
+ get_layer_id(ctx["prefix_name"]) in prefetch_list:
+ print_rank_0(f"prefetch swap hook success: {pre_layer_full_name + '.' + cur_layer_name}")
+ models.no_checkpoint_adaptive_recompute_forward = models.forward
+ models.forward = get_swap_prefetch(prefetch_args).hook_swap_manager_forward(models.forward,
+ pre_layer_full_name +
+ '.' + cur_layer_name)
+ get_recompute_hook().recompute_modules.append(models)
+ return
+ elif is_recompute_layer(ctx, recompute_list):
+ print_rank_0(f"prefetch recompute hook success: {pre_layer_full_name + '.' + cur_layer_name}")
+ models.no_checkpoint_adaptive_recompute_forward = models.forward
+ models.forward = get_recompute_hook().hook_checkpoint_forward(models.forward)
+ get_recompute_hook().recompute_modules.append(models)
+ return
+ else:
+ if is_hook_layer(ctx, prefetch_list):
+ print_rank_0(f"prefetch tensor hook success: {pre_layer_full_name + '.' + cur_layer_name}")
+ models.no_checkpoint_adaptive_recompute_forward = models.forward
+ models.forward = get_swap_prefetch(prefetch_args).hook_swap_manager_forward(models.forward,
+ pre_layer_full_name +
+ '.' + cur_layer_name)
+ get_recompute_hook().recompute_modules.append(models)
+ return
+ pre_layer_full_name += "." + cur_layer_name if pre_layer_full_name != "" else cur_layer_name
+ idx = 0
+ for name, module in models.named_children():
+ config = {
+ "pre_layer_full_name": pre_layer_full_name,
+ "pre_layer_ctx": ctx,
+ "cur_layer_name": name,
+ }
+ prefetch_recompute_group = [prefetch_layer, hook_layer, recompute_layer]
+ register_recursive_apply_prefetch(config, module, ctx['layers'][idx], prefetch_recompute_group, prefetch_args)
+ idx += 1
+
+
+def get_list_layers_context(ctx, idx):
+ current_ctx = {}
+ for k, v in ctx.items():
+ if k == "layers":
+ current_ctx[k] = [v[idx]]
+ continue
+ current_ctx[k] = v
+ return current_ctx
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute_solver.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac6578dda6ffed494d7ec4d80faa1e8882ecced1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute_solver.py
@@ -0,0 +1,574 @@
+import sys
+from copy import deepcopy
+
+import networkx as nx
+import torch
+import numpy as np
+
+from megatron.training import print_rank_0
+from megatron.core import parallel_state
+
+
+class GraphSolver:
+ graph_solver = None
+
+ def __init__(self):
+ self.total_recompute_cost = 0
+ self.total_forward_cost = 0
+ self.num_layers_module = []
+ self.layers_num = 0
+ self.transformer_module_memory = 0
+ self.recompute_policy = {}
+ self.layers_combination = []
+ self.layer_full_recompute_combination = None
+ self.layer_without_recompute_combination = None
+ self.layer_recompute_one_combination = None
+ self.module_layers = {}
+ self.node_split_flag = ", "
+ self.without_recompute_info = ""
+ self.all_recompute_info = ""
+ self.selective_recompute_info = ""
+ self.final_policy_info = ""
+ self.static_memory = 0
+ self.pp = 1
+ self.module_chunk = 1
+ self.chp_input = 0
+ self.chp_time = 0
+ self.full_activation = 0
+ self.first_layer_module = None
+ self.mp = 1
+ self.dp = 1
+
+ @staticmethod
+ def get_dg(module_layers):
+ dg = nx.DiGraph()
+ dg.add_nodes_from([
+ (i, {"name": module_layers[i]['name'],
+ "mem": module_layers[i]['memory'],
+ "input": module_layers[i]['input'],
+ "compute": module_layers[i]['time'],
+ "recompute": False,
+ "status": "no_status"})
+ for i in range(len(module_layers))
+ ])
+ dg.add_edges_from([
+ (i, i + 1) for i in range(len(module_layers) - 1)
+ ])
+ return dg
+
+ def broadcast_in_mp_dp(self, tensor, src):
+ if self.mp > 1 and parallel_state.get_tensor_model_parallel_src_rank() == src:
+ torch.distributed.broadcast(tensor,
+ src=parallel_state.get_tensor_model_parallel_src_rank(),
+ group=parallel_state.get_tensor_model_parallel_group())
+ if self.dp > 1:
+ torch.distributed.broadcast(tensor, src=parallel_state.get_data_parallel_src_rank(),
+ group=parallel_state.get_data_parallel_group())
+
+ def get_no_recompute_layer(self):
+ self.first_layer_module = self.num_layers_module[0]['layers'][0]
+ layer_module = self.first_layer_module['layers']
+ module_layers = []
+ if len(layer_module) == 0:
+ return module_layers
+ parent_layers = []
+ for layer in layer_module:
+ if "memory" not in layer:
+ continue
+ module_layers.append(layer)
+ parent_layers.append(layer)
+ if "layers" not in layer:
+ continue
+ parent_name = layer['name']
+ sub_layer_name = []
+ for sub_layer in layer['layers']:
+ if "memory" not in sub_layer:
+ continue
+ # rename sub_layer name, like 'self_attention.core_attention'
+ sub_layer['name'] = "{}.{}".format(parent_name, sub_layer['name'])
+ module_layers.append(sub_layer)
+ sub_layer_name.append(sub_layer)
+ self.module_layers.update({parent_name: sub_layer_name})
+ self.module_layers.update({"parent_layers": parent_layers})
+ self.module_layers.update({"module_layers": module_layers})
+ return
+
+ # remove full select node, like 'input_layernorm', 'self_attention', 'post_attention_layernorm' and 'mlp' in list
+ def remove_full_selective_node(self, recompute_nodes):
+ if len(recompute_nodes) == 0:
+ return recompute_nodes
+
+ layers_recompute_count = 0
+ try:
+ for layer in self.module_layers["parent_layers"]:
+ name = layer['name']
+ if name in recompute_nodes:
+ layers_recompute_count += 1
+ if layers_recompute_count == len(self.module_layers["parent_layers"]):
+ recompute_nodes.clear()
+ break
+ if name not in self.module_layers.keys():
+ continue
+ sub_layers_recompute_count = 0
+ for sub_layer in self.module_layers[name]:
+ if sub_layer['name'] in recompute_nodes:
+ sub_layers_recompute_count += 1
+ if sub_layers_recompute_count == len(self.module_layers[name]):
+ recompute_nodes.clear()
+ break
+ except KeyError:
+ print_rank_0("[ERROR] Some of these keys don't exist.")
+ return recompute_nodes
+
+ def get_recompute_op(self, graph):
+ recompute_nodes = []
+ p_node = []
+ for node in graph.nodes:
+ if not graph.nodes[node]['recompute']:
+ continue
+ name = graph.nodes[node]['name']
+ recompute_nodes.append(name)
+ spd = name.split(".")
+ if len(spd) == 2 and spd[0] not in p_node:
+ p_node.append(spd[0])
+ # remove parent and sub in list together, like 'self_attention' and 'self_attention.core_attention' in list
+ for n in p_node:
+ if n in recompute_nodes:
+ recompute_nodes.clear()
+ break
+ return self.remove_full_selective_node(recompute_nodes)
+
+ def broadcast_recompute_policy(self, recompute_policy_list):
+ try:
+ self.mp = parallel_state.get_tensor_model_parallel_world_size()
+ self.dp = parallel_state.get_data_parallel_world_size()
+ except:
+ print_rank_0("WARNING: mp, dp is not defined")
+ global_rank = torch.distributed.get_rank()
+ src = (global_rank // (self.mp * self.dp)) * self.dp * self.mp
+
+ policy_shape = np.array(recompute_policy_list).shape
+ policy_len_tensor = torch.tensor(policy_shape, device=torch.npu.current_device())
+ self.broadcast_in_mp_dp(policy_len_tensor, src)
+ policy_len = tuple(policy_len_tensor.cpu().numpy().tolist())
+ if global_rank == src:
+ recompute_policy_tensor = torch.tensor(recompute_policy_list, dtype=torch.int8,
+ device=torch.npu.current_device())
+ else:
+ recompute_policy_tensor = torch.empty(policy_len, dtype=torch.int8,
+ device=torch.npu.current_device())
+
+ self.broadcast_in_mp_dp(recompute_policy_tensor, src)
+ result = recompute_policy_tensor.cpu().numpy().tolist()
+ del recompute_policy_tensor
+ return result
+
+ def set_recompute_info_to_module(self, module, recompute_nodes_info):
+ for sub_module in module:
+ name = sub_module["name"]
+ if name not in recompute_nodes_info.keys():
+ continue
+ info = recompute_nodes_info[name]
+ if isinstance(info, bool):
+ sub_module["recompute"] = info
+ continue
+ if "child_module" in info.keys():
+ self.set_recompute_info_to_module(sub_module["layers"], info["child_module"])
+ continue
+ if name in info.keys():
+ sub_module["recompute"] = info[name]
+
+ def covert_recompute_node_idx_to_name(self, recompute_nodes):
+ result = {}
+ try:
+ module_layers = self.module_layers["module_layers"]
+ except KeyError as e:
+ print_rank_0("[ERROR] The key \"module_layers\" doesn't exist.")
+ raise e
+ for i, node in enumerate(recompute_nodes):
+ if node != self.layer_recompute_one_combination.broadcast_value:
+ continue
+ name = module_layers[i]["name"]
+ parent_name = name
+ sub_name = ""
+ if "." in name:
+ parent_name, sub_name = name.split(".")
+ if parent_name not in result.keys():
+ result[parent_name] = {}
+ if sub_name == "":
+ result[parent_name].update({name: True})
+ continue
+ if "child_module" not in result[parent_name].keys():
+ result[parent_name]["child_module"] = {}
+ result[parent_name]["child_module"].update({name: True, sub_name: True})
+ return result
+
+ def set_to_module(self, module, recompute_nodes, idx):
+ if len(recompute_nodes) == 0:
+ module["recompute"] = True
+ return
+ recompute_nodes_info = self.covert_recompute_node_idx_to_name(recompute_nodes)
+ if len(recompute_nodes_info) == 0:
+ return
+ self.set_recompute_info_to_module(module["layers"], recompute_nodes_info)
+
+ def apply_policy_to_model(self, recompute_policy_list):
+ full_layers = []
+ for layer in self.num_layers_module:
+ if 'is_module_list' in layer:
+ full_layers.extend(layer["layers"])
+ else:
+ full_layers.append(layer)
+ if len(recompute_policy_list) == 0:
+ return
+ idx = 0
+ if (recompute_policy_list[-1][2] == self.layer_full_recompute_combination.broadcast_value
+ or recompute_policy_list[0][2] == self.layer_without_recompute_combination.broadcast_value):
+ recompute_policy_list = list(reversed(recompute_policy_list))
+ for policy in recompute_policy_list:
+ n = policy[0]
+ combination_idx = policy[1]
+ recompute_nodes = []
+ if policy[2] == self.layer_without_recompute_combination.broadcast_value:
+ status = self.layer_without_recompute_combination.broadcast_value
+ try:
+ recompute_nodes = [status for _ in range(len(self.module_layers["module_layers"]))]
+ except KeyError:
+ print_rank_0("[ERROR] The key \"module_layers\" doesn't exist.")
+ if policy[2] == self.layer_recompute_one_combination.broadcast_value:
+ recompute_nodes = policy[3:]
+ for i in range(idx, idx + n):
+ self.set_to_module(full_layers[i], recompute_nodes, combination_idx)
+ idx += n
+
+ # minimize the number of memory, results in all recompute
+ def calculate_cost_mem(self, g: nx.DiGraph, idx):
+ subtotal_cost = 0
+ subtotal_compute_cost = 0
+ memory_cost = (g.nodes[idx]['mem'] if not g.nodes[idx]['recompute'] else g.nodes[idx]['input'])
+ compute_cost = (g.nodes[idx]['compute'] if g.nodes[idx]['recompute'] else 0)
+
+ successors = g.successors(idx)
+ for successor in successors:
+ a, b = self.calculate_cost_mem(g, successor)
+ subtotal_cost += a
+ subtotal_compute_cost += b
+
+ return subtotal_cost + memory_cost, subtotal_compute_cost + compute_cost
+
+ def cal_non_transformer_memory(self, model_context, num_model_chunks):
+ # total memory used
+ model_memory = 0
+ for layer in model_context['layers']:
+ model_memory += layer['memory']
+ break
+ non_size = (model_memory - self.transformer_module_memory) * num_model_chunks
+ return non_size
+
+ def reset_cost(self, g: nx.DiGraph, idx, reset_node_name):
+ node_name = g.nodes[idx]['name']
+ if node_name in reset_node_name:
+ g.nodes[idx]['mem'] = 0
+ g.nodes[idx]['input'] = 0
+ g.nodes[idx]['compute'] = 0
+ successors = g.successors(idx)
+ for successor in successors:
+ self.reset_cost(g, successor, reset_node_name)
+ return
+
+ # remove dg redundant nodes, like: self_attention and self_attention.core_attention, remove one
+ def reset_redundant_nodes(self, dg, recompute_nodes):
+ dg_copy = deepcopy(dg)
+ reset_node_name = []
+ try:
+ for parent_layer in self.module_layers["parent_layers"]:
+ parent_name = parent_layer['name']
+ if parent_name not in self.module_layers.keys():
+ continue
+ sub_reset_node_name = []
+ for sub_layer in self.module_layers[parent_name]:
+ sub_reset_node_name.append(sub_layer['name'])
+ if sub_layer['name'] in recompute_nodes:
+ reset_node_name.append(parent_name)
+ sub_reset_node_name.clear()
+ break
+ if len(sub_reset_node_name) != 0:
+ reset_node_name.extend(sub_reset_node_name)
+ except KeyError:
+ print_rank_0("[ERROR] The key \"parent_layers\" doesn't exist.")
+ self.reset_cost(dg_copy, 0, reset_node_name)
+ return dg_copy
+
+ def layers_combination_init(self, g, idx):
+ if idx == 0:
+ self.layer_full_recompute_combination = LayerCombination({
+ "name": "full_recompute",
+ "num": self.layers_num,
+ "memory": self.chp_input,
+ "cost": self.chp_time,
+ "broadcast_value": 0,
+ "policy_name": "n_full"
+ })
+ self.layers_combination.append(self.layer_full_recompute_combination)
+ self.layer_without_recompute_combination = LayerCombination({
+ "name": "without_recompute",
+ "num": self.layers_num,
+ "memory": self.full_activation,
+ "cost": 0,
+ "broadcast_value": 2,
+ "policy_name": "n_without"
+ })
+ self.layers_combination.append(self.layer_without_recompute_combination)
+ try:
+ if idx >= len(self.module_layers["module_layers"]):
+ recompute_nodes = self.get_recompute_op(g)
+ if len(recompute_nodes) == 0:
+ return
+ dg = self.reset_redundant_nodes(g, recompute_nodes)
+ stash_mem_per_layer, recompute_cost = self.calculate_cost_mem(dg, 0)
+ self.layer_recompute_one_combination = LayerCombination({
+ "name": self.node_split_flag.join(recompute_nodes),
+ "num": self.layers_num,
+ "memory": stash_mem_per_layer,
+ "cost": recompute_cost,
+ "broadcast_value": 1,
+ "policy_name": "n_selective"
+ })
+ self.layers_combination.append(self.layer_recompute_one_combination)
+ return
+ except KeyError:
+ print_rank_0("[ERROR] The key \"module_layers\" doesn't exist.")
+ if g.nodes[idx]['mem'] >= g.nodes[idx]['input']:
+ g.nodes[idx]['recompute'] = True
+ self.layers_combination_init(g, idx + 1)
+ g.nodes[idx]['recompute'] = False
+ self.layers_combination_init(g, idx + 1)
+
+ def get_max_goods_value(self, idx, ans, device_memory):
+ i, j, k = idx[0], idx[1], idx[2]
+ pre_step_ans = ans[i - 1][j - k]
+ if k == 0:
+ return pre_step_ans
+
+ goods_value = ans[i][j]
+ memory = pre_step_ans.memory + k * self.layers_combination[i].memory
+ cost = pre_step_ans.cost + k * self.layers_combination[i].cost
+ if pre_step_ans.cost == float('inf'):
+ cost = k * self.layers_combination[i].cost
+ try:
+ device_memory = max(device_memory - self.static_memory, 0) / self.pp
+ except ZeroDivisionError:
+ device_memory = max(device_memory - self.static_memory, 0)
+ print_rank_0("[ERROR] pipeline model parallel world size is 0. ")
+
+ if device_memory >= memory and cost <= goods_value.cost:
+ goods_value.memory = memory
+ goods_value.cost = cost
+ goods_value.layer_names.clear()
+ if len(pre_step_ans.layer_names) > 0:
+ goods_value.layer_names.extend(pre_step_ans.layer_names)
+ goods_value.layer_names.extend(self.layers_combination[i].name for _ in range(k))
+
+ return goods_value
+
+ def print_recompute_policy(self, memory, cost):
+ fmt_str = "With selective recompute:\n"
+ for k, v in self.recompute_policy.items():
+ if k == self.layer_full_recompute_combination.name:
+ policy_name = self.layer_full_recompute_combination.policy_name
+ elif k == self.layer_without_recompute_combination.name:
+ policy_name = self.layer_without_recompute_combination.policy_name
+ else:
+ policy_name = self.layer_recompute_one_combination.policy_name
+ fmt_str += "recomputeNodes=[{}], ".format(k)
+ fmt_str += "{} {}; ".format(v, policy_name)
+ all_recompute_cost = self.layers_num * self.layer_full_recompute_combination.cost
+ try:
+ performance = (all_recompute_cost - cost) / (all_recompute_cost * 4)
+ except ZeroDivisionError:
+ performance = 0
+ print_rank_0("[ERROR] all recompute cost is 0. ")
+ fmt_str = fmt_str.strip().rstrip(";")
+ fmt_str += "\ntotal mem cost: {:.1f} GiB + {:.1f} GiB, speed up compared with all recompute {:.2%}".format(
+ self.static_memory / 1024, memory * self.pp / 1024, performance)
+ self.selective_recompute_info = fmt_str
+
+ def get_all_layer_policy(self, combination_num, layer_num, ans):
+ layer_nodes = [self.layer_full_recompute_combination.name for _ in range(layer_num)]
+ memory = layer_num * self.layer_full_recompute_combination.memory
+ cost = layer_num * self.layer_full_recompute_combination.cost
+ for i in range(layer_num, 0, -1):
+ size = layer_num - len(ans[combination_num][i].layer_names)
+ if size != layer_num:
+ l_nodes = []
+ l_nodes.extend(ans[combination_num][i].layer_names)
+ # if the policies of all layers are not found, the remaining layers ues all recompute policy.
+ l_nodes.extend(self.layer_full_recompute_combination.name for _ in range(size))
+ l_memory = ans[combination_num][i].memory + size * self.layer_full_recompute_combination.memory
+ l_cost = ans[combination_num][i].cost + size * self.layer_full_recompute_combination.cost
+ if l_cost < cost:
+ cost = l_cost
+ memory = l_memory
+ layer_nodes.clear()
+ layer_nodes.extend(l_nodes)
+
+ for nodes in layer_nodes:
+ if nodes not in self.recompute_policy.keys():
+ self.recompute_policy.update({nodes: 1})
+ continue
+ self.recompute_policy.update({nodes: self.recompute_policy[nodes] + 1})
+
+ self.print_recompute_policy(memory, cost)
+
+ def knapsack_best(self, device_memory):
+ combination_num = len(self.layers_combination) - 1
+ if self.layers_combination[0] is not None:
+ combination_num = len(self.layers_combination)
+ # make combination index id begin for 1.
+ self.layers_combination.insert(0, None)
+ # init ans
+ ans = [[GoodsValue() for _ in range(self.layers_num + 1)] for _ in range(combination_num + 1)]
+ # find max goods value
+ for i in range(1, combination_num + 1):
+ for j in range(self.layers_num + 1):
+ k = 0
+ while k <= self.layers_combination[i].num and k <= j:
+ ans[i][j] = self.get_max_goods_value([i, j, k], ans, device_memory)
+ k += 1
+ self.get_all_layer_policy(combination_num, self.layers_num, ans)
+
+ def get_combination_idx(self, nodes_name):
+ for i in range(len(self.layers_combination)):
+ if self.layers_combination[i] is None:
+ continue
+ if nodes_name == self.layers_combination[i].name:
+ return i
+ return -1
+
+ def analyse_policy_to_list(self):
+ recompute_policy_list = []
+ module_layers = []
+ try:
+ module_layers = self.module_layers["module_layers"]
+ except KeyError:
+ print_rank_0("[ERROR] The key \"module_layers\" doesn't exist.")
+ module_layers_num = len(module_layers)
+ for nodes_name, v in self.recompute_policy.items():
+ idx = self.get_combination_idx(nodes_name)
+ nodes_count = [v, idx]
+ if nodes_name == self.layer_without_recompute_combination.name:
+ broadcast_value = self.layer_without_recompute_combination.broadcast_value
+ nodes_count.extend(broadcast_value for _ in range(module_layers_num + 1))
+ elif nodes_name == self.layer_full_recompute_combination.name:
+ broadcast_value = self.layer_full_recompute_combination.broadcast_value
+ nodes_count.extend(broadcast_value for _ in range(module_layers_num + 1))
+ else:
+ nodes_count.append(self.layer_recompute_one_combination.broadcast_value)
+ recompute_nodes = nodes_name.split(self.node_split_flag)
+ for layer in module_layers:
+ if layer["name"] in recompute_nodes:
+ nodes_count.append(self.layer_recompute_one_combination.broadcast_value)
+ continue
+ nodes_count.append(self.layer_without_recompute_combination.broadcast_value)
+ recompute_policy_list.append(nodes_count)
+ return recompute_policy_list
+
+ def print_list_to_policy(self, recompute_policy_list):
+ try:
+ module_layers = self.module_layers["module_layers"]
+ except KeyError as e:
+ print_rank_0("[ERROR] The key \"module_layers\" doesn't exist.")
+ raise e
+ module_layers_num = len(module_layers)
+ if len(recompute_policy_list) == 0:
+ return
+ fmt_str = ">> final selective strategy <<\n"
+ for policy in recompute_policy_list:
+ n = policy[0]
+ if policy[2] == self.layer_without_recompute_combination.broadcast_value:
+ policy_name = self.layer_without_recompute_combination.policy_name
+ elif policy[2] == self.layer_full_recompute_combination.broadcast_value:
+ policy_name = self.layer_full_recompute_combination.policy_name
+ else:
+ policy_name = self.layer_recompute_one_combination.policy_name
+ policy = policy[3:]
+ nodes = []
+ for i in range(module_layers_num):
+ if policy[i] == self.layer_recompute_one_combination.broadcast_value:
+ nodes.append(module_layers[i]["name"])
+ fmt_str += "recomputeNodes=[{}], ".format(self.node_split_flag.join(nodes))
+ fmt_str += "{} {}\n".format(n, policy_name)
+ self.final_policy_info = fmt_str.rstrip("\n")
+
+ def get_layers_module(self, model, parent_ctx):
+ if 'is_recomputing_layer' in model:
+ if 'is_module_list' in model and 'memory' in parent_ctx:
+ self.transformer_module_memory += parent_ctx['memory']
+ elif 'is_module_list' not in model and 'memory' in model:
+ self.transformer_module_memory += model['memory']
+ self.num_layers_module.append(model)
+ if "layers" in model:
+ self.layers_num += len(model["layers"])
+ return
+ if "layers" not in model:
+ return
+ for sub_model in model["layers"]:
+ self.get_layers_module(sub_model, model)
+
+ def build_solver_info(self, model_context, pp, num_model_chunks):
+ self.pp = max(self.pp, pp)
+ self.get_layers_module(model_context, "")
+ self.total_recompute_cost = sys.maxsize
+ # first layer is not recompute
+ self.get_no_recompute_layer()
+ self.chp_input = self.first_layer_module['input']
+ self.chp_time = self.first_layer_module['time']
+ self.full_activation = self.first_layer_module['memory']
+ self.module_chunk = len(model_context['layers'])
+ self.total_forward_cost = self.chp_time * self.layers_num
+ self.static_memory = model_context['used_mem'] + self.cal_non_transformer_memory(model_context,
+ num_model_chunks)
+
+ parent_dg = self.get_dg(self.module_layers["parent_layers"])
+ stash_mem_per_layer, _ = self.calculate_cost_mem(parent_dg, 0)
+ stash_mem_total = stash_mem_per_layer * self.layers_num
+ self.without_recompute_info = (
+ f"Without recompute: total mem cost: {self.static_memory / 1024:.1f} GiB + "
+ f"{stash_mem_total / 1024:.1f} GiB, total recompute 0, speed up over all recompute 25%")
+ stash_mem_total = self.chp_input * self.layers_num
+ self.all_recompute_info = (f"With all recompute: total mem cost: {self.static_memory / 1024:.1f} GiB + "
+ f"{stash_mem_total / 1024:.1f} GiB, total recompute all")
+
+ dg = self.get_dg(self.module_layers["module_layers"])
+ self.layers_combination_init(dg, 0)
+
+ def get_policy(self, device_memory):
+ self.recompute_policy.clear()
+ self.knapsack_best(device_memory)
+ recompute_policy_list = self.analyse_policy_to_list()
+ if torch.distributed.is_initialized():
+ recompute_policy_list = self.broadcast_recompute_policy(recompute_policy_list)
+ return recompute_policy_list
+
+
+def get_graph_solver():
+ if GraphSolver.graph_solver is None:
+ GraphSolver.graph_solver = GraphSolver()
+ return GraphSolver.graph_solver
+
+
+class LayerCombination:
+ def __init__(self, config):
+ self.name = config["name"]
+ self.num = config["num"]
+ self.memory = config["memory"]
+ self.cost = config["cost"]
+ self.broadcast_value = config["broadcast_value"]
+ self.policy_name = config["policy_name"]
+
+
+class GoodsValue:
+ def __init__(self):
+ self.layer_names = []
+ self.memory = 0
+ self.cost = float('inf')
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/pluggable_allocator_adpator.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/pluggable_allocator_adpator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff3565060d03577ff5578eb20f6e6e7c34a4d87
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/pluggable_allocator_adpator.py
@@ -0,0 +1,24 @@
+import ctypes
+
+import torch_npu
+
+from mindspeed.op_builder import AdaptiveRecomputingPluggableAllocatorBuilder
+
+
+def change_allocator():
+ adaptive_recompute_module = AdaptiveRecomputingPluggableAllocatorBuilder().load()
+ adaptive_recompute_module_path = adaptive_recompute_module.__file__
+
+ new_alloc = torch_npu.npu.memory.NPUPluggableAllocator(adaptive_recompute_module_path, 'my_malloc', 'my_free')
+ torch_npu.npu.memory.change_current_allocator(new_alloc)
+
+ myallocator = ctypes.CDLL(adaptive_recompute_module_path)
+ init_fn = ctypes.cast(getattr(myallocator, "my_init"), ctypes.c_void_p).value
+ empty_fn = ctypes.cast(getattr(myallocator, "my_empty_cache"), ctypes.c_void_p).value
+ get_device_stats_fn = ctypes.cast(getattr(myallocator, "my_get_device_stats"), ctypes.c_void_p).value
+ reset_peak_stats_fn = ctypes.cast(getattr(myallocator, "my_reset_peak_stats"), ctypes.c_void_p).value
+
+ new_alloc.allocator().set_init_fn(init_fn)
+ new_alloc.allocator().set_reset_fn(empty_fn)
+ new_alloc.allocator().set_get_device_stats_fn(get_device_stats_fn)
+ new_alloc.allocator().set_reset_peak_status_fn(reset_peak_stats_fn)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/prefetch.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/prefetch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bbc1bd925f0bfd35400e80d358266df8db8c414
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/prefetch.py
@@ -0,0 +1,328 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import re
+import torch
+import torch_npu
+from megatron.training import get_args
+
+
+def get_layer_id(name):
+ if name:
+ matches = re.findall(r'\.(\d+)\.?', str(name))
+ if matches:
+ return matches[0]
+ return -1
+ return -1
+
+
+class SwapTensor:
+ def __init__(self, tensor, layer_name):
+ self.tensor = tensor
+ self.size = tensor.size()
+ self.storage_size = tensor.storage().size()
+ self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device='cpu')
+
+ self.d2h_event = None
+ self.h2d_event = torch.npu.Event()
+
+ self.stat = "device"
+ self.layer_name = layer_name
+
+ self.prefetch_data_ptr = tensor.data_ptr()
+ self.storage_data_ptr = tensor.storage().data_ptr()
+ self.layer_id = None
+ self.first_tensor = False
+ self.last_tensor = False
+ self.is_slice_tensor = tensor.storage().size() != tensor.numel()
+ self.stream = None
+ self.layer_index = 0
+
+ # device to host
+ def launch_d2h(self, stream):
+ if self.stat != "device":
+ return
+ forward_event = torch.npu.Event()
+ forward_event.record()
+ with torch.no_grad():
+ with torch_npu.npu.stream(stream):
+ stream.wait_event(forward_event)
+ if self.is_slice_tensor:
+ self.tensor_cpu.copy_(self.tensor, non_blocking=True)
+ else:
+ self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True)
+ self.stat = "d2h"
+
+ # synchronize d2h and resize 0
+ def wait_d2h_finished(self, stream, need_wait=False):
+ if self.stat != "d2h":
+ return
+ if need_wait:
+ torch.npu.current_stream().wait_stream(stream)
+ torch.npu.default_stream().wait_stream(stream)
+ self.tensor.storage().resize_(0)
+ self.stat = "host"
+
+ # resize storage_size and host to device
+ def launch_h2d(self, stream, flag):
+ if self.stat != "host":
+ return
+ backward_event = torch.npu.Event()
+ backward_event.record()
+ if flag:
+ self.tensor.storage().resize_(self.storage_size)
+ with torch.no_grad():
+ with torch_npu.npu.stream(stream):
+ stream.wait_event(backward_event)
+ if self.is_slice_tensor:
+ self.tensor.copy_(self.tensor_cpu, non_blocking=True)
+ else:
+ self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)
+ self.h2d_event.record()
+ self.stat = "h2d"
+
+ # synchronize h2d
+ def wait_h2d_finished(self, stream, need_wait=False):
+ if self.stat != "h2d":
+ return
+ if need_wait:
+ torch.npu.current_stream().wait_stream(stream)
+ torch.npu.default_stream().wait_stream(stream)
+ self.stat = "device"
+
+
+class SwapPrefetch:
+ swap_prefetch = None
+
+ def __init__(self, prefetch_args):
+ swap_list, vpp, interval, num_prefetch = prefetch_args
+ all_args = get_args()
+ self.prefetch_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ self.pp = all_args.pipeline_model_parallel_size
+ self.vpp = min(vpp, num_prefetch)
+ self.first_layer_id = 0
+ if isinstance(all_args.noop_layers, set):
+ for layer_id in swap_list[0]:
+ if layer_id != '':
+ self.first_layer_id = int(layer_id)
+ break
+
+ self.swap_tensors = []
+ self.layer_name = ""
+
+ self.data_ptr = {}
+ self.prefetch_list = []
+ self.prefetch_data_ptr_list = []
+ self.cur_micro_num = 0
+ self.remove_num = 0
+ self.forward_flag = False
+ self.interval = interval
+ self.slice_tensor_storage_ptr = {}
+ self.slice_tensor_storage_ptr_list = []
+ self.eval_end_flag = False
+
+ @staticmethod
+ def no_swap_tensor(ori_tensor):
+ if ori_tensor.numel() * ori_tensor.element_size() * 2 < 1024 * 1024:
+ return True
+ if ori_tensor.grad_fn is None:
+ return True
+ if ori_tensor.storage().size() == 0:
+ return True
+ if ori_tensor.storage().size() != ori_tensor.numel():
+ return True
+ if ori_tensor._base is not None and ori_tensor._base.dim() >= 5:
+ return True
+
+ return False
+
+ def pack_hook(self, ori_tensor):
+ args = get_args()
+ if args.eval_interval:
+ if args.curr_iteration % args.eval_interval != 0:
+ self.eval_end_flag = False
+ if args.curr_iteration and args.curr_iteration % args.eval_interval == 0 and not self.eval_end_flag:
+ self.prefetch_data_ptr_list = []
+ self.prefetch_list = []
+ self.slice_tensor_storage_ptr_list = []
+ self.eval_end_flag = True
+
+ if self.no_swap_tensor(ori_tensor):
+ return ori_tensor
+ swap_tensor = SwapTensor(ori_tensor, self.layer_name)
+ if not self.swap_tensors:
+ swap_tensor.first_tensor = True
+ # Records the slice tensor status.
+ if ori_tensor.storage().size() != ori_tensor.numel():
+ swap_tensor.is_slice_tensor = True
+ if ori_tensor.storage().data_ptr() not in self.slice_tensor_storage_ptr:
+ if self.swap_tensors and self.swap_tensors[0].layer_id != 0:
+ self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()] = \
+ [f'{len(self.prefetch_list) - 1}_{len(self.swap_tensors)}']
+ else:
+ self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()] = \
+ [f'{len(self.prefetch_list)}_{len(self.swap_tensors)}']
+ else:
+ if self.swap_tensors and self.swap_tensors[0].layer_id != 0:
+ self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()].append(
+ f'{len(self.prefetch_list) - 1}_{len(self.swap_tensors)}')
+ else:
+ self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()].append(
+ f'{len(self.prefetch_list)}_{len(self.swap_tensors)}')
+
+ # Records the same data_ptr tensor status.
+ if ori_tensor.storage().data_ptr() in self.data_ptr:
+ self.swap_tensors[self.data_ptr[ori_tensor.storage().data_ptr()]].stat = 'h2d'
+ swap_tensor.stat = 'd2h'
+ swap_tensor.tensor_cpu = self.swap_tensors[self.data_ptr[ori_tensor.storage().data_ptr()]].tensor_cpu
+ self.data_ptr[ori_tensor.storage().data_ptr()] = len(self.swap_tensors)
+ else:
+ self.data_ptr[ori_tensor.storage().data_ptr()] = len(self.swap_tensors)
+
+ swap_tensor.launch_d2h(self.prefetch_stream)
+ swap_tensor.stream = self.prefetch_stream
+ swap_tensor.layer_id = int(get_layer_id(swap_tensor.layer_name))
+ self.swap_tensors.append(swap_tensor)
+ self.forward_flag = True
+ return swap_tensor
+
+ def unpack_hook(self, swap_tensor):
+ if isinstance(swap_tensor, torch.Tensor):
+ return swap_tensor
+ swap_tensor.wait_h2d_finished(self.prefetch_stream, swap_tensor.last_tensor)
+ self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index].remove(swap_tensor)
+ # Remove prefetch completed list
+ if len(self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index]) == 0:
+ self.prefetch_list[self.cur_micro_num].remove(
+ self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index])
+ self.prefetch_data_ptr_list[self.cur_micro_num].remove(
+ self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index])
+ self.slice_tensor_storage_ptr_list[self.cur_micro_num].remove(
+ self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index])
+ if len(self.prefetch_list[self.cur_micro_num]) == 0:
+ self.prefetch_list.remove(self.prefetch_list[self.cur_micro_num])
+ self.prefetch_data_ptr_list.remove(self.prefetch_data_ptr_list[self.cur_micro_num])
+ self.slice_tensor_storage_ptr_list.remove(self.slice_tensor_storage_ptr_list[self.cur_micro_num])
+ self.remove_num += 1
+ if self.remove_num // self.pp == self.vpp:
+ self.remove_num = 0
+ self.forward_flag = False
+ return swap_tensor.tensor
+
+ def hook_swap_manager_forward(self, forward_func, layer_name):
+ def custom_forward(*args, **kargs):
+ self.layer_name = layer_name
+ with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):
+ return forward_func(*args, **kargs)
+
+ return custom_forward
+
+ def update_slice_tensor_stat(self, swap_tensor):
+ if swap_tensor.is_slice_tensor and swap_tensor.storage_data_ptr in self.slice_tensor_storage_ptr:
+ _, index = self.slice_tensor_storage_ptr[swap_tensor.storage_data_ptr][0].split('_')
+ if swap_tensor != self.swap_tensors[int(index)]:
+ swap_tensor.stat = 'host'
+ return False
+ return True
+
+ def sync_d2h(self, module_name):
+ if not self.swap_tensors:
+ return
+ if self.swap_tensors[0].layer_id <= self.first_layer_id:
+ self.first_layer_id = self.swap_tensors[0].layer_id
+ elif self.prefetch_list and self.swap_tensors[0].layer_id <= self.prefetch_list[-1][-1][-1].layer_id:
+ self.first_layer_id = self.swap_tensors[0].layer_id
+ first_resize_tensor = False
+ for swap_tensor in self.swap_tensors:
+ if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list:
+ swap_tensor.layer_index = len(self.prefetch_list[-1])
+ if swap_tensor.layer_id == int(get_layer_id(module_name)) \
+ and swap_tensor.stat == "d2h":
+ if not self.update_slice_tensor_stat(swap_tensor):
+ continue
+ if not first_resize_tensor:
+ swap_tensor.first_tensor = True
+ first_resize_tensor = True
+ # During synchronization, let the first tensor wait for d2h
+ swap_tensor.wait_d2h_finished(swap_tensor.stream, swap_tensor.first_tensor)
+ self.swap_tensors[-1].last_tensor = True
+ if self.swap_tensors[-1].stat == 'host':
+ if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list:
+ self.prefetch_list[-1].append(self.swap_tensors)
+ self.prefetch_data_ptr_list[-1].append(self.data_ptr)
+ self.slice_tensor_storage_ptr_list[-1].append(self.slice_tensor_storage_ptr)
+ else:
+ self.prefetch_list.append([self.swap_tensors])
+ self.prefetch_data_ptr_list.append([self.data_ptr])
+ self.slice_tensor_storage_ptr_list.append([self.slice_tensor_storage_ptr])
+ self.swap_tensors = []
+ self.data_ptr = {}
+ self.slice_tensor_storage_ptr = {}
+ if self.vpp == 1:
+ self.cur_micro_num = 0
+ else:
+ if not self.remove_num and len(self.prefetch_list) > self.pp:
+ self.cur_micro_num = self.pp * (self.vpp - 1)
+ elif self.remove_num and self.remove_num % self.pp == 0:
+ self.cur_micro_num = self.pp * (self.vpp - 1 - self.remove_num // self.pp)
+
+ def h2d_special_tensor(self, swap_tensor):
+ if swap_tensor.is_slice_tensor:
+ if swap_tensor.storage_data_ptr in self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index]:
+ _, index = self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr][
+ 0].split('_')
+ if swap_tensor == self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index][int(index)]:
+ swap_tensor.launch_h2d(self.prefetch_stream, True)
+ del self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr]
+ else:
+ swap_tensor.launch_h2d(self.prefetch_stream, False)
+ else:
+ swap_tensor.launch_h2d(self.prefetch_stream, True)
+
+ def h2d(self, module_name):
+ if not self.prefetch_list:
+ return
+ if self.vpp != 1 and not self.forward_flag:
+ self.cur_micro_num = self.pp * (self.vpp - 1 - self.remove_num // self.pp)
+ for swap_tensor_list in self.prefetch_list[self.cur_micro_num]:
+ for swap_tensor in reversed(swap_tensor_list):
+ if swap_tensor.layer_id + self.interval == int(get_layer_id(module_name)) \
+ and swap_tensor.stat == "host" \
+ and swap_tensor.storage_data_ptr in self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index]:
+ del self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr]
+ # For slice tensors, only the first tensor is resized. Other h2d the tensor size
+ self.h2d_special_tensor(swap_tensor)
+
+
+def get_swap_prefetch(prefetch_args):
+ if SwapPrefetch.swap_prefetch is None:
+ SwapPrefetch.swap_prefetch = SwapPrefetch(prefetch_args)
+
+ return SwapPrefetch.swap_prefetch
+
+
+def pre_forward_hook_func(module_name, prefetch_args):
+ def custom_func(module, *args, **kargs):
+ get_swap_prefetch(prefetch_args).sync_d2h(module_name)
+
+ return custom_func
+
+
+def post_backward_hook_func(module_name, prefetch_args):
+ def custom_func(module, *args, **kargs):
+ get_swap_prefetch(prefetch_args).h2d(module_name)
+
+ return custom_func
+
+
+# manage activation tensor
+def prefetch_tensor(module, name, prefetch_args):
+ get_swap_prefetch(prefetch_args).hook_swap_manager_forward(module.forward, name)
+
+
+# register prefetch before backward, prefetch h2d
+def prefetch_register_post_backward_hook(module, name, prefetch_args):
+ module.register_backward_hook(post_backward_hook_func(name, prefetch_args))
+
+
+# register prefetch after forward, sync d2h
+def prefetch_register_pre_forward_hook(module, name, prefetch_args):
+ module.register_forward_hook(pre_forward_hook_func(name, prefetch_args))
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/swap_manager.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/swap_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b46067dbd76517d2f12a885e76161f44785b65ce
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/swap_manager.py
@@ -0,0 +1,226 @@
+import os
+import time
+
+from megatron.training import print_rank_0
+from mindspeed.core.memory.adaptive_recomputing.swappable_tensor import SwappableTensor
+
+
+class SwapManagerMeta(type):
+ swap_manager_instance = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls.swap_manager_instance:
+ instance = super().__call__(*args, **kwargs)
+ cls.swap_manager_instance[cls] = instance
+ return cls.swap_manager_instance[cls]
+
+
+class SwapManager(metaclass=SwapManagerMeta):
+ def __init__(self):
+ self.host_tensors = {}
+ self.device_tensors = {}
+ self.total_swap_out_size = 0
+
+ @staticmethod
+ def is_allowed_wrap_tensor(tensor):
+ if isinstance(tensor, SwappableTensor):
+ return False
+ # min wrap tensor size, default is 1024B
+ config = os.getenv('MIN_SWAP_TENSOR_SIZE')
+ min_swap_tensor_size = 1024
+ if config is not None:
+ try:
+ min_swap_tensor_size = max(min_swap_tensor_size, int(config))
+ except ValueError:
+ print_rank_0('WARNING: MIN_SWAP_TENSOR_SIZE value error, fallback to default value 1024')
+ if get_tensor_mem_size(tensor) < min_swap_tensor_size:
+ return False
+ # leaf node tensor
+ if tensor.grad_fn is None:
+ return False
+
+ return True
+
+ def change_manager_tensor_status_to_allowed_swap(self):
+ for k in self.device_tensors.keys():
+ self.device_tensors[k].is_allowed_swap = True
+
+ def wrap_tensor(self, tensor, pre_tensor_is_allowed_swap=False):
+ """
+ Wrap the original tensor.
+ The tensor will be stored in the wrapped tensor. The original tensor may will be swap out to host cpu to release
+ device memory when the swapping function is called
+ :param pre_tensor_is_allowed_swap: pre tensor is allowed swap to CPU
+ :param tensor: torch tensor which is needed to wrap
+ :return: wrapped tensor
+ """
+ if pre_tensor_is_allowed_swap:
+ self.change_manager_tensor_status_to_allowed_swap()
+ if not self.is_allowed_wrap_tensor(tensor):
+ return tensor
+ wrapped_tensor = SwappableTensor(tensor)
+ if tensor.storage().size() != tensor.numel():
+ wrapped_tensor.is_slice_tensor = True
+ key = time.time()
+ wrapped_tensor.set_tensor(key, tensor)
+ self.device_tensors[key] = wrapped_tensor
+ return wrapped_tensor
+
+ def is_exist_tensor_allowed_swap(self):
+ for tensor in self.device_tensors.values():
+ if tensor.is_allowed_swap:
+ return True
+ return False
+
+ def is_exist_tensor_contiguous(self):
+ for tensor in self.device_tensors.values():
+ if tensor.get_tensor().is_contiguous() and tensor.is_allowed_swap:
+ return True
+ return False
+
+ def move_shard_tensor_to_host(self, bro_key, bro_tensor):
+ move_count = 0
+ device_tensors_keys = list(self.device_tensors.keys())
+ for key in device_tensors_keys:
+ tensor = self.device_tensors[key]
+ if tensor.inner_tensor_data_ptr == bro_tensor.inner_tensor_data_ptr:
+ self.device_tensors.pop(key)
+ tensor.set_tensor_location("cpu")
+ tensor.inner_tensor_bro_keys.append(bro_key)
+ bro_tensor.inner_tensor_bro_keys.append(key)
+ self.host_tensors[key] = tensor
+ move_count += 1
+ self.host_tensors[bro_key] = bro_tensor
+
+ return move_count
+
+ def is_last_slice_shard_tensor_to_host(self, bro_key, bro_tensor):
+ device_tensors_keys = list(self.device_tensors.keys())
+ for key in device_tensors_keys:
+ tensor = self.device_tensors[key]
+ if key != bro_key and tensor.get_slice_tensor() and tensor.storage_data_ptr == bro_tensor.storage_data_ptr:
+ return False
+ return True
+
+ def swap_out_by_size(self, size):
+ """
+ swap some tensors to host memory
+ :param size: total size which is requested to release memory
+ :return: true or false
+ """
+ print_rank_0("Need tensor size is : %d" % (size))
+ if not self.device_tensors or not self.is_exist_tensor_allowed_swap():
+ return False
+ swap_size = 0
+ swap_tensor_num = 0
+ only_swap_contiguous_tensor = self.is_exist_tensor_contiguous()
+ if only_swap_contiguous_tensor:
+ cur_swap_size, cur_swap_tensor_num = self.traverse_swap_device_tensors(size, swap_size, False)
+ else:
+ cur_swap_size, cur_swap_tensor_num = self.traverse_swap_device_tensors(size, swap_size, True)
+ swap_size += cur_swap_size
+ swap_tensor_num += cur_swap_tensor_num
+ if swap_size != 0:
+ print_rank_0("swap tensor to CPU, tensor num: %d, release NPU memory size: %s (%d)" % (
+ swap_tensor_num, hum_convert(swap_size), swap_size))
+ print_rank_0("tensor nums wrap manager for [device: %d, CPU: %d]" % (
+ len(self.device_tensors), len(self.host_tensors)))
+ self.total_swap_out_size += swap_size
+ return True
+
+ def traverse_swap_device_tensors(self, size, swap_size, is_swap_not_contiguous):
+ cur_swap_size = 0
+ cur_swap_tensor_num = 0
+ device_tensors_keys = list(self.device_tensors.keys())
+ # swap device memory size multiple
+ config = os.getenv('SWAP_SIZE_MULTIPLE')
+ swap_size_multiple = 1
+ if config is not None:
+ try:
+ swap_size_multiple = max(1, int(config))
+ except ValueError:
+ print_rank_0('WARNING: SWAP_SIZE_MULTIPLE value error, fallback to default value 1')
+ for key in device_tensors_keys:
+ if swap_size + cur_swap_size >= size * swap_size_multiple:
+ break
+ if key not in self.device_tensors.keys():
+ continue
+ tensor = self.device_tensors[key]
+ if not is_swap_not_contiguous and not tensor.get_tensor().is_contiguous():
+ continue
+ if tensor.is_allowed_swap:
+ tensor_size = 0
+ if tensor.get_slice_tensor():
+ is_last_slice_tensor = self.is_last_slice_shard_tensor_to_host(key, tensor)
+ if is_last_slice_tensor:
+ tensor_size = tensor.get_tensor_origin_storage()
+ tensor.trans_to_cpu()
+ else:
+ tensor.slice_tensor_trans_to_cpu()
+ else:
+ tensor_size = tensor.get_tensor().numel() * tensor.get_tensor().element_size()
+ tensor.trans_to_cpu()
+ cur_swap_size += tensor_size
+ self.device_tensors.pop(key)
+ self.host_tensors[key] = tensor
+ move_count = self.move_shard_tensor_to_host(key, tensor)
+ cur_swap_tensor_num += 1 + move_count
+ return cur_swap_size, cur_swap_tensor_num
+
+ def unwrap_tensor(self, tensor):
+ """
+ Unwrap the tensor.
+ If tensor is not on the device, the tensor will be swapped in to make sure that tensor is on device to compute.
+ return the torch tensor to compute in torch graph
+ :param tensor: wrapped tensor
+ :return: origin tensor
+ """
+ if not isinstance(tensor, SwappableTensor):
+ return tensor
+
+ if tensor.id_key in self.host_tensors.keys():
+ self.host_tensors.pop(tensor.id_key)
+ if tensor.get_tensor().storage().size() == 0:
+ self.move_shard_tensor_to_device(tensor)
+ else:
+ tensor.trans_to_device(False)
+ else:
+ self.device_tensors.pop(tensor.id_key)
+
+ return tensor.get_tensor()
+
+ def move_shard_tensor_to_device(self, tensor):
+ cap_tensor = tensor
+ if tensor.inner_tensor_cpu_data is None:
+ cap_key = tensor.inner_tensor_bro_keys[0]
+ try:
+ cap_tensor = self.host_tensors[cap_key]
+ except KeyError:
+ print_rank_0("[ERROR] The key doesn't exist.")
+ cap_tensor.trans_to_device(True)
+ if cap_tensor.id_key != tensor.id_key:
+ cap_tensor.inner_tensor_bro_keys.remove(tensor.id_key)
+ self.host_tensors.pop(cap_tensor.id_key)
+ self.device_tensors[cap_tensor.id_key] = cap_tensor
+ for key in cap_tensor.inner_tensor_bro_keys:
+ bro_tensor = self.host_tensors.pop(key)
+ bro_tensor.set_tensor_location("device")
+ self.device_tensors[key] = bro_tensor
+
+ def reset_swap_manager_tensors(self):
+ self.device_tensors.clear()
+ self.host_tensors.clear()
+
+
+def hum_convert(value):
+ units = ["B", "KB", "MB", "GB", "TB", "PB"]
+ origin_value = value
+ for unit in units:
+ if (value / 1024.0) < 1:
+ return "%.2f%s" % (value, unit)
+ value = value / 1024.0
+ return "%.2f%s" % (origin_value, units[0])
+
+
+def get_tensor_mem_size(tensor):
+ return tensor.numel() * tensor.element_size()
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/swappable_tensor.py b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/swappable_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..82d9184e602a464e02093c94ff0bf5b4deba8aff
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/adaptive_recomputing/swappable_tensor.py
@@ -0,0 +1,88 @@
+import torch
+
+
+class SwappableTensor(torch.Tensor):
+
+ @classmethod
+ def __new__(cls, tensor, *args, **kwargs):
+ # construct a fake tensor to unique tensors
+ data = torch.Tensor([id(tensor)])
+ return torch.Tensor._make_subclass(cls, data, False)
+
+ def __init__(self, tensor):
+ self.id_key = None
+ self.inner_tensor = None
+ self.inner_tensor_bro_keys = []
+ self.inner_tensor_cpu_data = None
+ self.storage_data_ptr = None
+ self.inner_tensor_data_ptr = None
+ self.inner_tensor_origin_storage_size = 0
+ self.inner_tensor_origin_storage_ele_size = 0
+ self.is_allowed_swap = False
+ self._device = None
+ self._location = None
+ self.is_slice_tensor = tensor.storage().size() != tensor.numel()
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ return super().__torch_function__(func, types, args, kwargs)
+
+ def set_tensor(self, id_key, tensor):
+ self.id_key = id_key
+ self.inner_tensor = tensor
+ self.inner_tensor_data_ptr = tensor.data_ptr()
+ self.storage_data_ptr = tensor.storage().data_ptr()
+ self.inner_tensor_origin_storage_size = tensor.storage().size()
+ self.inner_tensor_origin_storage_ele_size = tensor.storage().element_size()
+ self._location = "device"
+ self._device = tensor.device
+
+ def get_tensor(self):
+ return self.inner_tensor
+
+ def set_tensor_location(self, location):
+ self._location = location
+
+ def trans_to_cpu(self):
+ """
+ translate tensor to host memory
+ :return:
+ """
+ with torch.no_grad():
+ self.inner_tensor_cpu_data = self.inner_tensor.cpu()
+ self.inner_tensor.storage().resize_(0)
+ self._location = "cpu"
+
+ def slice_tensor_trans_to_cpu(self):
+ """
+ translate slice tensor to host memory
+ """
+ with torch.no_grad():
+ self.inner_tensor_cpu_data = self.inner_tensor.cpu()
+ self._location = "cpu"
+
+ def trans_to_device(self, resize_flag):
+ """
+ translate tensor to device
+ :return:
+ """
+ with torch.no_grad():
+ if resize_flag:
+ self.inner_tensor.storage().resize_(self.inner_tensor_origin_storage_size)
+ if self.is_slice_tensor:
+ self.inner_tensor.copy_(self.inner_tensor_cpu_data)
+ else:
+ self.inner_tensor.storage().copy_(self.inner_tensor_cpu_data.storage())
+ self._location = "device"
+
+ def get_location(self):
+ return self._location
+
+ def get_slice_tensor(self):
+ return self.is_slice_tensor
+
+ def get_tensor_origin_storage(self):
+ return self.inner_tensor_origin_storage_size * self.inner_tensor_origin_storage_ele_size
+
diff --git a/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline.py b/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7ff2813564e574f61af737b81b86fcac4d16c5
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline.py
@@ -0,0 +1,371 @@
+import time
+
+_TRAIN_START_TIME = time.time()
+import json
+import os.path
+import gc
+import copy
+from functools import wraps
+import torch
+import torch.nn
+import torch_npu
+from megatron.training import print_rank_0
+from megatron.training.arguments import parse_args
+from megatron.core import parallel_state
+from megatron.core.parallel_state import get_embedding_group
+from megatron.training import get_args
+from megatron.training import get_timers
+from megatron.training import training
+from megatron.training.training import print_datetime
+from megatron.core.pipeline_parallel import p2p_communication
+from megatron.core import mpu, tensor_parallel
+from megatron.training.initialize import initialize_megatron, set_jit_fusion_options
+
+
+class AutoPipeline:
+ auto_pipeline = None
+
+ def __init__(self, args):
+ self.args = copy.deepcopy(args)
+ self.context = {
+ 'module': []
+ }
+ self.modules_hooks = []
+ self.profiling_step = 0
+ self.stop_profiling_step = 5
+ self.unit_mb = 1024 * 1024
+
+ @staticmethod
+ def get_memory_status():
+ used_memory = torch.npu.memory_allocated()
+ reserved_memory = torch.npu.memory_reserved()
+ return used_memory, reserved_memory
+
+ def _cal_tensor_size(self, tensor):
+ try:
+ return tensor.numel() * tensor.element_size() / self.unit_mb
+ except ZeroDivisionError:
+ return 0
+
+ def pre_hook_func(self, state, sync: bool, *args, **kargs):
+ if sync:
+ torch.npu.synchronize()
+ used_memory, _ = self.get_memory_status()
+ torch.npu.reset_max_memory_allocated()
+ state['memory'] = used_memory
+ torch.npu.synchronize()
+ state['time'] = time.time()
+ size = 0
+ for arg in args:
+ if isinstance(arg, torch.Tensor):
+ size += self._cal_tensor_size(arg)
+ elif isinstance(arg, tuple) or isinstance(arg, list):
+ for t in arg:
+ if isinstance(t, torch.Tensor):
+ size += self._cal_tensor_size(t)
+ state['input'] = size
+
+ def post_hook_func(self, state, sync: bool, *args, **kargs):
+ if sync:
+ torch.npu.synchronize()
+ used_memory, _ = self.get_memory_status()
+ max_mem = torch.npu.max_memory_allocated()
+ state['peak_memory'] = max_mem - state['memory']
+ state['memory'] = (used_memory - state['memory']) // self.unit_mb
+ if 'pre_total_time' in state:
+ state['forward_cnt'] += 1
+ state['time'] = (time.time() - state['time']) * 1000
+ state['pre_total_time'] += state['time']
+ try:
+ state['time'] = state['pre_total_time'] / state['forward_cnt']
+ except ZeroDivisionError:
+ state['time'] = 0
+ else:
+ state['forward_cnt'] = 0
+ state['time'] = (time.time() - state['time']) * 1000
+ state['pre_total_time'] = 0
+
+ def forward_pre_hook(self, name, parent_ctx, ctx):
+ if self.profiling_step < self.stop_profiling_step:
+ ctx['name'] = name
+ if 'layers' in parent_ctx:
+ parent_ctx['layers'].append(ctx)
+
+ def hook(module, *args, **kargs):
+ if self.profiling_step < self.stop_profiling_step:
+ if 'module' in self.context:
+ self.context['module'].append(ctx)
+ self.pre_hook_func(ctx, True, *args, **kargs)
+
+ return hook
+
+ def forward_post_hook(self, ctx):
+ def hook(module, *args, **kargs):
+ if self.profiling_step < self.stop_profiling_step:
+ self.post_hook_func(ctx, True, *args)
+ if 'module' in self.context:
+ self.context['module'].pop()
+
+ return hook
+
+ def register_recursive_hook(self, prefix_name, model, ctx):
+ for name, module in model.named_children():
+ if 'layers' not in ctx:
+ ctx['layers'] = []
+ current_ctx = {}
+
+ next_name = prefix_name + "." + name if prefix_name != "" else name
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(name, ctx, current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+ self.register_recursive_hook(next_name, module, current_ctx)
+
+ def step_hook(self, model):
+ self.profiling_step += 1
+
+ def hook_step_func(self, step_func, models):
+ def custom_step_func(*args, **kargs):
+ result = step_func(*args, **kargs)
+ if self.profiling_step < self.stop_profiling_step:
+ used_memory, reserved_memory = self.get_memory_status()
+ self.context['used_mem'] = used_memory // self.unit_mb
+ if isinstance(models, list):
+ for model in models:
+ self.step_hook(model)
+ else:
+ self.step_hook(models)
+ return result
+
+ return custom_step_func
+
+ def get_comm_time(self, config, sync: bool):
+ if torch.distributed.get_rank() == 0:
+ if sync:
+ torch.npu.synchronize()
+ input_tensor = torch.ones(self.args.seq_length, self.args.micro_batch_size, self.args.hidden_size)
+ start_time = time.time()
+ p2p_communication.send_backward(input_tensor, config)
+ comm_time = (time.time() - start_time) * 1000
+ self.context['comm_time'] = comm_time
+ else:
+ self.context['comm_time'] = 0.028
+
+ def get_modules_params_by_stages(self, init_memory, sync: bool):
+
+ if self.args.pipeline_model_parallel_size == 2:
+ self.context['first_stage_embed'] = self.args.padded_vocab_size * self.args.hidden_size
+ self.context['last_stage_embed'] = self.args.padded_vocab_size * self.args.hidden_size
+ attention_block = 3 * self.args.hidden_size * self.args.num_attention_heads * (
+ self.args.hidden_size / self.args.num_attention_heads) + self.args.hidden_size * self.args.hidden_size + self.args.hidden_size + self.args.hidden_size
+ ffn_block = 3 * self.args.ffn_hidden_size * self.args.hidden_size + self.args.hidden_size + self.args.hidden_size
+ per_trans_layer_param = attention_block + ffn_block
+ per_trans_layer_param /= self.args.tensor_model_parallel_size
+ self.context['per_trans_layer_param'] = per_trans_layer_param
+
+ else:
+ first_stage_param = 0
+ per_trans_layer_param = 0
+ last_stage_param = 0
+ if sync:
+ torch.npu.synchronize()
+ first_stage_rank = 0
+ last_stage_rank = torch.distributed.get_world_size() - 1
+ layer_stage_rank = self.args.tensor_model_parallel_size
+
+ first_stage_param = self.broadcast_param_in_ranks(first_stage_rank, first_stage_param, init_memory)
+ last_stage_param = self.broadcast_param_in_ranks(last_stage_rank, last_stage_param, init_memory)
+ per_trans_layer_param = self.broadcast_param_in_ranks(layer_stage_rank, per_trans_layer_param, init_memory)
+
+ self.context['first_stage_embed'] = first_stage_param - per_trans_layer_param
+ self.context['last_stage_embed'] = last_stage_param - per_trans_layer_param
+ self.context['per_trans_layer_param'] = per_trans_layer_param
+
+ def broadcast_param_in_ranks(self, src_rank, param, init_memory):
+ if torch.distributed.get_rank() == src_rank:
+ param = torch.npu.max_memory_allocated() / self.unit_mb - init_memory
+ tmp_param = torch.cuda.IntTensor([param])
+ torch.distributed.broadcast(tmp_param, src=src_rank)
+ param = tmp_param.item()
+ return param
+
+ def update_args_for_profiling(self):
+ args = get_args()
+ if args.num_layers_per_virtual_pipeline_stage is None:
+ args.num_layers = self.args.pipeline_model_parallel_size
+ args.encoder_num_layers = self.args.pipeline_model_parallel_size
+ args.train_iters = self.stop_profiling_step
+ args.save = False
+ args.log_interval = 10
+
+ def restore_args_for_training(self):
+ args = get_args()
+ if args.num_layers_per_virtual_pipeline_stage is None:
+ args.num_layers = self.args.num_layers
+ args.encoder_num_layers = self.args.num_layers
+ args.train_iters = self.args.train_iters
+ args.optimizer = self.args.optimizer
+ args.save = self.args.save
+ args.log_interval = self.args.log_interval
+
+
+def check_equal_model_configs(args, parsed_contents):
+ model_index = 0
+ for model_instance in parsed_contents:
+ if args.hidden_size == model_instance["model_configs"]["hidden_size"] \
+ and args.ffn_hidden_size == model_instance["model_configs"]["ffn_hidden_size"] \
+ and args.seq_length == model_instance["model_configs"]["seq_length"] \
+ and args.num_attention_heads == model_instance["model_configs"]["num_attention_heads"]:
+ return model_index
+ else:
+ model_index += 1
+ return -1
+
+
+def check_equal_parallel_configs(args, parsed_content):
+ for parallel_instance in parsed_content["autopipeline_policy"]:
+ if args.num_layers == parallel_instance["num_layers"] \
+ and args.pipeline_model_parallel_size == parallel_instance["pipeline_model_parallel_size"] \
+ and args.tensor_model_parallel_size == parallel_instance["tensor_model_parallel_size"] \
+ and args.save_memory_ratio == parallel_instance["ratio"]:
+ return parallel_instance["num_layer_list"], parallel_instance["recompute_module_list"], parallel_instance[
+ "recompute_type"]
+ return None, None, None
+
+
+def check_skip_profiling(args, config_file):
+ if os.path.exists(config_file):
+ with open(config_file) as config_json:
+ config_contents = config_json.read()
+ parsed_contents = json.loads(config_contents)
+ index = check_equal_model_configs(args, parsed_contents)
+ if index != -1:
+ num_layer_list, recompute_module_list, recompute_type = check_equal_parallel_configs(args,
+ parsed_contents[index])
+ if num_layer_list:
+ return True, [(num_layer_list, recompute_module_list, (0, [0]), recompute_type)]
+ return False, None
+
+
+def set_recompute_mode(models):
+ for model in models:
+ for name, module in model.named_modules():
+ if str.isdigit(name) and name != "0":
+ module.forward = hook_checkpoint_forward(module.forward)
+
+
+def hook_checkpoint_forward(forward_func):
+ def custom_forward(*args, **kargs):
+ def inside_forward(*args):
+ return forward_func(*args, **kargs)
+
+ return tensor_parallel.checkpoint(inside_forward, None, *args)
+
+ return custom_forward
+
+
+def get_auto_pipeline(args):
+ if AutoPipeline.auto_pipeline is None:
+ AutoPipeline.auto_pipeline = AutoPipeline(args)
+ return AutoPipeline.auto_pipeline
+
+
+def initialize_cfg_from_args_wrapper(initialize_cfg_from_args):
+ @wraps(initialize_cfg_from_args)
+ def wrapper(*args, **kwargs):
+ from mindspeed.core import training as mc_training
+ argument = get_args()
+ disable_mc2 = argument.automated_pipeline and not mc_training.policy
+ if not disable_mc2:
+ initialize_cfg_from_args(*args, **kwargs)
+ return wrapper
+
+
+def autopipeline_profiling(model_provider, model_type, forward_step_func, train_valid_test_dataset_provider,
+ process_non_loss_data_func, args):
+ is_skip, policy = check_skip_profiling(args, config_file="autopipeline_config.json")
+ if not is_skip:
+ initialize_megatron(extra_args_provider=None,
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
+ set_jit_fusion_options()
+ global _TRAIN_START_TIME
+ start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
+ torch.distributed.all_reduce(start_time_tensor,
+ op=torch.distributed.ReduceOp.MIN)
+ _TRAIN_START_TIME = start_time_tensor.item()
+ print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
+ time.time() - _TRAIN_START_TIME))
+ print_datetime('after megatron is initialized')
+ args = get_args()
+ pipelining = get_auto_pipeline(args)
+ pipelining.update_args_for_profiling()
+ init_memory = torch.npu.max_memory_allocated() / pipelining.unit_mb
+ models, optimizer, lr_scheduler = training.setup_model_and_optimizer(model_provider, model_type)
+ optimizer.step = pipelining.hook_step_func(optimizer.step, models)
+ config = training.get_model_config(models[0])
+
+ if args.virtual_pipeline_model_parallel_size is not None:
+ train_data_iterator = []
+ valid_data_iterator = []
+ for i in range(len(models)):
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
+ iterators = training.build_train_valid_test_data_iterators(
+ train_valid_test_dataset_provider)
+ train_data_iterator.append(iterators[0])
+ valid_data_iterator.append(iterators[1])
+ else:
+ train_data_iterator, valid_data_iterator, _ = training.build_train_valid_test_data_iterators(
+ train_valid_test_dataset_provider)
+ if isinstance(models, list):
+ for model in models:
+ pipelining.register_recursive_hook("module", model, pipelining.context)
+ else:
+ pipelining.register_recursive_hook("module", models, pipelining.context)
+ pipelining.get_modules_params_by_stages(init_memory, sync=True)
+ set_recompute_mode(models)
+ checkpointing_context = {}
+ training.train(forward_step_func, models, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator,
+ process_non_loss_data_func, config, checkpointing_context)
+ pipelining.get_comm_time(config, sync=True)
+
+ timers = get_timers()
+ if timers('interval-time'):
+ timers('interval-time').stop(barrier=True)
+
+ for hook_handle in pipelining.modules_hooks:
+ hook_handle.remove()
+ pipelining.modules_hooks.clear()
+ pipelining.restore_args_for_training()
+
+ for key, value in optimizer.optimizer.state.items():
+ key.detach()
+ key.grad = None
+ key.storage().resize_(0)
+ if "momentum_buffer" in value:
+ value["momentum_buffer"].detach()
+ value["momentum_buffer"].grad = None
+ value["momentum_buffer"].storage().resize_(0)
+ for ofg in optimizer.param_groups:
+ if "params" in ofg:
+ for og in ofg["params"]:
+ og.detach()
+ og.grad = None
+ og.storage().resize_(0)
+ for md in models:
+ for param in md.parameters():
+ param.detach()
+ param.grad = None
+ param.storage().resize_(0)
+ for param_tensor in md.state_dict():
+ if md.state_dict()[param_tensor] is not None:
+ md.state_dict()[param_tensor].detach()
+ md.state_dict()[param_tensor].grad = None
+ md.state_dict()[param_tensor].storage().resize_(0)
+
+ gc.collect()
+ torch_npu.npu.empty_cache()
+ time.sleep(5)
+ return pipelining.context, policy
+ else:
+ print_rank_0("[INFO] Found existed automated pipeline policy, apply it directly.")
+ return None, policy
diff --git a/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline_apply.py b/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline_apply.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16004dfd68aded956df6994ecd2fdfebbfeb35a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline_apply.py
@@ -0,0 +1,53 @@
+import torch
+from megatron.training import print_rank_0
+from megatron.training import get_args
+from megatron.core import utils, parallel_state, tensor_parallel
+
+
+def apply_autopipeline(models):
+ if isinstance(models, list):
+ for model in models:
+ apply_recompute_modules(model)
+ else:
+ apply_recompute_modules(models)
+
+
+def apply_recompute_modules(model):
+ args = get_args()
+ for pp_rankid, recomp_value in enumerate(args.recompute_module_list):
+ if pp_rankid == parallel_state.get_pipeline_model_parallel_rank():
+ if recomp_value > 0:
+ set_recompute_modules(model, recomp_value, args.recompute_type)
+
+
+def set_recompute_modules(model, recomp_value, module_type):
+ recomp_pool = []
+ recomp_name = "module.module.language_model.encoder.layers."
+ for i in range(0, recomp_value):
+ tmp_recomp_name = recomp_name
+ tmp_recomp_name += str(i)
+ # mlp recompute type
+ if module_type == 0:
+ tmp_recomp_name += ".mlp"
+ recomp_pool.append(tmp_recomp_name)
+ # attention recompute type
+ if module_type == 1:
+ tmp_recomp_name += ".self_attention"
+ recomp_pool.append(tmp_recomp_name)
+ # layer recompute type
+ if module_type == 2:
+ recomp_pool.append(tmp_recomp_name)
+
+ for name, module in model.named_modules():
+ if name in recomp_pool:
+ module.forward = hook_checkpoint_forward(module.forward)
+
+
+def hook_checkpoint_forward(forward_func):
+ def custom_forward(*args, **kargs):
+ def inside_forward(*args):
+ return forward_func(*args, **kargs)
+
+ return tensor_parallel.checkpoint(inside_forward, None, *args)
+
+ return custom_forward
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline_solver.py b/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a0073ac0167e4e8aef68bc0ed8951586e71855
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/auto_pipeline/autopipeline_solver.py
@@ -0,0 +1,501 @@
+import os
+import json
+import statistics
+import math
+import time
+import multiprocessing
+from functools import wraps
+import torch
+import megatron.training.global_vars
+from megatron.training import get_args
+from megatron.training import print_rank_0
+from .autopipeline import check_equal_model_configs
+import mindspeed.model.transformer as mindspeed_transformer
+import megatron.core.parallel_state as megatron_parallel_state
+import mindspeed.core.parallel_state as mindspeed_parallel_state
+
+
+class AutoPipelineSolver():
+ def __init__(self, context):
+ self.context = context
+ self.MB_SIZE = 1024 * 1024
+ # model configurations
+ args = get_args()
+ self.num_layers = args.num_layers
+ self.vocab_size = args.padded_vocab_size
+ self.hidden_size = args.hidden_size
+ self.ffn_hidden_size = args.ffn_hidden_size
+ self.micro_batch_size = args.micro_batch_size
+ self.global_batch_size = args.global_batch_size
+ self.seq_length = args.seq_length
+ self.num_attention_heads = args.num_attention_heads
+ self.pipeline_model_parallel_size = args.pipeline_model_parallel_size
+ self.tensor_model_parallel_size = args.tensor_model_parallel_size
+
+ self.first_stage_embed = 0
+ self.last_stage_embed = 0
+ self.per_trans_layer_param = 0
+ self.embed_activation = 0
+
+ self.forward_time = 0
+ self.forward_activation = 0
+ self.mlp_forward_time = 0
+ self.comm_time = 0
+ self.forward_mlp_activation = 0
+ self.attention_forward_time = 0
+ self.forward_attention_activation = 0
+ self.layer_forward_time = 0
+ self.forward_layer_activation = 0
+ self.parse_profile()
+
+ # hyper params settings
+ self.ratio = args.save_memory_ratio if args.save_memory_ratio == 1.0 else 1 - args.save_memory_ratio
+ self.min_layer, self.max_layer = self.get_min_max_layer()
+ self.target_memory = self.set_target_memory()
+
+ # auto pipeline search result
+ self.ans = []
+ self.backup = []
+ self.backup_min_mem = 0
+ # auto pipeline policy
+ self.policy = []
+ self.optimal_sch = []
+ self.minn = []
+
+
+ def find_target_profile(self, module, target, profile_type):
+ context = self.context
+ while module in context:
+ for sub_context in context[module]:
+ if sub_context["name"] == target:
+ return sub_context[profile_type]
+ else:
+ context = sub_context
+ return 0
+
+
+ def get_min_max_layer(self):
+ layer_avg = round(self.num_layers / self.pipeline_model_parallel_size)
+ if 1 <= layer_avg <= 4:
+ layer_range = 0
+ elif 5 <= layer_avg < 8:
+ layer_range = 1
+ else:
+ layer_range = 2
+ return layer_avg - layer_range, layer_avg + layer_range
+
+
+ def parse_profile(self):
+ self.first_stage_embed = self.context["first_stage_embed"] * self.MB_SIZE
+ self.last_stage_embed = self.context["last_stage_embed"] * self.MB_SIZE
+ self.per_trans_layer_param = self.context["per_trans_layer_param"] * self.MB_SIZE
+ self.embed_activation = self.find_target_profile("layers", "embedding", "memory") * self.MB_SIZE
+
+ self.forward_time = self.find_target_profile("layers", "module", "time")
+ self.mlp_forward_time = self.find_target_profile("layers", "mlp", "time")
+ self.attention_forward_time = self.find_target_profile("layers", "self_attention", "time")
+ self.layer_forward_time = self.find_target_profile("layers", "0", "time")
+ self.comm_time = self.context["comm_time"]
+ self.forward_activation = self.find_target_profile("layers", "module", "memory") * self.MB_SIZE
+ self.forward_mlp_activation = self.find_target_profile("layers", "mlp", "memory") * self.MB_SIZE
+ self.forward_attention_activation = self.find_target_profile("layers", "self_attention", "memory") * self.MB_SIZE
+ self.forward_layer_activation = self.find_target_profile("layers", "0", "memory") * self.MB_SIZE
+
+
+ def naive_search(self, module_type, answer_queue):
+
+ def dfs_build_layers(prefix_n_layers, cur_layers_sum):
+
+ if len(prefix_n_layers) > self.pipeline_model_parallel_size:
+ return
+ if cur_layers_sum > self.num_layers:
+ return
+ if 2 <= len(prefix_n_layers) < self.pipeline_model_parallel_size:
+ if prefix_n_layers[-1] < prefix_n_layers[-2]:
+ return
+
+ if len(prefix_n_layers) == self.pipeline_model_parallel_size and cur_layers_sum == self.num_layers:
+ status, prefix_recomp_modules, mem_set = self.get_recompute_modules(prefix_n_layers, self.pipeline_model_parallel_size, module_type)
+ if status:
+ answer_queue.append((prefix_n_layers, prefix_recomp_modules, mem_set, module_type))
+ if len(answer_queue) == 0 and len(self.ans) == 0:
+ if len(self.backup) == 0:
+ self.backup.append((prefix_n_layers, prefix_recomp_modules, mem_set, module_type))
+ else:
+ temp_min_mem = min(mem_set[1])
+ if temp_min_mem < self.backup_min_mem:
+ self.backup_min_mem = temp_min_mem
+ self.backup[0] = (prefix_n_layers, prefix_recomp_modules, mem_set, module_type)
+ return
+
+ for cur_n_layer in range(self.max_layer, self.min_layer - 1, -1):
+ dfs_build_layers(prefix_n_layers + [cur_n_layer], cur_layers_sum + cur_n_layer)
+
+ for prefix_n_layer in range(self.max_layer, self.min_layer - 1, -1):
+ dfs_build_layers([prefix_n_layer], prefix_n_layer)
+
+ return answer_queue
+
+
+ def main_search(self):
+ mlp_answer_queue, attn_answer_queue, layer_answer_queue = [], [], []
+ mlp_answer_queue = self.naive_search(0, mlp_answer_queue)
+ self.ans = mlp_answer_queue
+ attn_answer_queue = self.naive_search(1, attn_answer_queue)
+ self.ans += attn_answer_queue
+ layer_answer_queue = self.naive_search(2, layer_answer_queue)
+ self.ans += layer_answer_queue
+
+ return self.ans
+
+
+ def cal_module_param(self, module_type):
+
+ per_layer_activation_param = self.forward_activation
+ per_recompute_module_param = 0
+ if module_type == 0:
+ # mlp activation param
+ per_recompute_module_param = self.forward_mlp_activation
+ if module_type == 1:
+ # attn param
+ per_recompute_module_param = self.forward_attention_activation
+ if module_type == 2:
+ # layer param
+ per_recompute_module_param = 2 * self.seq_length * self.micro_batch_size * self.hidden_size
+
+ return per_layer_activation_param, per_recompute_module_param
+
+
+ def cal_model_mem(self, per_layer_activation_param, per_recompute_module_param, n_layer, n_recompute_module, parallel_num, \
+ stage_num):
+ if stage_num == 0:
+ stage_max_optimizer_mem = (self.first_stage_embed + self.per_trans_layer_param * n_layer) + self.embed_activation
+ model_mem = self.first_stage_embed + self.per_trans_layer_param * n_layer \
+ + stage_max_optimizer_mem \
+ + per_layer_activation_param * n_layer * parallel_num
+ elif stage_num == self.pipeline_model_parallel_size - 1:
+ stage_max_optimizer_mem = (self.last_stage_embed + self.per_trans_layer_param * n_layer) + self.embed_activation
+ model_mem = self.last_stage_embed + self.per_trans_layer_param * n_layer \
+ + stage_max_optimizer_mem \
+ + per_layer_activation_param * n_layer * parallel_num
+ else:
+ stage_max_optimizer_mem = self.per_trans_layer_param * n_layer
+ model_mem = self.per_trans_layer_param * n_layer \
+ + stage_max_optimizer_mem \
+ + per_layer_activation_param * n_layer * parallel_num
+ return model_mem
+
+
+ def set_target_memory(self):
+ per_layer_activation_param, per_recompute_module_param = self.cal_module_param(0)
+ stage_num = 0
+ default_n_layers_mems = []
+ while stage_num < self.pipeline_model_parallel_size:
+ default_layer_mem = self.cal_model_mem(per_layer_activation_param, per_recompute_module_param,
+ self.num_layers/self.pipeline_model_parallel_size, 0,
+ self.pipeline_model_parallel_size - stage_num, stage_num)
+ default_n_layers_mems.append(default_layer_mem)
+ stage_num += 1
+
+ target_memory = sum(default_n_layers_mems)/len(default_n_layers_mems)
+ if self.ratio < 1.0:
+ target_memory = max(default_n_layers_mems)
+ return target_memory
+
+
+ def get_recompute_modules(self, n_layers, num_pp_stage, module_type):
+ per_layer_activation_param, per_recompute_module_param = self.cal_module_param(module_type)
+ init_recompute_modules = []
+ new_n_layers_mems = []
+ stage_num = 0
+ status = True
+
+ while stage_num < len(n_layers):
+ init_layer_mem = self.cal_model_mem(per_layer_activation_param, per_recompute_module_param,\
+ n_layers[stage_num], 0,
+ num_pp_stage - stage_num, stage_num)
+ if init_layer_mem <= self.target_memory * self.ratio:
+ n_recompute_module = 0
+ init_recompute_modules.append(n_recompute_module)
+ else:
+ if (per_recompute_module_param * (num_pp_stage - stage_num) / self.MB_SIZE) == 0:
+ n_recompute_module = 0
+ else:
+ n_recompute_module = math.ceil((init_layer_mem / self.MB_SIZE - self.target_memory * self.ratio / self.MB_SIZE) / (per_recompute_module_param * (num_pp_stage - stage_num) / self.MB_SIZE))
+ if n_recompute_module > n_layers[stage_num]:
+ status = False
+ n_recompute_module = n_layers[stage_num]
+ init_recompute_modules.append(n_recompute_module)
+ else:
+ init_recompute_modules.append(n_recompute_module)
+
+ init_layer_mem = self.cal_model_mem(per_layer_activation_param, per_recompute_module_param,
+ n_layers[stage_num], n_recompute_module,
+ num_pp_stage - stage_num, stage_num)
+ init_layer_mem -= per_recompute_module_param*n_recompute_module
+ init_layer_mem /= self.MB_SIZE
+ new_n_layers_mems.append(init_layer_mem)
+ stage_num += 1
+
+ return status, init_recompute_modules, (self.target_memory/self.MB_SIZE, new_n_layers_mems)
+
+
+ def dp(self, examples):
+ # lookup duration via parallel params
+ (Fwd, Bwd, ComFwd, ComBwd) = self.forward_time, self.forward_time * 1.3, self.comm_time, self.comm_time
+
+ RecompFwd = 0
+ module_type = examples[3]
+ if module_type == 0:
+ RecompFwd = self.mlp_forward_time
+ elif module_type == 1:
+ RecompFwd = self.attention_forward_time
+ elif module_type == 2:
+ RecompFwd = self.layer_forward_time
+
+ # to remember that n_layers can be divided by num_pp_stage
+ n_layers = [0] + examples[0]
+ n_recompute_layers = [0] + examples[1]
+ num_pp_stage = self.pipeline_model_parallel_size
+
+ # number of micro-batch-size is 256
+ mbs = [self.micro_batch_size for _ in range(self.global_batch_size)]
+ num_microbatch = len(mbs)
+ mbs = [0] + mbs
+
+ SF = [[0 for i in range(num_microbatch + 1)] for _ in range(num_pp_stage + 1)] # start of forward 图中蓝色的左边
+ EF = [[0 for i in range(num_microbatch + 1)] for _ in range(num_pp_stage + 1)] # end of forward 图中蓝色的右边
+
+ SB = [[0 for i in range(num_microbatch + 1)] for _ in range(num_pp_stage + 1)] # start of backward 图中绿色的左边
+ EB = [[0 for i in range(num_microbatch + 1)] for _ in range(num_pp_stage + 1)] # end of backward 图中绿色的右边
+
+ warmup = [num_pp_stage - p - 1 for p in range(num_pp_stage)]
+ remaining = [num_microbatch - warmup[p] for p in range(num_pp_stage)]
+
+ # for dp, p and m start with 1
+ # warmup: only forward processing, add activations
+ for p in range(1, num_pp_stage + 1):
+ for m in range(1, num_pp_stage - p + 1):
+ SF[p][m] = max(EF[p][m - 1], EF[p - 1][m] + ComFwd)
+ EF[p][m] = SF[p][m] + Fwd * n_layers[p]
+
+ # 1f1b
+ for num_1f1b in range(1, num_microbatch + 1):
+
+ # # fwd of 1f1b
+ for p in range(1, num_pp_stage + 1):
+ if remaining[p - 1] < num_1f1b:
+ # this means it have to work for cool down phase
+ continue
+
+ m = warmup[p - 1] + num_1f1b
+ if p == 1:
+ EF[0][m] = max(EF[1]) - ComFwd
+
+ SF[p][m] = max(EB[p][m + p - num_pp_stage - 1], EF[p - 1][m] + ComFwd)
+ EF[p][m] = SF[p][m] + Fwd * n_layers[p]
+
+ # bwd of 1f1b
+ for p in range(num_pp_stage, 0, -1):
+ m = num_1f1b
+ if remaining[p - 1] < num_1f1b:
+ # this means it have to work for cool down phase
+ continue
+ if p == num_pp_stage:
+ SB[p][m] = EF[p][m + num_pp_stage - p]
+ else:
+ SB[p][m] = max(EF[p][m + num_pp_stage - p], EB[p + 1][m] + ComBwd)
+
+ EB[p][m] = SB[p][m] + Bwd * n_layers[p] + RecompFwd * n_recompute_layers[p]
+
+ # cooldown
+ for p in range(num_pp_stage, 0, -1):
+ m = num_1f1b
+ if remaining[p - 1] >= num_1f1b:
+ continue
+ SB[p][m] = max(EB[p][m - 1], EB[p + 1][m] + ComBwd)
+ EB[p][m] = SB[p][m] + Bwd * n_layers[p] + RecompFwd * n_recompute_layers[p]
+
+ itertime = max([max(EB[p]) for p in range(num_pp_stage)])
+ self.policy.append((itertime, examples))
+ return
+
+
+ def find_top_optimal_schedule(self):
+ self.main_search()
+ for examples in self.ans:
+ self.dp(examples)
+
+ if len(self.policy) > 0:
+ min_itertime = self.policy[0][0]
+ self.minn.append(min_itertime)
+ self.optimal_sch.append(self.policy[0][1])
+ for idx, res in enumerate(self.policy):
+ if res[0] < min_itertime:
+ min_itertime = res[0]
+ self.minn[0] = min_itertime
+ self.optimal_sch[0] = res[1]
+ else:
+ print_rank_0("[INFO] [Autopipeline Policy Time Searching Stage] No strategy is satisfied. We will apply the minimum memory strategy instead.")
+ self.minn.append(0)
+ self.optimal_sch.append(self.backup[0])
+
+ return self.optimal_sch, self.minn
+
+
+def broadcast_policy_in_ranks(src_rank, policy=None):
+ args = get_args()
+ num_layer_list = args.pipeline_model_parallel_size * [0]
+ recompute_module_list = args.pipeline_model_parallel_size * [0]
+ recompute_type = 0
+ if torch.distributed.get_rank() == 0:
+ num_layer_list = policy[0][0]
+ recompute_module_list = policy[0][1]
+ recompute_type = policy[0][3]
+
+ tmp_layer_list = torch.cuda.IntTensor(num_layer_list)
+ torch.distributed.broadcast(tmp_layer_list, src=src_rank)
+ args.num_layer_list = tmp_layer_list.tolist()
+
+ tmp_recompute_module_list = torch.cuda.IntTensor(recompute_module_list)
+ torch.distributed.broadcast(tmp_recompute_module_list, src=src_rank)
+ args.recompute_module_list = tmp_recompute_module_list.tolist()
+
+ tmp_recompute_type = torch.cuda.IntTensor([recompute_type])
+ torch.distributed.broadcast(tmp_recompute_type, src=src_rank)
+ args.recompute_type = tmp_recompute_type.item()
+
+
+def destroy_global_vars():
+ megatron.training.global_vars._GLOBAL_ARGS = None
+ megatron.training.global_vars._GLOBAL_RETRO_ARGS = None
+ megatron.training.global_vars._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
+ megatron.training.global_vars._GLOBAL_TOKENIZER = None
+ megatron.training.global_vars._GLOBAL_TENSORBOARD_WRITER = None
+ megatron.training.global_vars._GLOBAL_WANDB_WRITER = None
+ megatron.training.global_vars._GLOBAL_ADLR_AUTORESUME = None
+ megatron.training.global_vars._GLOBAL_TIMERS = None
+ megatron.training.global_vars._GLOBAL_SIGNAL_HANDLER = None
+ megatron_parallel_state._EXPERT_PARALLEL_GROUP = None
+ mindspeed_transformer._GLOBAL_ATTN_MASK = None
+
+
+def destroy_global_parallel_group():
+ global_parallel_group = [
+ megatron_parallel_state._MODEL_PARALLEL_GROUP,
+ megatron_parallel_state._TENSOR_MODEL_PARALLEL_GROUP,
+ megatron_parallel_state._PIPELINE_MODEL_PARALLEL_GROUP,
+ mindspeed_parallel_state._PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM,
+ megatron_parallel_state._DATA_PARALLEL_GROUP,
+ megatron_parallel_state._DATA_PARALLEL_GROUP_WITH_CP,
+ megatron_parallel_state._CONTEXT_PARALLEL_GROUP,
+ megatron_parallel_state._EMBEDDING_GROUP,
+ megatron_parallel_state._POSITION_EMBEDDING_GROUP,
+ megatron_parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP,
+ megatron_parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP,
+ megatron_parallel_state._EXPERT_MODEL_PARALLEL_GROUP,
+ megatron_parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP,
+ megatron_parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP
+ ]
+ for gid in range(len(global_parallel_group)):
+ if global_parallel_group[gid]:
+ torch.distributed.destroy_process_group(global_parallel_group[gid])
+ torch.distributed.barrier()
+
+ megatron_parallel_state._MODEL_PARALLEL_GROUP = None
+ megatron_parallel_state._TENSOR_MODEL_PARALLEL_GROUP = None
+ megatron_parallel_state._PIPELINE_MODEL_PARALLEL_GROUP = None
+ mindspeed_parallel_state._PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM = None
+ megatron_parallel_state._DATA_PARALLEL_GROUP = None
+ megatron_parallel_state._DATA_PARALLEL_GROUP_WITH_CP = None
+ megatron_parallel_state._CONTEXT_PARALLEL_GROUP = None
+ megatron_parallel_state._CONTEXT_PARALLEL_GLOBAL_RANKS = None
+ megatron_parallel_state._EMBEDDING_GROUP = None
+ megatron_parallel_state._POSITION_EMBEDDING_GROUP = None
+ megatron_parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP = None
+ megatron_parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
+ megatron_parallel_state._EXPERT_MODEL_PARALLEL_GROUP = None
+ megatron_parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = None
+ megatron_parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = None
+ megatron_parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
+ megatron_parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
+ megatron_parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
+ megatron_parallel_state._MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
+ megatron_parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = None
+ megatron_parallel_state._MPU_PIPELINE_MODEL_PARALLEL_RANK = None
+ megatron_parallel_state._GLOBAL_MEMORY_BUFFER = None
+ megatron_parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None
+ megatron_parallel_state._MPU_EXPERT_MODEL_PARALLEL_RANK = None
+
+
+def destroy_model_parallel_profiling_wrapper(destroy_model_parallel):
+ @wraps(destroy_model_parallel)
+ def wrapper(*args, **kwargs):
+ argument = get_args()
+ enable_profiling_destroy = (argument.automated_pipeline and not argument.num_layer_list) \
+ or (argument.automated_pipeline_perf and not argument.optimized_mbs_list)
+ if enable_profiling_destroy:
+ destroy_global_parallel_group()
+ else:
+ destroy_model_parallel(*args, **kwargs)
+ return wrapper
+
+
+def get_profiling_data(policy, args):
+ instance = {"model_configs": {
+ "vocab_size": args.padded_vocab_size,
+ "hidden_size": args.hidden_size,
+ "ffn_hidden_size": args.ffn_hidden_size,
+ "seq_length": args.seq_length,
+ "num_attention_heads": args.num_attention_heads
+ }, "autopipeline_policy": [{
+ "num_layers": args.num_layers,
+ "pipeline_model_parallel_size": args.pipeline_model_parallel_size,
+ "tensor_model_parallel_size": args.tensor_model_parallel_size,
+ "ratio": args.save_memory_ratio,
+ "num_layer_list": policy[0][0],
+ "recompute_module_list": policy[0][1],
+ "recompute_type": policy[0][3]
+ }]}
+ return instance
+
+
+def save_profiling_data(policy, config_file):
+
+ args = get_args()
+ instance = get_profiling_data(policy, args)
+ if os.path.exists(config_file):
+ with open(config_file, "r") as config_json:
+ config_contents = config_json.read()
+ parsed_contents = json.loads(config_contents)
+ index = check_equal_model_configs(args, parsed_contents)
+ if index != -1:
+ if "autopipeline_policy" in parsed_contents[index]:
+ parsed_contents[index]["autopipeline_policy"].append(instance["autopipeline_policy"][0])
+ else:
+ parsed_contents.append(instance)
+ with open(config_file, "w") as f:
+ json.dump(parsed_contents, f, ensure_ascii=False)
+ os.chmod(config_file, 0o644)
+ else:
+ with open(config_file, "w") as f:
+ json.dump([instance], f, ensure_ascii=False)
+ os.chmod(config_file, 0o644)
+
+
+def solve_autopipeline(context):
+ if torch.distributed.get_rank() == 0:
+ solver = AutoPipelineSolver(context)
+ solver.find_top_optimal_schedule()
+ torch.distributed.barrier()
+ if torch.distributed.get_rank() == 0:
+ print_rank_0("=============================================")
+ print_rank_0("[INFO] Applied Automated Pipeline Policy: ")
+ print_rank_0("network layer setting: " + str(solver.optimal_sch[0][0]))
+ print_rank_0("recompute policy setting: " + str(solver.optimal_sch[0][1]))
+ print_rank_0("recompute type: " + str(solver.optimal_sch[0][3]))
+ print_rank_0("=============================================")
+ save_profiling_data(solver.optimal_sch, config_file="autopipeline_config.json")
+ return solver.optimal_sch
+ else:
+ return [([0], [0], (0, [0]), 0)]
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/common.py b/model/train/yoco_moe/mindspeed/core/memory/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..576713052f2e2a1af87a02f1b2aef78c7f380476
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/common.py
@@ -0,0 +1,11 @@
+from megatron.training import get_args
+from mindspeed.core.memory.adaptive_memory.adaptive_memory_swap_manager import SwapManager as AdaptiveMemorySwapManager
+from mindspeed.core.memory.adaptive_recomputing.swap_manager import SwapManager as AdaptiveRecomputingSwapManager
+
+
+def swap_out_by_size(size):
+ args = get_args()
+ if args.adaptive_memory_optimization:
+ return AdaptiveMemorySwapManager().swap_out_by_size(size)
+ else:
+ return AdaptiveRecomputingSwapManager().swap_out_by_size(size)
diff --git a/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/malloc_recorder.py b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/malloc_recorder.py
new file mode 100644
index 0000000000000000000000000000000000000000..93d4030f9b775a3a78a8072daff2744304bbe485
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/malloc_recorder.py
@@ -0,0 +1,18 @@
+from functools import wraps
+from mindspeed.core.memory.memory_fragmentation.pluggable_allocator_adpator import load_memory_fragmentation_module
+
+
+class MallocRecorder(object):
+ def __init__(self):
+ load_memory_fragmentation_module().malloc_recorder_start()
+
+ def __del__(self):
+ load_memory_fragmentation_module().malloc_recorder_end()
+
+
+def malloc_recorder_wrapper(train_step):
+ @wraps(train_step)
+ def rec_wrap(*args, **kwargs):
+ recorder = MallocRecorder()
+ return train_step(*args, **kwargs)
+ return rec_wrap
diff --git a/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/memory_recorder.py b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/memory_recorder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce4e8f697a3b79a58c1453eec5b0d066459a7c5d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/memory_recorder.py
@@ -0,0 +1,45 @@
+from functools import wraps
+
+from megatron.legacy.model.transformer import ParallelTransformer
+from mindspeed.core.memory.memory_fragmentation.pluggable_allocator_adpator import load_memory_fragmentation_module
+
+
+class MemoryRecorder(object):
+ def __init__(self):
+ load_memory_fragmentation_module().memory_recorder_start()
+
+ def __del__(self):
+ load_memory_fragmentation_module().memory_recorder_end()
+
+ def register_recursive_hook(self, prefix_name, model):
+ for name, module in model.named_children():
+ if isinstance(module, ParallelTransformer):
+ module.no_checkpoint_forward = module.forward
+ module.forward = wrapper(module.forward)
+
+ next_name = prefix_name + "." + name if prefix_name != "" else name
+ self.register_recursive_hook(next_name, module)
+
+
+def memory_recorder_wrapper(setup_model_and_optimizer):
+ @wraps(setup_model_and_optimizer)
+ def get_model_hook_func(*args, **kwargs):
+ load_memory_fragmentation_module().precise_match_start()
+ models, optimizer, lr_scheduler = setup_model_and_optimizer(*args, **kwargs)
+ load_memory_fragmentation_module().precise_match_end()
+ memory = MemoryRecorder()
+ if isinstance(models, list):
+ for model in models:
+ memory.register_recursive_hook("module", model)
+ else:
+ memory.register_recursive_hook("module", models)
+ return models, optimizer, lr_scheduler
+
+ return get_model_hook_func
+
+
+def wrapper(f):
+ def rec_wrap(*args, **kwargs):
+ recorder = MemoryRecorder()
+ return f(*args, **kwargs)
+ return rec_wrap
diff --git a/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/optimizer_init_precise.py b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/optimizer_init_precise.py
new file mode 100644
index 0000000000000000000000000000000000000000..00dab6f450b675ba182dc0f1a60632f839fce6f6
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/optimizer_init_precise.py
@@ -0,0 +1,22 @@
+import torch_npu
+from functools import wraps
+from mindspeed.core.memory.memory_fragmentation.pluggable_allocator_adpator import load_memory_fragmentation_module
+
+is_optimizer_init_end = False
+
+
+def optimizer_init_wrapper(step):
+ @wraps(step)
+ def rec_wrap(*args, **kwargs):
+ global is_optimizer_init_end
+ if not is_optimizer_init_end:
+ torch_npu.npu.empty_cache()
+ load_memory_fragmentation_module().precise_match_start()
+ optimizer_initialized, grad_norm, num_zeros_in_grad = step(*args, **kwargs)
+ if not is_optimizer_init_end:
+ load_memory_fragmentation_module().precise_match_end()
+ is_optimizer_init_end = optimizer_initialized
+
+ return optimizer_initialized, grad_norm, num_zeros_in_grad
+
+ return rec_wrap
diff --git a/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/pluggable_allocator_adpator.py b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/pluggable_allocator_adpator.py
new file mode 100644
index 0000000000000000000000000000000000000000..90c6ed44185a03783e09108d6b22acbfb0142988
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/memory_fragmentation/pluggable_allocator_adpator.py
@@ -0,0 +1,34 @@
+import ctypes
+
+import torch_npu
+from mindspeed.op_builder import MemoryFragmentationBuilder
+
+
+class PluggableAllocatorAdaptor(object):
+ MEMORY_FRAGMENTATION_MODULE = None
+ def __init__(self):
+ pass
+
+def load_memory_fragmentation_module():
+ if PluggableAllocatorAdaptor.MEMORY_FRAGMENTATION_MODULE is None:
+ PluggableAllocatorAdaptor.MEMORY_FRAGMENTATION_MODULE = MemoryFragmentationBuilder().load()
+ return PluggableAllocatorAdaptor.MEMORY_FRAGMENTATION_MODULE
+
+def change_allocator():
+ memory_fragmentation_module_path = load_memory_fragmentation_module().__file__
+
+ new_alloc = torch_npu.npu.memory.NPUPluggableAllocator(memory_fragmentation_module_path, 'memory_fragmentation_malloc', 'memory_fragmentation_free')
+ torch_npu.npu.memory.change_current_allocator(new_alloc)
+
+ myallocator = ctypes.CDLL(memory_fragmentation_module_path)
+ init_fn = ctypes.cast(getattr(myallocator, "memory_fragmentation_init"), ctypes.c_void_p).value
+ empty_fn = ctypes.cast(getattr(myallocator, "memory_fragmentation_empty_cache"), ctypes.c_void_p).value
+ memory_fraction_fn = ctypes.cast(getattr(myallocator, "memory_fragmentation_memory_fraction"), ctypes.c_void_p).value
+ get_device_stats_fn = ctypes.cast(getattr(myallocator, "memory_fragmentation_get_device_stats"), ctypes.c_void_p).value
+ reset_peak_status_fn = ctypes.cast(getattr(myallocator, "my_reset_peak_stats"), ctypes.c_void_p).value
+
+ new_alloc.allocator().set_init_fn(init_fn)
+ new_alloc.allocator().set_reset_fn(empty_fn)
+ new_alloc.allocator().set_memory_fraction_fn(memory_fraction_fn)
+ new_alloc.allocator().set_get_device_stats_fn(get_device_stats_fn)
+ new_alloc.allocator().set_reset_peak_status_fn(reset_peak_status_fn)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/__init__.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/hooks.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..64a1f4a413ef383106d1b4d6eaa7c3a1adcc2f19
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/hooks.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+
+from .swap_utils import print_with_rank, PrintLevel
+
+
+def get_module_name(module: torch.nn.Module):
+ return module.__module__ + "." + module.__class__.__name__
+
+
+class SwapHookRegister:
+ id = 0
+
+ def __init__(self):
+ self.id = SwapHookRegister.id
+ SwapHookRegister.id += 1
+
+ self.fwd_pre_hook_handle = None
+ self.fwd_post_hook_handle = None
+ self.bwd_pre_hook_handle = None
+ self.bwd_post_hook_handle = None
+ self.fwd_begin_module: torch.nn.Module = None
+ self.fwd_end_module: torch.nn.Module = None
+ self.bwd_begin_module: torch.nn.Module = None
+ self.bwd_end_module: torch.nn.Module = None
+ self.fwd_idx = 0
+ self.bwd_idx = 0
+ self.prehook_handles = []
+ self.posthook_handls = []
+
+ self.fwd_pre_hook_custom_func = None
+ self.fwd_post_hook_custom_func = None
+ self.bwd_pre_hook_custom_func = None
+ self.bwd_post_hook_custom_func = None
+
+ def __del__(self):
+ r"""if not need swap hook to module, del it."""
+
+ self.reset()
+
+ if self.fwd_pre_hook_handle:
+ self.fwd_pre_hook_handle.remove()
+ if self.fwd_post_hook_handle:
+ self.fwd_post_hook_handle.remove()
+ if self.bwd_pre_hook_handle:
+ self.bwd_pre_hook_handle.remove()
+ if self.bwd_post_hook_handle:
+ self.bwd_post_hook_handle.remove()
+
+ def reset(self):
+ self.fwd_begin_module = None
+ self.fwd_end_module = None
+ self.bwd_begin_module = None
+ self.bwd_end_module = None
+
+ self.fwd_idx = 0
+ self.bwd_idx = 0
+ for hdl in self.prehook_handles:
+ hdl.remove()
+ for hdl in self.posthook_handls:
+ hdl.remove()
+ self.prehook_handles.clear()
+ self.posthook_handls.clear()
+
+ def register_custom_func(
+ self, fwd_pre_hook_custom_func, fwd_post_hook_custom_func, bwd_pre_hook_custom_func, bwd_post_hook_custom_func
+ ):
+ r"""
+ custom_func(instance_id, fwd_or_bwd_idx)
+ """
+ self.fwd_pre_hook_custom_func = fwd_pre_hook_custom_func
+ self.fwd_post_hook_custom_func = fwd_post_hook_custom_func
+ self.bwd_pre_hook_custom_func = bwd_pre_hook_custom_func
+ self.bwd_post_hook_custom_func = bwd_post_hook_custom_func
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="SwapHook", print_level=print_level)
+
+ def register_hook_to_grad_fn(self, input_tensor, position, is_bwd_pre):
+
+ def grad_fn_bwd_pre_hook(grad_outputs):
+ self.bwd_idx += 1
+ self.print_with_rank(f"grad_fn_bwd_pre_hook: bwd begin, id[{self.id}], bwd_idx[{self.bwd_idx}]")
+ # border
+ if self.bwd_pre_hook_custom_func:
+ self.bwd_pre_hook_custom_func(self.id, self.bwd_idx)
+ return grad_outputs
+
+ def grad_fn_bwd_post_hook(grad_inputs, _):
+ self.print_with_rank(f"grad_fn_bwd_post_hook: bwd end, id[{self.id}], bwd_idx[{self.bwd_idx}]")
+ # border
+ if self.bwd_post_hook_custom_func:
+ self.bwd_post_hook_custom_func(self.id, self.bwd_idx)
+ return grad_inputs
+
+ if is_bwd_pre:
+ self.print_with_rank(f"{position}, register grad_fn_bwd_pre_hook to grad_fn: {input_tensor.grad_fn}")
+ self.prehook_handles.append(input_tensor.grad_fn.register_prehook(grad_fn_bwd_pre_hook))
+ else:
+ self.print_with_rank(f"{position}, register grad_fn_bwd_post_hook to grad_fn: {input_tensor.grad_fn}")
+ self.posthook_handls.append(input_tensor.grad_fn.register_hook(grad_fn_bwd_post_hook))
+
+ def register_hook_to_bwd_end_module(self, module, inputs, position):
+ if not self.bwd_end_module or (self.bwd_end_module and module is self.bwd_end_module):
+ if isinstance(inputs, torch.Tensor):
+ inputs = (inputs,)
+ if isinstance(inputs, tuple):
+ for input_item in inputs:
+ if not isinstance(input_item, torch.Tensor):
+ continue
+ if (input_item.requires_grad and not input_item.is_leaf) and input_item.grad_fn:
+ if not self.bwd_end_module:
+ self.bwd_end_module = module
+ self.print_with_rank(f"{position}, set bwd_end_module: {get_module_name(module)}")
+
+ self.register_hook_to_grad_fn(input_item, position, is_bwd_pre=False)
+ break
+
+ def register_hook_to_bwd_begin_module(self, module, inputs, position):
+ if self.bwd_begin_module and module is self.bwd_begin_module:
+ if isinstance(inputs, torch.Tensor):
+ inputs = (inputs,)
+ if isinstance(inputs, tuple):
+ for input_item in inputs:
+ if not isinstance(input_item, torch.Tensor):
+ continue
+ if (input_item.requires_grad and not input_item.is_leaf) and input_item.grad_fn:
+
+ self.register_hook_to_grad_fn(input_item, position, is_bwd_pre=True)
+ break
+
+ def fwd_pre_hook(self, module, args):
+ self.print_with_rank(f"fwd_pre_hook, {get_module_name(module)}")
+
+ if not self.fwd_begin_module:
+ self.fwd_begin_module = module
+ self.fwd_end_module = module
+ self.bwd_begin_module = module
+ self.print_with_rank(
+ f"fwd_pre_hook: set fwd_begin_module, fwd_end_module and bwd_begin_module: {get_module_name(module)}"
+ )
+
+ if self.fwd_begin_module and module is self.fwd_begin_module:
+ self.fwd_idx += 1
+ self.print_with_rank(
+ f"fwd_pre_hook: fwd begin, id[{self.id}], fwd_idx[{self.fwd_idx}], {get_module_name(module)}"
+ )
+ # border
+ if self.fwd_pre_hook_custom_func:
+ self.fwd_pre_hook_custom_func(self.id, self.fwd_idx)
+
+ self.register_hook_to_bwd_end_module(module, args, "fwd_pre_hook")
+
+ return None
+
+ def fwd_post_hook(self, module, _, outputs):
+ self.print_with_rank(f"fwd_post_hook, {get_module_name(module)}")
+
+ if self.fwd_end_module and module is self.fwd_end_module:
+ self.print_with_rank(
+ f"fwd_post_hook: fwd end, id[{self.id}], fwd_idx[{self.fwd_idx}], {get_module_name(module)}"
+ )
+ # border
+ if self.fwd_post_hook_custom_func:
+ self.fwd_post_hook_custom_func(self.id, self.fwd_idx)
+
+ self.register_hook_to_bwd_begin_module(module, outputs, "fwd_post_hook")
+ self.register_hook_to_bwd_end_module(module, outputs, "fwd_post_hook")
+
+ return None
+
+ def register_hooks_to_modules_recursively(self, module, name=""):
+ self.print_with_rank(f"register_hooks_to_modules_recursively, {get_module_name(module)}")
+
+ for child_name, child in module.named_children():
+ self.register_hooks_to_modules_recursively(child, name + child_name)
+
+ def module_fwd_pre_hook(module, args):
+ return self.fwd_pre_hook(module, args)
+
+ def module_fwd_post_hook(module, args, outputs):
+ return self.fwd_post_hook(module, args, outputs)
+
+ self.fwd_pre_hook_handle = module.register_forward_pre_hook(module_fwd_pre_hook)
+ self.fwd_post_hook_handle = module.register_forward_hook(module_fwd_post_hook)
+
+
+def register_swap_hooks_to_modules(
+ module,
+ fwd_pre_hook_custom_func=None,
+ fwd_post_hook_custom_func=None,
+ bwd_pre_hook_custom_func=None,
+ bwd_post_hook_custom_func=None,
+):
+ r"""
+ usage:
+
+ # before training
+ models = [model_1, model_2, ...]
+ swap_hook_registers = []
+
+ def fwd_pre_hook_custom_func(swap_hook_register_id, fwd_idx):
+ ...
+
+ def fwd_post_hook_custom_func(swap_hook_register_id, fwd_idx):
+ ...
+
+ def bwd_pre_hook_custom_func(swap_hook_register_id, bwd_idx):
+ ...
+
+ def bwd_post_hook_custom_func(swap_hook_register_id, bwd_idx):
+ ...
+
+ for model in models:
+ import smart_swap
+ swap_hook_register = smart_swap.xxx.register_swap_hooks_to_modules(.
+ model,
+ fwd_pre_hook_custom_func, fwd_post_hook_custom_func
+ bwd_pre_hook_custom_func, bwd_post_hook_custom_func)
+
+ swap_hook_registers.append(swap_hook_register)
+
+ # when training
+ for step in range(train_steps):
+ for swap_hook_register in swap_hook_registers:
+ swap_hook_register.reset()
+
+ train_step(xxx)
+
+ # after training
+ for swap_hook_register in swap_hook_registers:
+ del swap_hook_register
+
+ """
+
+ swap_hook_register = SwapHookRegister()
+ swap_hook_register.register_hooks_to_modules_recursively(module)
+ swap_hook_register.register_custom_func(
+ fwd_pre_hook_custom_func, fwd_post_hook_custom_func, bwd_pre_hook_custom_func, bwd_post_hook_custom_func
+ )
+
+ return swap_hook_register
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/policy_generator.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/policy_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a5573e1c49682d8d7bef04cd59a083316d6f7a4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/policy_generator.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import os
+from typing import Dict, List
+
+import numpy as np
+
+from .swap_policy_config import swap_policy_config
+from .swap_utils import print_with_rank, PrintLevel, timer
+from .swap_cpp_adaptor import (
+ ProfilerDataOneStep,
+ SwapPolicyCandidate,
+ TensorInfoDetail,
+ UniqueSwapPtr,
+ MemoryReductionInfo,
+ MemoryPeakInfo,
+ SwapStage,
+ SwapStageType,
+ SwapTensorType,
+)
+from .swap_arranger import TensorArranger
+
+
+class PolicyGenerator:
+ def __init__(self, profiler_op_step: ProfilerDataOneStep):
+ self.size_coverage_weight = swap_policy_config.size_coverage_weight
+
+ self.profiler_op_step = profiler_op_step
+ self.tensor_info_dict: Dict[UniqueSwapPtr, TensorInfoDetail] = {}
+ self.policy_candidate_list: List[SwapPolicyCandidate] = []
+ self.intersect_candidates: List[SwapPolicyCandidate] = []
+ self.swap_list: List[SwapPolicyCandidate] = []
+ self.peak_list: List[MemoryReductionInfo] = []
+
+ self.candidate_selected: Dict[SwapPolicyCandidate, bool] = {}
+ self.memory_reduction_list = profiler_op_step.memory_reduction_list
+ # new data structure
+ self.mri_opid2idx = self.profiler_op_step.mri_opid2idx
+ self.memory_peaks = self.profiler_op_step.memory_peaks
+ self.swap_arranger = TensorArranger(
+ self.profiler_op_step,
+ os.path.join(swap_policy_config.output_root_path, f"Simulation_{swap_policy_config.rank}.html"),
+ swap_policy_config.duration_time,
+ )
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="Policy", print_level=print_level)
+
+ def reduction_target_satisfied(self):
+ for memory_reduction in self.memory_reduction_list:
+ if not memory_reduction.cleared():
+ return False
+ self.print_with_rank("Successfully reach reduction target ...", print_level=PrintLevel.INFO)
+ return True
+
+ def get_covered_reductions(self, candidate_list=None):
+ if not self.memory_reduction_list:
+ return
+ flag = 0
+ if candidate_list is None:
+ flag = 1
+ candidate_list = self.policy_candidate_list
+ for memory_info in self.memory_reduction_list:
+ memory_info.intersect_candidate_list.clear()
+ for candidate in candidate_list:
+ candidate.num_covered_reductions = 0
+ swap_out_stage = self.profiler_op_step.layer_info.get_next_layer(candidate.swap_out_stage_actual)
+ swap_in_stage = candidate.swap_in_stage_actual
+ start_op_id = self.profiler_op_step.layer_info.layer_start_opid[swap_out_stage]
+ end_op_id = self.profiler_op_step.layer_info.layer_start_opid[swap_in_stage]
+ if start_op_id >= self.memory_reduction_list[-1].op_id or end_op_id <= self.memory_reduction_list[0].op_id:
+ candidate.start_mri_opid = -1
+ candidate.end_mri_opid = -1
+ candidate.num_covered_reductions = 0
+ else:
+ # 二分法查找
+ # find the mri with smallest opid that has opid >= start_op_id
+ start_mri_opid = self.get_closest_mri(start_op_id, cmp="ge")
+ # find the mri with largest opid that has opid < end_op_id
+ end_mri_opid = self.get_closest_mri(end_op_id, cmp="lt")
+ if end_mri_opid == end_op_id:
+ end_mri_opid = self.memory_reduction_list[self.mri_opid2idx[end_mri_opid] - 1].op_id
+ if start_mri_opid < start_op_id:
+ self.print_with_rank(
+ f"start_op_id={start_op_id}, end_op_id={end_op_id}, \
+ start_mri_opid={start_mri_opid}, end_mri_opid={end_mri_opid}",
+ print_level=PrintLevel.INFO,
+ )
+ if start_mri_opid < start_op_id:
+ raise ValueError("candidate.start_mri_opid should be >= than start_op_id")
+ if end_mri_opid > end_op_id:
+ self.print_with_rank(
+ f"start_op_id={start_op_id}, end_op_id={end_op_id}, \
+ start_mri_opid={start_mri_opid}, end_mri_opid={end_mri_opid}",
+ print_level=PrintLevel.INFO,
+ )
+ if end_mri_opid > end_op_id:
+ raise ValueError("candidate.end_mri_opid should be <= end_op_id")
+ # candidate增加属性:start_mri_opid, end_mri_opid, num_covered_reductions
+ if end_mri_opid < start_mri_opid:
+ candidate.start_mri_opid = -1
+ candidate.end_mri_opid = -1
+ candidate.num_covered_reductions = 0
+ else:
+ candidate.start_mri_opid = start_mri_opid
+ candidate.end_mri_opid = end_mri_opid
+ # 计算candidate能cover的mri的个数,通过mri_opid2idx的map算start_mri_opid和end_mri_opid之间的mri的个数
+ candidate.num_covered_reductions = (
+ self.mri_opid2idx[end_mri_opid] - self.mri_opid2idx[start_mri_opid] + 1
+ )
+ if flag:
+ if candidate.start_mri_opid != -1 and candidate.end_mri_opid != -1:
+ for mri_idx in range(self.mri_opid2idx[start_mri_opid], self.mri_opid2idx[end_mri_opid] + 1):
+ self.memory_reduction_list[mri_idx].intersect_candidate_list.append(candidate)
+
+ def get_closest_mri(self, target_opid, cmp="ge"):
+ """
+ Binary search for the opid closest to target_opid.
+ cmp:
+ 'ge': result opid greater than or equal to target_opid;
+ 'lt': result opid less than target_opid;
+ """
+ p1 = 0
+ p2 = len(self.memory_reduction_list) - 1
+ if cmp not in ["ge", "lt"]:
+ raise ValueError("For now only support cmp='ge' or cmp='lt' ")
+ while p1 < p2 - 1:
+ mid = (p1 + p2) // 2
+ mid_opid = self.memory_reduction_list[mid].op_id
+ if mid_opid == target_opid:
+ return mid_opid
+ elif mid_opid < target_opid:
+ p1 = mid
+ elif mid_opid > target_opid:
+ p2 = mid
+ if cmp == "ge":
+ if self.memory_reduction_list[p1].op_id >= target_opid:
+ return self.memory_reduction_list[p1].op_id
+ else:
+ return self.memory_reduction_list[p2].op_id
+ elif cmp == "lt":
+ if self.memory_reduction_list[p2].op_id < target_opid:
+ return self.memory_reduction_list[p2].op_id
+ else:
+ return self.memory_reduction_list[p1].op_id
+
+ def update_memory_reduction(self, candidate_list: List[SwapPolicyCandidate]):
+ self.get_covered_reductions(candidate_list)
+ for candidate in candidate_list:
+ if candidate.start_mri_opid != -1 and candidate.end_mri_opid != -1:
+ for mri_idx in range(
+ self.mri_opid2idx[candidate.start_mri_opid], self.mri_opid2idx[candidate.end_mri_opid] + 1
+ ):
+ mri = self.memory_reduction_list[mri_idx]
+ mri.update_memory_reduction_need(-candidate.tensor.info.size)
+
+ @timer
+ def select_candidate(self):
+ self.tensor_info_dict.clear()
+ for op in self.profiler_op_step.op_list:
+ for tensor in op.tensor_list:
+ tensor_info = self.tensor_info_dict.setdefault(tensor.ptr, TensorInfoDetail(tensor))
+ tensor_info.update_op(op)
+
+ for detail_tensor in self.tensor_info_dict.values():
+ detail_tensor.policy_candidate_list.clear()
+ if (
+ not detail_tensor.is_used_multiple_times()
+ or detail_tensor.info.tensor_type == SwapTensorType.SHARED_MEMORY
+ or detail_tensor.info.size < swap_policy_config.tensor_size_filter
+ ):
+ continue
+ if detail_tensor.info.tensor_type == SwapTensorType.OPTIM:
+ self.select_optim_tensor(detail_tensor)
+ elif detail_tensor.info.tensor_type in (SwapTensorType.MODEL, SwapTensorType.OTHERS):
+ self.select_model_tensor(detail_tensor)
+
+ self.policy_candidate_list = list(
+ set().union(*[i.policy_candidate_list for i in self.tensor_info_dict.values()])
+ )
+ self.candidate_selected = dict([(candidate, False) for candidate in self.policy_candidate_list])
+ self.get_covered_reductions()
+
+ def select_optim_tensor(self, detail_tensor: TensorInfoDetail):
+ first_op = detail_tensor.used_op_list[0]
+ if first_op.stage.stage_type != SwapStageType.OPTIM:
+ return
+ swap_out_stage = SwapStage(stage_type=SwapStageType.FWD, micro_batch_index=1, layer_index=1)
+ swap_in_stage = SwapStage(stage_type=SwapStageType.OPTIM, micro_batch_index=0, layer_index=0)
+ swap_policy_candidate = SwapPolicyCandidate(
+ detail_tensor, is_optimizer_or_weight=True, swap_out_stage=swap_out_stage, swap_in_stage=swap_in_stage
+ )
+ detail_tensor.policy_candidate_list.append(swap_policy_candidate)
+ return
+
+ # 找到FWD最后一次使用和BWD第一次使用
+ def select_model_tensor(self, detail_tensor: TensorInfoDetail):
+ if any(op.stage.stage_type == SwapStageType.OPTIM for op in detail_tensor.used_op_list):
+ return
+ fwd_last_op = None
+ bwd_first_op = None
+ for op in detail_tensor.used_op_list:
+ if op.stage.stage_type == SwapStageType.FWD and (fwd_last_op is None or fwd_last_op.op_id < op.op_id):
+ fwd_last_op = op
+ if op.stage.stage_type == SwapStageType.BWD and (bwd_first_op is None or bwd_first_op.op_id > op.op_id):
+ bwd_first_op = op
+ if fwd_last_op and bwd_first_op:
+ swap_policy_candidate = SwapPolicyCandidate(
+ detail_tensor, is_optimizer_or_weight=False, swap_out_op=fwd_last_op, swap_in_op=bwd_first_op
+ )
+ detail_tensor.policy_candidate_list.append(swap_policy_candidate)
+ return
+
+ def compute_score(self):
+ if not self.policy_candidate_list:
+ return
+ tensor_info_sizes = [i.tensor.info.size for i in self.policy_candidate_list]
+ max_size = max(tensor_info_sizes)
+ min_size = min(tensor_info_sizes)
+ max_size = max_size ** (1 / 3)
+ min_size = min_size ** (1 / 3)
+ size_range = max(0.001, max_size - min_size)
+
+ coverages = [i.num_covered_reductions for i in self.policy_candidate_list]
+ max_coverage = max(coverages)
+ min_coverage = min(coverages)
+ coverage_range = max(0.001, max_coverage - min_coverage)
+
+ for candidate in self.policy_candidate_list:
+ normalized_coverage = (candidate.num_covered_reductions - min_coverage) / coverage_range
+ normalized_size = (candidate.tensor.info.size ** (1 / 3) - min_size) / size_range
+ candidate.score = normalized_coverage + self.size_coverage_weight * normalized_size
+
+ def get_peak_list(self):
+ # Select the maximum mri value from the top mri of each MemoryPeakInfo (self.memory_peaks)
+ # so each iteration only one peak is selected.
+ self.peak_list.clear()
+
+ def get_max_for_each_mp(mp: MemoryPeakInfo):
+ """
+ 找到每个MemoryPeak区间内对应的MemoryReductionInfo当前的最大memory_reduction_need
+ """
+ if mp.mp_mri_start_opid == -1 or mp.mp_mri_end_opid == -1:
+ return None
+ start_idx = self.mri_opid2idx[mp.mp_mri_start_opid]
+ end_idx = self.mri_opid2idx[mp.mp_mri_end_opid] + 1
+ mri_list = self.memory_reduction_list[start_idx:end_idx]
+ mrn = [mri.memory_reduction_need for mri in mri_list]
+ max_idx = np.argmax(mrn)
+ self.print_with_rank(
+ f"current top mri in MemoryPeakInfo is {mri_list[max_idx]}", print_level=PrintLevel.INFO
+ )
+ return mri_list[max_idx]
+
+ mp_max = [(i, get_max_for_each_mp(mp)) for i, mp in enumerate(self.memory_peaks)]
+ for mp in mp_max:
+ self.print_with_rank(f"top mri from each MemoryPeakInfo {mp[1]}", print_level=PrintLevel.INFO)
+ mp_max_list = np.array([0 if not item[1] else item[1].memory_reduction_need for item in mp_max])
+ self.print_with_rank(f"top mri from each MemoryPeakInfo {[mp_max_list]}", print_level=PrintLevel.INFO)
+ selected_peak_idx = np.argmax(mp_max_list)
+ self.peak_list = [mp_max[selected_peak_idx][1]]
+
+ def get_intersect_candidates(self):
+ self.get_peak_list()
+ self.intersect_candidates.clear()
+ self.print_with_rank(f"len of peak list is {len(self.peak_list)}", print_level=PrintLevel.INFO)
+ peak = self.peak_list[0]
+ if not peak:
+ return
+ self.intersect_candidates = [
+ cand for cand in peak.intersect_candidate_list if not self.candidate_selected[cand]
+ ]
+ self.intersect_candidates.sort(key=lambda x: (-x.score, x.start_mri_opid))
+ self.print_with_rank(
+ f"len of self.intersect_candidates after {len(self.intersect_candidates)}", print_level=PrintLevel.INFO
+ )
+
+ def simulation_select(self):
+ reduction_need = self.peak_list[0].memory_reduction_need
+ selected_candidates = []
+ for cand in self.intersect_candidates:
+ if not self.swap_arranger.cause_delay(cand):
+ selected_candidates.append(cand)
+ reduction_need -= cand.tensor.info.size
+ if reduction_need <= 0:
+ return selected_candidates, False
+ if not selected_candidates:
+ return [self.intersect_candidates[0]], True
+ return selected_candidates, False
+
+ def simulation(self, use_custom_policy=False):
+ if use_custom_policy:
+ selected_candidates = self.policy_candidate_list
+ cause_delay = False
+ else:
+ selected_candidates, cause_delay = self.simulation_select()
+ self.print_with_rank(f"selected_candidates have {len(selected_candidates)} cands", print_level=PrintLevel.DEBUG)
+ self.swap_list.extend(selected_candidates)
+ self.swap_arranger.run(selected_candidates, self.swap_list, delay=cause_delay)
+ self.update_memory_reduction(selected_candidates)
+ for cand in selected_candidates:
+ self.candidate_selected[cand] = True
+
+ def get_sorted_swap_list(self):
+ """
+ Sort swap_list by: primary key: swap_out time; secondary key: tensor size reverse
+ """
+ swap_list_out_opid = [
+ (
+ candidate,
+ (
+ self.profiler_op_step.layer_info.layer_start_opid[candidate.swap_out_stage]
+ if candidate.is_optimizer_or_weight
+ else candidate.swap_out_op.op_id
+ ),
+ )
+ for candidate in self.swap_list
+ ]
+ swap_list_out_opid = sorted(swap_list_out_opid, key=lambda item: (item[1], -item[0].tensor.info.size))
+ swap_list = [candidate for (candidate, out_opid) in swap_list_out_opid]
+ return swap_list
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_adaptor.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_adaptor.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec50c4dea0f49d4da5a056b461d509d3858e2f3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_adaptor.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import ctypes
+import torch_npu
+
+from mindspeed.op_builder import SmartSwapBuilder
+
+
+class SmartSwapAdaptor(object):
+ SMART_SWAP_MODULE = None
+
+ def __init__(self):
+ pass
+
+
+def load_smart_swap_module():
+ if SmartSwapAdaptor.SMART_SWAP_MODULE is None:
+ SmartSwapAdaptor.SMART_SWAP_MODULE = SmartSwapBuilder().load()
+ return SmartSwapAdaptor.SMART_SWAP_MODULE
+
+
+def change_allocator():
+ smart_swap_cpp = load_smart_swap_module()
+ smart_swap_module_path = smart_swap_cpp.__file__
+
+ new_alloc = torch_npu.npu.memory.NPUPluggableAllocator(smart_swap_module_path, "gmlake_malloc", "gmlake_free")
+ torch_npu.npu.memory.change_current_allocator(new_alloc)
+
+ myallocator = ctypes.CDLL(smart_swap_module_path)
+ init_fn = ctypes.cast(getattr(myallocator, "gmlake_init"), ctypes.c_void_p).value
+ empty_fn = ctypes.cast(getattr(myallocator, "gmlake_empty_cache"), ctypes.c_void_p).value
+ memory_fraction_fn = ctypes.cast(getattr(myallocator, "gmlake_memory_fraction"), ctypes.c_void_p).value
+ get_device_stats_fn = ctypes.cast(getattr(myallocator, "gmlake_get_device_stats"), ctypes.c_void_p).value
+ reset_peak_stats_fn = ctypes.cast(getattr(myallocator, "gmlake_reset_peak_stats"), ctypes.c_void_p).value
+ record_stream_fn = ctypes.cast(getattr(myallocator, "gmlake_record_stream"), ctypes.c_void_p).value
+ erase_stream_fn = ctypes.cast(getattr(myallocator, "gmlake_erase_stream"), ctypes.c_void_p).value
+
+ new_alloc.allocator().set_init_fn(init_fn)
+ new_alloc.allocator().set_reset_fn(empty_fn)
+ new_alloc.allocator().set_memory_fraction_fn(memory_fraction_fn)
+ new_alloc.allocator().set_get_device_stats_fn(get_device_stats_fn)
+ new_alloc.allocator().set_reset_peak_status_fn(reset_peak_stats_fn)
+ new_alloc.allocator().set_record_stream_fn(record_stream_fn)
+ new_alloc.allocator().set_erase_stream_fn(erase_stream_fn)
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_arranger.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_arranger.py
new file mode 100644
index 0000000000000000000000000000000000000000..64f6753f96cafe79471def1ce1db2fb325b959c6
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_arranger.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from typing import List
+
+import numpy as np
+
+from .swap_policy_config import swap_policy_config
+from .swap_utils import print_with_rank, PrintLevel
+from .swap_cpp_adaptor import ProfilerDataOneStep, SwapPolicyCandidate, SwapStage
+
+
+class TensorArrangerBase:
+ def __init__(self, profiler_op_step: ProfilerDataOneStep, output_file_path, duration_time):
+ self.op_list = profiler_op_step.op_list
+ self.profiler_op_step = profiler_op_step
+ self.duration_time = duration_time
+ self.stage_data = []
+ self.stage_map = {}
+ self.stage_index_map = {}
+ self.stage_start_time = dict()
+ self.stage_end_time = dict()
+ self.set_data()
+ self.candidate_index = 0
+ self.output_file_path = output_file_path
+
+ self.D2H_bandwidth = swap_policy_config.D2H_bandwidth
+ self.H2D_bandwidth = swap_policy_config.H2D_bandwidth
+ self.color_map = {
+ "SwapStageType.INIT": "yellow",
+ "SwapStageType.FWD": "red",
+ "SwapStageType.BWD": "blue",
+ "SwapStageType.OPTIM": "purple",
+ "Delay": "green",
+ }
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="SwapEngine", print_level=print_level)
+
+ def set_data(self):
+ time_line = list(np.linspace(0, self.duration_time, len(self.op_list) + 1))[1:]
+ for index, op in enumerate(self.op_list):
+ if not self.stage_data or op.stage != self.stage_data[-1]["stage"]:
+ self.stage_data.append(
+ {
+ "op_id": op.op_id,
+ "stage": op.stage,
+ "stage_type": str(op.stage.stage_type),
+ "start_time": time_line[index],
+ "type": "op_stream",
+ "candidate_index": -1,
+ }
+ )
+ self.stage_data[-1]["end_time"] = time_line[index]
+ for index, row in enumerate(self.stage_data):
+ if row["stage"] in self.stage_map:
+ raise ValueError("Find duplicate stage ...")
+ self.stage_index_map[index] = row["stage"]
+ self.stage_map[row["stage"]] = {
+ "index": index,
+ "start_time": row["start_time"],
+ "end_time": row["end_time"],
+ "layer_time": row["end_time"] - row["start_time"],
+ "time_left": row["end_time"] - row["start_time"],
+ "candidate_list": [],
+ }
+ self.stage_end_time[index] = row["end_time"]
+ self.stage_start_time[index] = row["start_time"]
+
+ def get_swap_time(self, size):
+ swap_out_time = size / 1024 / 1024 / self.D2H_bandwidth
+ swap_in_time = size / 1024 / 1024 / self.H2D_bandwidth
+ return swap_out_time, swap_in_time
+
+ def reset_simulation(self):
+ self.candidate_index = 0
+ for stage in self.stage_map:
+ self.stage_map[stage]["candidate_list"] = []
+ self.stage_map[stage]["time_left"] = self.stage_map[stage]["layer_time"]
+ self.stage_time_left = dict()
+
+ def set_swapin_free_stage_to_candidate(self, cur_time, candidate):
+ swap_in_free_stage_index = self.stage_map[candidate.swap_in_stage_actual]["index"]
+ # 首先swap_in后实际释放的时机设置为swap_in_stage_actual的后一个
+ # 由于在排布时为了减少实际执行中计算流等待swap流swap in的情况,
+ # 所有candidate的swap_in_stage_actual设置都至少比理论上计算流需要的stage提前了一个stage
+ # 因此这里将所有candidate实际swap in释放的stage设置为swap_in_stage_actual的后一个stage,一定不会超出所有stage的边界
+ candidate.swap_in_free_stage = self.stage_index_map[swap_in_free_stage_index + 1]
+ for index, stage in self.stage_index_map.items():
+ # 如果当前candidate在排布中实际swap in结束时间所在stage,
+ # 加上延迟free的stage数后没有超过总stage数边界,
+ # 则将实际swap in 释放stage设置为排布获得的swap in结束时间所在stage再往后延swap_in_free_stage_delay个stage
+ if (
+ index < len(self.stage_index_map) - swap_policy_config.swap_in_free_stage_delay
+ and cur_time < self.stage_end_time[index]
+ ):
+ candidate.swap_in_free_stage = self.stage_index_map[index + swap_policy_config.swap_in_free_stage_delay]
+ return
+
+ def set_free_stage_to_candidate(self, cur_time, candidate):
+ candidate.free_stage = candidate.swap_in_stage_actual
+ for index, _ in self.stage_index_map.items():
+ if (
+ index < len(self.stage_index_map) - swap_policy_config.free_stage_delay
+ and cur_time < self.stage_end_time[index]
+ ):
+ candidate.free_stage = self.stage_index_map[index + swap_policy_config.free_stage_delay]
+ return
+
+ def set_free_stage(self):
+ for index, stage in self.stage_index_map.items():
+ value = self.stage_map[stage]
+ time_left = self.stage_time_left[value["index"]]
+ start_time = self.stage_start_time[index] - time_left + value["time_left"]
+ cur_time = start_time
+
+ # Initialize an empty list to store swap information for each candidate
+ swap_list_out_opid = []
+
+ # Iterate through each item in the candidate list
+ for swap_stage, swap_time, stream_type, candidate_index, candidate in value["candidate_list"]:
+ # Determine operation ID based on candidate type
+ if candidate.is_optimizer_or_weight:
+ op_id = self.profiler_op_step.layer_start_opid[candidate.swap_out_stage]
+ else:
+ op_id = candidate.swap_out_op.op_id
+
+ # Append a tuple with the relevant information to the list
+ swap_list_out_opid.append((swap_stage, swap_time, stream_type, candidate_index, candidate, op_id))
+
+ swap_list_out_opid = sorted(swap_list_out_opid, key=lambda item: (item[-1], -item[-2].tensor.info.size))
+ value["candidate_list"] = [
+ (swap_stage, swap_time, stream_type, candidate_index, candidate)
+ for swap_stage, swap_time, stream_type, candidate_index, candidate, _ in swap_list_out_opid
+ ]
+
+ for swap_stage, swap_time, stream_type, candidate_index, candidate in value["candidate_list"]:
+ cur_time += swap_time
+ if stream_type == "swap_out_stream":
+ self.set_free_stage_to_candidate(cur_time, candidate)
+ elif stream_type == "swap_in_stream":
+ self.set_swapin_free_stage_to_candidate(cur_time, candidate)
+
+
+class TensorArranger(TensorArrangerBase):
+ def __init__(self, profiler_op_step: ProfilerDataOneStep, output_file_path, duration_time):
+ super(TensorArranger, self).__init__(profiler_op_step, output_file_path, duration_time)
+ self.profiler_op_step = profiler_op_step
+ self.stage_time_left = dict()
+
+ def calculate_time_left(self, find_index):
+ time_left = 0
+ for index in range(find_index + 1):
+ time_left = min(0, time_left)
+ time_left += self.stage_map[self.stage_index_map[index]]["time_left"]
+ return time_left
+
+ def save_stage_time_left(self):
+ time_left = 0
+ for index, stage in self.stage_index_map.items():
+ time_left = min(0, time_left)
+ time_left += self.stage_map[stage]["time_left"]
+ self.stage_time_left[index] = time_left
+
+ def get_layer_time_excess(self, layer: SwapStage, swap_time):
+ return self.stage_map[layer]["time_left"] - swap_time
+
+ def cause_delay(self, candidate: SwapPolicyCandidate):
+ swap_out_time, swap_in_time = self.get_swap_time(candidate.tensor.info.size)
+ swap_out_affected = self.get_layer_time_excess(candidate.swap_out_stage, swap_out_time)
+ swap_in_stage_index = self.stage_map[candidate.swap_in_stage]["index"]
+ swap_in_stage_index = swap_in_stage_index - 1
+ swap_in_stage = self.stage_index_map[swap_in_stage_index]
+ swap_in_affected = self.get_layer_time_excess(swap_in_stage, swap_in_time)
+ return swap_out_affected < 0 or swap_in_affected < 0
+
+ def run(self, candidates: List[SwapPolicyCandidate], _: List[SwapPolicyCandidate], delay=False):
+ """
+ delay: if False, then items in candidates would not cause delay in current simulation
+ """
+ for cand in candidates:
+ swap_out_stage = cand.swap_out_stage
+ swap_in_stage = cand.swap_in_stage
+ swap_out_stage_index = self.stage_map[swap_out_stage]["index"]
+ swap_in_stage_index = self.stage_map[swap_in_stage]["index"]
+ swap_out_time, swap_in_time = self.get_swap_time(cand.tensor.info.size)
+ swap_in_stage_index = swap_in_stage_index - 1
+ swap_in_stage = self.stage_index_map[swap_in_stage_index]
+ self.stage_map[swap_out_stage]["candidate_list"].append(
+ (swap_out_stage, swap_out_time, "swap_out_stream", self.candidate_index, cand)
+ )
+ self.stage_map[swap_out_stage]["time_left"] -= swap_out_time
+ if delay:
+ find_flag = False
+ for find_swap_in_index in range(swap_in_stage_index, swap_out_stage_index, -1):
+ time_left = self.calculate_time_left(find_swap_in_index)
+ if time_left > swap_in_time:
+ swap_in_stage = self.stage_index_map[find_swap_in_index]
+ find_flag = True
+ break
+ if not find_flag:
+ swap_in_stage = self.stage_index_map[swap_in_stage_index]
+ cand.swap_in_stage_actual = swap_in_stage
+ self.stage_map[swap_in_stage]["candidate_list"].append(
+ (swap_in_stage, swap_in_time, "swap_in_stream", self.candidate_index, cand)
+ )
+ self.stage_map[swap_in_stage]["time_left"] -= swap_in_time
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_cpp_adaptor.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_cpp_adaptor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e7a8c087cc658944435ea66caf07e314721c75b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_cpp_adaptor.py
@@ -0,0 +1,1206 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from collections import Counter
+from enum import Enum
+from typing import List, Dict
+from dataclasses import dataclass
+from bisect import bisect_right, bisect_left
+
+import torch
+
+from .swap_policy_config import swap_policy_config
+from .swap_utils import print_with_rank, PrintLevel
+from .swap_adaptor import load_smart_swap_module
+
+
+def get_smart_swap_cpp():
+ return load_smart_swap_module()
+
+
+class SwapTensorType(Enum):
+ MODEL = get_smart_swap_cpp().SwapTensorType.MODEL
+ OPTIM = get_smart_swap_cpp().SwapTensorType.OPTIM
+ SHARED_MEMORY = get_smart_swap_cpp().SwapTensorType.SHARED_MEMORY
+ OTHERS = get_smart_swap_cpp().SwapTensorType.OTHERS
+ RESERVED = get_smart_swap_cpp().SwapTensorType.RESERVED
+
+
+class SwapStageType(Enum):
+ INIT = get_smart_swap_cpp().SwapStageType.INIT
+ FWD = get_smart_swap_cpp().SwapStageType.FWD
+ BWD = get_smart_swap_cpp().SwapStageType.BWD
+ OPTIM = get_smart_swap_cpp().SwapStageType.OPTIM
+ RESERVED = get_smart_swap_cpp().SwapStageType.RESERVED
+
+
+def record_tensor_ptr_with_types(
+ tensors: List[torch.Tensor], tensor_type: SwapTensorType, update_weak_ptr_map=0, is_update_blacklist=False
+):
+ # 调用下面的函数时,当前在c++侧会自动clear其维护的map
+ return get_smart_swap_cpp().recordTensorPtrWithTypes(tensors, tensor_type.value, update_weak_ptr_map, is_update_blacklist)
+
+
+class SwapStage:
+ def __init__(self, cpp_instance=None, stage_type=None, micro_batch_index=None, layer_index=None):
+ self.stage_type: SwapStageType = None
+ self.micro_batch_index = None
+ self.layer_index = None
+
+ if cpp_instance:
+ self.from_cpp(cpp_instance)
+ if stage_type is not None:
+ self.stage_type = stage_type
+ if micro_batch_index is not None:
+ self.micro_batch_index = micro_batch_index
+ if layer_index is not None:
+ self.layer_index = layer_index
+
+ def __eq__(self, other):
+ if not isinstance(other, SwapStage):
+ return NotImplemented
+ return (
+ self.stage_type == other.stage_type
+ and self.micro_batch_index == other.micro_batch_index
+ and self.layer_index == other.layer_index
+ )
+
+ def __ne__(self, other):
+ if not isinstance(other, SwapStage):
+ return NotImplemented
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash((self.stage_type, self.micro_batch_index, self.layer_index))
+
+ def copy(self):
+ # return a python SwapStage copy
+ instance = SwapStage()
+ instance.stage_type = self.stage_type
+ instance.micro_batch_index = self.micro_batch_index
+ instance.layer_index = self.layer_index
+ return instance
+
+ def from_cpp(self, instance):
+ self.stage_type = SwapStageType(instance.stageType)
+ self.micro_batch_index = instance.microBatchIndex
+ self.layer_index = instance.layerIndex
+
+ def to_cpp(self, instance):
+ instance.stageType = self.stage_type.value
+ instance.microBatchIndex = self.micro_batch_index
+ instance.layerIndex = self.layer_index
+
+ def __str__(self):
+ ret = dict(stage_type=self.stage_type.name, mbi=self.micro_batch_index, li=self.layer_index)
+ return str(ret)
+
+ def calculate_layer_index(self, stage_op_idx, fwd_layer_info, bwd_layer_info):
+ # stage_op_idx: op_idx starting from the current stage
+ # op_layer_info: fwd_op_layer_info, or bwd_op_layer_info
+ self.layer_index = 0
+ if self.stage_type == SwapStageType.FWD:
+ op_layer_info = fwd_layer_info
+ elif self.stage_type == SwapStageType.BWD:
+ op_layer_info = bwd_layer_info
+ elif self.stage_type == SwapStageType.OPTIM or self.stage_type == SwapStageType.INIT:
+ self.layer_index = 0
+ return self.layer_index
+ else:
+ raise RuntimeError(f"calculate_layer_index error, stage_type={self.stage_type}")
+
+ for i, op_layer_info_value in enumerate(op_layer_info):
+ if stage_op_idx <= op_layer_info_value:
+ self.layer_index = i + 1 # layerIndex 从1开始
+ break
+ if self.layer_index == 0:
+ self.layer_index = len(op_layer_info) + 1
+ return self.layer_index
+
+
+class SwapConfig:
+ def __init__(self):
+ self.cpp_config = get_smart_swap_cpp().NPUSwapManager.GetInstance().config
+
+ def dict(self):
+ return dict(
+ micro_batch_num=self.micro_batch_num,
+ layer_num=self.layer_num,
+ is_oom=self.is_oom,
+ stage=str(self.stage),
+ step=self.step,
+ one_step_duration=self.one_step_duration,
+ policy_step=self.policy_step,
+ current_stage_op_id=self.current_stage_op_id,
+ enable_profiler=self.enable_profiler,
+ enable_executor=self.enable_executor,
+ fwd_op_layer_info=self.fwd_op_layer_info,
+ bwd_op_layer_info=self.bwd_op_layer_info,
+ enable_custom_record_stream=self.enable_custom_record_stream,
+ )
+
+ @property
+ def micro_batch_num(self):
+ return self.cpp_config.microBatchNum
+
+ @micro_batch_num.setter
+ def micro_batch_num(self, value):
+ self.cpp_config.microBatchNum = value
+
+ @property
+ def layer_num(self):
+ return self.cpp_config.layerNum
+
+ @layer_num.setter
+ def layer_num(self, value):
+ self.cpp_config.layerNum = value
+
+ @property
+ def is_oom(self):
+ return self.cpp_config.isOOM
+
+ @is_oom.setter
+ def is_oom(self, value):
+ self.cpp_config.isOOM = value
+
+ @property
+ def stage(self) -> SwapStage:
+ stage = SwapStage()
+ stage.from_cpp(self.cpp_config.stage)
+ return stage
+
+ @stage.setter
+ def stage(self, value: SwapStage):
+ value.to_cpp(self.cpp_config.stage)
+
+ @property
+ def step(self):
+ return self.cpp_config.step
+
+ @property
+ def next_step(self):
+ return self.step + 1
+
+ @step.setter
+ def step(self, value):
+ self.cpp_config.step = value
+
+ @property
+ def one_step_duration(self):
+ return self.cpp_config.oneStepDuration
+
+ @one_step_duration.setter
+ def one_step_duration(self, value):
+ self.cpp_config.oneStepDuration = value
+
+ @property
+ def policy_step(self):
+ return self.cpp_config.policyStep
+
+ @policy_step.setter
+ def policy_step(self, value):
+ self.cpp_config.policyStep = value
+
+ @property
+ def current_stage_op_id(self):
+ return self.cpp_config.currentStageOpId
+
+ @current_stage_op_id.setter
+ def current_stage_op_id(self, value):
+ self.cpp_config.currentStageOpId = value
+
+ @property
+ def enable_profiler(self):
+ return self.cpp_config.enableProfiler
+
+ @enable_profiler.setter
+ def enable_profiler(self, value):
+ self.cpp_config.enableProfiler = value
+
+ @property
+ def enable_executor(self):
+ return self.cpp_config.enableExecutor
+
+ @enable_executor.setter
+ def enable_executor(self, value):
+ self.cpp_config.enableExecutor = value
+
+ @property
+ def enable_custom_record_stream(self):
+ return self.cpp_config.enableCustomRecordStream
+
+ @enable_custom_record_stream.setter
+ def enable_custom_record_stream(self, value):
+ self.cpp_config.enableCustomRecordStream = value
+
+ @property
+ def tensor_size_thresh(self):
+ return self.cpp_config.tensorSizeThresh
+
+ @tensor_size_thresh.setter
+ def tensor_size_thresh(self, value):
+ self.cpp_config.tensorSizeThresh = value
+
+ @property
+ def fwd_op_layer_info(self):
+ return self.cpp_config.fwdOpLayerInfo
+
+ @fwd_op_layer_info.setter
+ def fwd_op_layer_info(self, value):
+ self.cpp_config.fwdOpLayerInfo = value
+
+ @property
+ def bwd_op_layer_info(self):
+ return self.cpp_config.bwdOpLayerInfo
+
+ @bwd_op_layer_info.setter
+ def bwd_op_layer_info(self, value):
+ self.cpp_config.bwdOpLayerInfo = value
+
+
+class UniqueSwapPtr:
+ def __init__(self, cpp_instance=None, ptr_base=None, index=None):
+ self.ptr_base = None
+ self.index = None
+
+ if cpp_instance:
+ self.from_cpp(cpp_instance)
+ if ptr_base:
+ self.ptr_base = ptr_base
+ if index:
+ self.index = index
+
+ def from_cpp(self, instance):
+ self.ptr_base = instance.ptrBase
+ self.index = instance.index
+
+ def to_cpp(self, instance):
+ instance.ptrBase = self.ptr_base
+ instance.index = self.index
+
+ def __str__(self):
+ return f"{self.ptr_base}_{self.index}"
+
+ def __eq__(self, other):
+ if not isinstance(other, UniqueSwapPtr):
+ return NotImplemented
+ return self.ptr_base == other.ptr_base and self.index == other.index
+
+ def __ne__(self, other):
+ if not isinstance(other, UniqueSwapPtr):
+ return NotImplemented
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash((self.ptr_base, self.index))
+
+
+class ProfilerTensorInfo:
+ def __init__(self, tensor_dict):
+ self.origi_ptr = UniqueSwapPtr(cpp_instance=tensor_dict["ptr"])
+ self.ptr = UniqueSwapPtr(cpp_instance=tensor_dict["ptr"])
+ self.size = tensor_dict["size"]
+ self.shape = tensor_dict["shape"]
+ self.dtype = tensor_dict["dtype"]
+ self.tensor_type = SwapTensorType(tensor_dict["tensorType"])
+
+ def get_dict(self):
+ ret = dict(
+ ptr=str(self.ptr), size=self.size, shape=self.shape, dtype=self.dtype, tensor_type=self.tensor_type.name
+ )
+ return ret
+
+ def __str__(self):
+ return str(self.get_dict())
+
+
+class ProfilerOpInfo:
+ def __init__(self, op_dict):
+ self.op_name = op_dict["opName"]
+ self.op_id = op_dict["opId"]
+ self.stage = SwapStage(cpp_instance=op_dict["stage"])
+ self.step = op_dict["step"]
+ self.allocated_bytes = op_dict["allocated_bytes"]
+ self.reserved_bytes = op_dict["reserved_bytes"]
+ self.active_bytes = op_dict["active_bytes"]
+ self.tensor_list = []
+
+ tensor_list = op_dict["tensor"]
+ for tensor in tensor_list:
+ self.tensor_list.append(ProfilerTensorInfo(tensor))
+
+ def print_dict(self):
+ return dict(
+ name=self.op_name,
+ op_id=self.op_id,
+ stage=str(self.stage),
+ tensor_list=[str(tensor) for tensor in self.tensor_list],
+ )
+
+ def print_dict_brief(self):
+ return dict(name=self.op_name, op_id=self.op_id, stage=str(self.stage))
+
+ def __str__(self) -> str:
+ return str(
+ dict(
+ name=self.op_name,
+ op_id=self.op_id,
+ stage=str(self.stage),
+ tensor_list=[str(tensor) for tensor in self.tensor_list],
+ )
+ )
+
+ def get_brief_dict(self):
+ return str(dict(name=self.op_name, op_id=self.op_id, stage=str(self.stage)))
+
+ def __eq__(self, other):
+ if not isinstance(other, ProfilerOpInfo):
+ return NotImplemented
+ return self.op_name == other.op_name and self.op_id == other.op_id and self.stage == other.stage
+
+ def __ne__(self, other):
+ if not isinstance(other, ProfilerOpInfo):
+ return NotImplemented
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash((self.op_name, self.op_id, self.stage))
+
+ def __lt__(self, other):
+ if not isinstance(other, ProfilerOpInfo):
+ return NotImplemented
+ return self.op_id < other.op_id
+
+ def __gt__(self, other):
+ if not isinstance(other, ProfilerOpInfo):
+ return NotImplemented
+ return self.op_id > other.op_id
+
+
+class ProfilerSwapInfo:
+ def __init__(self, swap_dict):
+ self.op_id = swap_dict["opId"]
+ self.swap_name = swap_dict["swapName"]
+ self.size = swap_dict["size"]
+ self.is_oom = swap_dict["isOOM"]
+ self.src_ptr = UniqueSwapPtr(swap_dict["srcPtr"])
+ self.dst_ptr = UniqueSwapPtr(swap_dict["dstPtr"])
+
+ def print_dict(self):
+ return dict(
+ op_id=self.op_id,
+ swap_name=self.swap_name,
+ size=str(self.size),
+ is_oom=self.is_oom,
+ src_ptr=str(self.src_ptr),
+ dst_ptr=str(self.dst_ptr),
+ )
+
+ def __str__(self) -> str:
+ return str(
+ dict(
+ op_id=self.op_id,
+ swap_name=self.swap_name,
+ size=str(self.size),
+ is_oom=self.is_oom,
+ src_ptr=str(self.src_ptr),
+ dst_ptr=str(self.dst_ptr),
+ )
+ )
+
+
+class MemoryReductionInfo:
+ # 适用于1.去除OOM 2.通过策略下降xxxG峰值内存 两种情况
+ def __init__(self, op, memory_reduction_total):
+ self.op = op
+ self.op_id = op.op_id
+ self.memory_reduction_need = memory_reduction_total
+ self.memory_reduction_total = memory_reduction_total
+ self.intersect_candidate_list: List[SwapPolicyCandidate] = []
+
+ def __str__(self):
+ return (
+ f"Reduction_need:{self.memory_reduction_need}, "
+ f"Reduction_total:{self.memory_reduction_total}, "
+ f"OP: {self.op.get_brief_dict()}"
+ )
+
+ def update_memory_reduction_need(self, amount):
+ self.memory_reduction_need += amount
+
+ def cleared(self):
+ return self.memory_reduction_need <= 0
+
+ def check_in_list(self, memory_reduction_list):
+ # precondition: memory_reduction_list is sorted according to op_id
+ if not memory_reduction_list or len(memory_reduction_list) == 0:
+ return False
+ if memory_reduction_list[0].op_id > self.op_id:
+ return False
+ if memory_reduction_list[-1].op_id < self.op_id:
+ return False
+ return True
+
+ def print_dict(self):
+ ret = dict(
+ op_id=str(self.op_id),
+ op_name=str(self.op.op_name),
+ memory_reduction_need=str(self.memory_reduction_need),
+ memory_reduction_total=str(self.memory_reduction_total),
+ )
+ return ret
+
+ def print_dict_op(self):
+ return self.op.print_dict()
+
+
+@dataclass
+class MemoryPeakInfo:
+ """
+ 模型运行中根据内存曲线进行抽象得到的数据结构。
+ 以bwd-fwd的交替为标志,每个MemoryPeak为相邻两次bwd-fwd交替之间的op序列,
+ 代表内存曲线从一个local minima升至local maxima再降至local minima的区间
+ 例如:在非PP的1F1B场景下,每个microbatch(一次前向后一次反向)为一个MemoryPeak;
+ 在PP场景中,如果stage序列为fwd1->fwd2->bwd1->fwd3->bwd2->fwd4->bwd3->bwd4,
+ 则第一个MemoryPeak为fwd1至bwd1, 第二个MemoryPeak为fwd3->bwd2, 等等
+
+ MemoryPeakInfo记录每个MemoryPeak的信息
+ start_opid: 当前MemoryPeak开始的opid (这个MemoryPeak区间中第一个前向阶段的第一个op的opid)
+ end_opid: 当前MemoryPeak结束的opid (这个MemoryPeak中最后一个反向阶段的最后一个op的opid)
+ mp_mri_start_opid: 在这个MemoryPeak区间内第一处需要降内存(MemoryReductionInfo)的opid
+ mp_mri_end_opid: 在这个MemoryPeak区间内最后一处需要降内存(MemoryReductionInfo)的opid
+ """
+
+ start_opid: int
+ end_opid: int
+ mp_mri_start_opid: int = -1
+ mp_mri_end_opid: int = -1
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="MemoryPeakInfo", print_level=print_level)
+
+
+class ProfilerLayerInfo:
+ def __init__(self, op_list: List[ProfilerOpInfo]):
+ self.op_list = op_list
+ self.logical_layer_num = swap_policy_config.logical_layer_num
+
+ self.stage_data = []
+ self.fwd_op_layer_info = []
+ self.bwd_op_layer_info = []
+ self.layer_start_opid: Dict[SwapStage, int] = {}
+ self.layer_to_index_map: Dict[SwapStage, int] = {}
+ self.index_to_layer_map: Dict[int, SwapStage] = {}
+ self.memory_peaks = []
+ self.generate_layer_info()
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="ProfilerLayerInfo", print_level=print_level)
+
+ def generate_layer_info(self):
+ self.stage_data.clear()
+ self.fwd_op_layer_info.clear()
+ self.bwd_op_layer_info.clear()
+ self.layer_start_opid.clear()
+ self.layer_to_index_map.clear()
+ self.index_to_layer_map.clear()
+ self.memory_peaks.clear()
+ self.logical_layer_num = swap_policy_config.logical_layer_num
+
+ self.calculate_layer_info()
+ self.set_layer_info()
+ self.create_layer_mapping()
+ self.get_memory_peaks()
+
+ def calculate_layer_info(self):
+ op_fwd_sequence = []
+ op_bwd_sequence = []
+ for op in self.op_list:
+ if op.stage.micro_batch_index == 1:
+ if op.stage.stage_type == SwapStageType.FWD:
+ op_fwd_sequence.append(op)
+ elif op.stage.stage_type == SwapStageType.BWD:
+ op_bwd_sequence.append(op)
+ if self.logical_layer_num < 0: # use per op level layer info
+ self.fwd_op_layer_info = list(range(len(op_fwd_sequence)))
+ self.bwd_op_layer_info = list(range(len(op_bwd_sequence)))
+ else: # layer divided by logical layer num
+ for i in range(self.logical_layer_num - 1):
+ self.fwd_op_layer_info.append(len(op_fwd_sequence) // self.logical_layer_num * (i + 1))
+ self.bwd_op_layer_info.append(len(op_bwd_sequence) // self.logical_layer_num * (i + 1))
+
+ def set_layer_info(self):
+ cur_stage = SwapStage()
+ stage_start_idx = 0
+ for op in self.op_list:
+ stage = op.stage
+ # 将layerindex的信息更新到stage中, 同时更新model_info.model_stage_seq
+ if stage.stage_type != cur_stage.stage_type or stage.micro_batch_index != cur_stage.micro_batch_index:
+ cur_stage.stage_type = stage.stage_type
+ cur_stage.micro_batch_index = stage.micro_batch_index
+ stage_start_idx = op.op_id
+ stage_op_idx = op.op_id - stage_start_idx
+ stage.calculate_layer_index(stage_op_idx, self.fwd_op_layer_info, self.bwd_op_layer_info)
+
+ def create_layer_mapping(self):
+ for index, op in enumerate(self.op_list):
+ if not self.stage_data or op.stage != self.stage_data[-1]["stage"]:
+ self.stage_data.append(
+ {
+ "op_id": op.op_id,
+ "stage": op.stage,
+ "stage_type": op.stage.stage_type,
+ }
+ )
+ self.print_with_rank(
+ (f"op_id:: {str(op.op_id)}, stage: {str(op.stage)}, stage_type: {str(op.stage.stage_type)}")
+ )
+ for index, row in enumerate(self.stage_data):
+ if row["stage"] in self.layer_to_index_map:
+ raise ValueError("Find duplicate stage ...")
+ self.index_to_layer_map[index] = row["stage"]
+ self.layer_to_index_map[row["stage"]] = index
+ self.layer_start_opid[row["stage"]] = row["op_id"]
+
+ def get_memory_peaks(self):
+ """
+ 建立MemoryPeakInfo数据结构, 每次由反向阶段进入前向阶段时进入新的MemoryPeak区间
+ 仅将前反向进行MemoryPeakInfo的划分抽象, 不包含优化器和INIT阶段
+ """
+ self.memory_peaks = []
+ cur_peak_start = -1
+ cur_peak_end = -1
+ for index, layer in self.index_to_layer_map.items():
+ if cur_peak_start == -1:
+ if index == 0 and layer.stage_type == SwapStageType.FWD:
+ cur_peak_start = self.layer_start_opid[layer]
+ elif index > 0:
+ prev_layer = self.index_to_layer_map[index - 1]
+ if layer.stage_type == SwapStageType.FWD and prev_layer.stage_type != SwapStageType.FWD:
+ cur_peak_start = self.layer_start_opid[layer]
+ if cur_peak_end == -1:
+ if index == -len(self.layer_to_index_map) - 1 and layer.stage_type == SwapStageType.BWD:
+ cur_peak_end = self.layer_end_opid[layer]
+ elif index < len(self.layer_to_index_map) - 1:
+ next_layer = self.index_to_layer_map[index + 1]
+ if layer.stage_type == SwapStageType.BWD and next_layer.stage_type != SwapStageType.BWD:
+ cur_peak_end = self.layer_start_opid[next_layer] - 1
+ if cur_peak_start != -1 and cur_peak_end != -1:
+ cur_memory_peak = MemoryPeakInfo(cur_peak_start, cur_peak_end)
+ self.memory_peaks.append(cur_memory_peak)
+ cur_peak_start = -1
+ cur_peak_end = -1
+ self.print_with_rank(
+ f"current profiler step has {len(self.memory_peaks)} memory peaks", print_level=PrintLevel.INFO
+ )
+
+ def get_prev_layer(self, layer: SwapStage):
+ idx = self.layer_to_index_map[layer]
+ if idx - 1 not in self.index_to_layer_map:
+ return None
+ else:
+ return self.index_to_layer_map[idx - 1]
+
+ def get_next_layer(self, layer: SwapStage):
+ idx = self.layer_to_index_map[layer]
+ if idx + 1 not in self.index_to_layer_map:
+ return None
+ else:
+ return self.index_to_layer_map[idx + 1]
+
+
+class ProfilerDataOneStep:
+ def __init__(self, duration_time, step, is_oom, enable_profiler=True):
+ self.op_list: List[ProfilerOpInfo] = []
+ self.swap_list: List[ProfilerSwapInfo] = []
+ self.memory_reduction_list: List[MemoryReductionInfo] = []
+ self.layer_start_opid: Dict[SwapStage, int] = dict()
+ self.layer_info: ProfilerLayerInfo = None
+ self.duration_time = duration_time
+ self.step = step
+ self.max_memory = None
+ self.target_memory = None
+ self.is_oom = is_oom
+
+ if enable_profiler:
+ self.acquire_data()
+ self.layer_info = ProfilerLayerInfo(self.op_list)
+ self.layer_start_opid = self.layer_info.layer_start_opid
+ self.memory_peaks = self.layer_info.memory_peaks
+ self.init_memory_reduction_list()
+ self.get_memory_peak_mri()
+
+ self.__stage_list: List[SwapStage] = []
+ self.__stage_map: Dict[SwapStage, int] = {}
+ self.__parse_stage_info()
+ self.__op_info_cache: Dict[str, List[ProfilerOpInfo]] = {} # {op_name, List[ProfilerOpInfo]}
+
+ def __parse_stage_info(self):
+ for op in self.op_list:
+ if not self.__stage_list or op.stage != self.__stage_list[-1]:
+ self.__stage_list.append(op.stage)
+ self.__stage_map[op.stage] = self.__stage_list.index(op.stage)
+
+ def __get_op_info_from_list(self, from_op: ProfilerOpInfo, op_name: str, direction: str) -> ProfilerOpInfo:
+ # Determine the bisect function based on the direction
+ if direction == "next":
+ bisect_fn = bisect_right
+ op_name_check = op_name
+ idx_adjustment = 0
+ elif direction == "prev":
+ bisect_fn = bisect_left
+ op_name_check = op_name
+ idx_adjustment = -1
+ else:
+ raise ValueError("direction must be 'next' or 'prev'")
+
+ if op_name_check == "": # when search op is not specified
+ begin_idx = bisect_fn(self.op_list, from_op) + idx_adjustment
+ if begin_idx < 0 or begin_idx >= len(self.op_list):
+ return None
+ return self.op_list[begin_idx]
+
+ # Cache logic: Fetch or cache the op_info_list
+ if op_name_check not in self.__op_info_cache:
+ op_info_list = [op for op in self.op_list if op.op_name == op_name_check]
+ self.__op_info_cache[op_name_check] = op_info_list
+ else:
+ op_info_list = self.__op_info_cache[op_name_check]
+
+ # Determine the index to start searching from
+ begin_idx = bisect_fn(self.op_list, from_op) + idx_adjustment
+ if begin_idx < 0 or begin_idx >= len(self.op_list):
+ return None
+
+ # Search within the cached op_info_list
+ target_idx = bisect_fn(op_info_list, from_op) + idx_adjustment
+ if target_idx < 0 or target_idx >= len(op_info_list):
+ return None
+ return op_info_list[target_idx]
+
+ def group_op_info_by(self, op_info_list: List[ProfilerOpInfo], method="") -> List[List[ProfilerOpInfo]]:
+ if not all(isinstance(item, ProfilerOpInfo) for item in op_info_list):
+ raise TypeError("op_info_list can only contain elements with ProfilerOpInfo type.")
+ if method == "microbatch":
+ result_op_info_list = []
+ mb_group = []
+ curr_mb = None
+ for op in op_info_list:
+ if curr_mb is None:
+ curr_mb = op.stage.micro_batch_index
+ mb_group.append(op)
+ else:
+ if op.stage.micro_batch_index != curr_mb:
+ curr_mb = op.stage.micro_batch_index
+ result_op_info_list.append(mb_group.copy())
+ mb_group.clear()
+ mb_group.append(op)
+ return result_op_info_list
+ elif method == "":
+ return op_info_list
+ else:
+ raise NotImplementedError('group_by method other than "microbatch" is not implemented yet.')
+
+ def get_all_op_info(self, op_names: List[str] = None) -> List[ProfilerOpInfo]:
+ if op_names is None or len(op_names) == 0:
+ return self.op_list
+ op_info_list = []
+ for op_name in op_names:
+ if op_name in self.__op_info_cache:
+ op_info_list.extend(self.__op_info_cache[op_name])
+ else:
+ op = self.get_first_op_info(op_name)
+ while op is not None:
+ op_info_list.append(op)
+ op = self.get_next_op_info(op, op_name)
+ return op_info_list
+
+ def get_next_op_info(self, from_op: ProfilerOpInfo, next_op_name: str = "") -> ProfilerOpInfo:
+ if from_op is None:
+ return None
+ return self.__get_op_info_from_list(from_op, next_op_name, "next")
+
+ def get_prev_op_info(self, from_op: ProfilerOpInfo, prev_op_name: str = "") -> ProfilerOpInfo:
+ if from_op is None:
+ return None
+ return self.__get_op_info_from_list(from_op, prev_op_name, "prev")
+
+ def get_first_op_info(self, op_name: str = "") -> ProfilerOpInfo:
+ if len(self.op_list) == 0:
+ return None
+ first_op = self.op_list[0]
+ if op_name == "":
+ return first_op
+ return self.get_next_op_info(first_op, op_name)
+
+ def get_last_op_info(self, op_name: str = "") -> ProfilerOpInfo:
+ if len(self.op_list) == 0:
+ return None
+ last_op = self.op_list[-1]
+ if op_name == "":
+ return last_op
+ return self.get_prev_op_info(last_op, op_name)
+
+ def __get_adjacent_stage(self, stage: SwapStage, op_name: str, direction: str) -> SwapStage:
+ # Determine whether we are looking for the next or previous stage
+ if direction == "next":
+ stage_index_adjustment = 1
+ get_op_fn = self.get_first_op_info
+ get_adj_op_fn = self.get_next_op_info
+ elif direction == "prev":
+ stage_index_adjustment = -1
+ get_op_fn = self.get_last_op_info
+ get_adj_op_fn = self.get_prev_op_info
+ else:
+ raise ValueError("direction must be 'next' or 'prev'")
+
+ # Get the stage index from the stage map
+ if stage is None:
+ return None
+ stage_index = self.__stage_map.get(stage, None)
+ if stage_index is None:
+ return None
+
+ # If op_name is empty, handle the simple case of getting the next or previous stage
+ if op_name == "":
+ adjacent_stage_index = stage_index + stage_index_adjustment
+ if adjacent_stage_index < 0 or adjacent_stage_index >= len(self.__stage_list):
+ return None
+ return self.__stage_list[adjacent_stage_index]
+
+ # If op_name is specified, traverse the operations to find the adjacent stage
+ result_stage = None
+ curr_op = get_op_fn(op_name)
+ while curr_op is not None:
+ curr_stage_idx = self.__stage_map.get(curr_op.stage, None)
+ if curr_stage_idx is None:
+ break # Avoid infinite loop if stage is not found
+
+ is_valid_in_next_direction = direction == "next" and curr_stage_idx > stage_index
+ is_valid_in_prev_direction = direction == "prev" and curr_stage_idx < stage_index
+
+ if is_valid_in_next_direction or is_valid_in_prev_direction:
+ result_stage = curr_op.stage
+ break
+ curr_op = get_adj_op_fn(curr_op, op_name)
+ return result_stage
+
+ def get_next_stage(self, stage: SwapStage, op_name: str = "") -> SwapStage:
+ return self.__get_adjacent_stage(stage, op_name, "next")
+
+ def get_prev_stage(self, stage: SwapStage, op_name: str = "") -> SwapStage:
+ return self.__get_adjacent_stage(stage, op_name, "prev")
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="ProfilerDataOneStep", print_level=print_level)
+
+ def __str__(self):
+ ret = "=" * 20 + "ProfilerDataOneStep SHOW BEGIN" + "=" * 20 + "\n"
+
+ ret += f"The length of op_list is {len(self.op_list)}\n"
+ for index, op_info in enumerate(self.op_list):
+ ret += f"op_info-{index}, {str(op_info.print_dict())}\n"
+
+ ret += f"The length of swap_list is {len(self.swap_list)}\n"
+ for index, swap_info in enumerate(self.swap_list):
+ ret += f"swap_info-{index}, {str(swap_info.print_dict())}\n"
+
+ for index, memory_reduction in enumerate(self.memory_reduction_list):
+ ret += f"memory_reduction-{index}, {str(memory_reduction.print_dict())}\n"
+
+ ret += "=" * 20 + "ProfilerDataOneStep SHOW END" + "=" * 20 + "\n"
+ return ret
+
+ @property
+ def length(self):
+ return len(self.op_list)
+
+ def acquire_data(self):
+ op_list = get_smart_swap_cpp().getProfilerOpInfoData()
+ swap_list = get_smart_swap_cpp().getProfilerSwapInfoData()
+ self.op_list = [ProfilerOpInfo(i) for i in op_list]
+ self.swap_list = [ProfilerSwapInfo(i) for i in swap_list]
+ get_smart_swap_cpp().updateProfiler()
+
+ def filter_swap_list(self):
+ """
+ 修正内存建模曲线:将swap_list中有swap_out但是没有对应swap_in的tensor单独记录,
+ 以MemoryPeakInfo为单位记录当前MemoryPeakInfo中上述多余swap out的tensor的总size
+ """
+ swap_in_list = [item for item in self.swap_list if item.swap_name == "swapIn"]
+ swap_in_total_size = sum([item.size for item in swap_in_list])
+ self.print_with_rank(
+ f"original swap in: {len(swap_in_list)} swap_in items with total size {swap_in_total_size}",
+ print_level=PrintLevel.INFO,
+ )
+ swap_out_list = [item for item in self.swap_list if item.swap_name == "swapOut"]
+ swap_out_total_size = sum([item.size for item in swap_out_list])
+ self.print_with_rank(
+ f"original swap out: {len(swap_out_list)} swap_out items with total size {swap_out_total_size}",
+ print_level=PrintLevel.INFO,
+ )
+ if swap_in_total_size == swap_out_total_size:
+ return None
+ swap_in_dict = dict([item.src_ptr, item] for item in swap_in_list)
+ extra_swap_out_dict = dict([(i, []) for i in range(len(self.memory_peaks))])
+ swap_out_list.sort(key=lambda item: item.op_id)
+ cur_mp_idx = 0
+ for item in swap_out_list:
+ if item.dst_ptr not in swap_in_dict:
+ while cur_mp_idx < len(self.memory_peaks) and item.op_id > self.memory_peaks[cur_mp_idx].end_opid:
+ cur_mp_idx += 1
+ if cur_mp_idx < len(self.memory_peaks):
+ extra_swap_out_dict[cur_mp_idx].append(item)
+ else:
+ self.print_with_rank(
+ f"current swap out at op_id {item.op_id} happens at OPTIM stage", print_level=PrintLevel.INFO
+ )
+ return extra_swap_out_dict
+
+ def get_max_memory(self, extra_swap_out_dict=None):
+ if extra_swap_out_dict:
+ for i, item in extra_swap_out_dict.items():
+ self.print_with_rank(f"extra_swap_out_dict has {len(item)} at {i}-th mp", print_level=PrintLevel.INFO)
+ swap_list_dict = {}
+ for swap_info in self.swap_list:
+ swap_list_dict.setdefault(swap_info.op_id, []).append(swap_info)
+
+ theoretical_minus_actual = 0
+ cur_mp_idx = 0
+ for op in self.op_list:
+ swap_info_list = swap_list_dict.get(op.op_id, [])
+ # 可能一个opid对应了多个swap
+ swap_out_size = sum(info.size for info in swap_info_list if info.swap_name == "swapOut")
+ swap_in_size = sum(info.size for info in swap_info_list if info.swap_name == "swapIn")
+
+ # 以MemmoryPeakInfo为单位进行内存曲线校正:进入每个新的MemoryPeakInfo时都去除swap out但没swap in的tensor的总size
+ if extra_swap_out_dict:
+ if cur_mp_idx < len(self.memory_peaks) and op.op_id > self.memory_peaks[cur_mp_idx].end_opid:
+ extra_swap_out_size = sum([item.size for item in extra_swap_out_dict[cur_mp_idx]])
+ while cur_mp_idx < len(self.memory_peaks) and op.op_id > self.memory_peaks[cur_mp_idx].end_opid:
+ cur_mp_idx += 1
+ theoretical_minus_actual -= extra_swap_out_size
+
+ theoretical_minus_actual = theoretical_minus_actual + swap_out_size - swap_in_size
+ op.theoretical_active_bytes = op.active_bytes + theoretical_minus_actual
+ return max(
+ (
+ op.theoretical_active_bytes
+ for op in self.op_list
+ if op.stage.stage_type not in [SwapStageType.INIT, SwapStageType.OPTIM]
+ ),
+ default=0,
+ )
+
+ def get_target_memory(self):
+ self.print_with_rank(f"is current step oom? {self.is_oom}", print_level=PrintLevel.INFO)
+ max_memory = max((op.active_bytes for op in self.op_list), default=0)
+ if self.is_oom:
+ return max_memory - swap_policy_config.redundant_memory
+ elif self.swap_list:
+ return max_memory
+
+ if swap_policy_config.target_mode:
+ target_memory = swap_policy_config.target_memory
+ else:
+ target_memory = self.max_memory - swap_policy_config.reduction_memory
+ return target_memory
+
+ def init_memory_reduction_list(self):
+ self.memory_reduction_list = []
+ extra_swap_out_dict = self.filter_swap_list()
+ self.max_memory = self.get_max_memory(extra_swap_out_dict=extra_swap_out_dict)
+ self.target_memory = self.get_target_memory()
+ self.print_with_rank(
+ f"max_memory={self.max_memory}, target_memory={self.target_memory}", print_level=PrintLevel.INFO
+ )
+ for op in self.op_list:
+ if op.theoretical_active_bytes > self.target_memory:
+ if op.stage.stage_type == SwapStageType.INIT:
+ self.print_with_rank("Skip init ... ")
+ continue
+ if op.stage.stage_type == SwapStageType.OPTIM:
+ self.print_with_rank("Memory Bound at Optim Stage ...")
+ break
+ memory_reduction_info = MemoryReductionInfo(op, op.theoretical_active_bytes - self.target_memory)
+ self.memory_reduction_list.append(memory_reduction_info)
+ # new data structure:build a map from index to opid of memory_reduction_info
+ self.mri_opid2idx = dict(
+ [(self.memory_reduction_list[i].op_id, i) for i in range(len(self.memory_reduction_list))]
+ )
+
+ def reset_memory_reduction_list(self):
+ for memory_info in self.memory_reduction_list:
+ memory_info.memory_reduction_need = memory_info.memory_reduction_total
+
+ def get_memory_peak_mri(self):
+ """
+ 建立每个MemoryPeakInfo对应的MemoryReductionInfo的开始和结束信息(mp_mri_start_opid, mp_mri_end_opid)
+ """
+ self.print_with_rank(
+ f"current memory_reduction_list has len {len(self.memory_reduction_list)}", print_level=PrintLevel.INFO
+ )
+ if len(self.memory_reduction_list) == 0:
+ return
+ cur_mri = 0
+ for idx, mp in enumerate(self.memory_peaks):
+ mp.mp_mri_start_opid = -1
+ mp.mp_mri_end_opid = -1
+ while (
+ cur_mri < len(self.memory_reduction_list)
+ and self.memory_reduction_list[cur_mri].op_id >= mp.start_opid
+ and self.memory_reduction_list[cur_mri].op_id <= mp.end_opid
+ ):
+ if mp.mp_mri_start_opid == -1:
+ mp.mp_mri_start_opid = self.memory_reduction_list[cur_mri].op_id
+ cur_mri += 1
+ if mp.mp_mri_start_opid > -1:
+ mp.mp_mri_end_opid = self.memory_reduction_list[cur_mri - 1].op_id
+ self.print_with_rank(
+ f"current mp {idx} starts at opid {mp.mp_mri_start_opid} and ends at opid {mp.mp_mri_end_opid}",
+ print_level=PrintLevel.INFO,
+ )
+
+ def get_sorted_op_names(self, sort_by="frequency") -> List[str]:
+ op_name_sequence = [item.op_name for item in self.op_list]
+ op_names_frequency_map = Counter(op_name_sequence)
+ if sort_by == "frequency":
+ op_names_frequency_list = sorted(
+ op_names_frequency_map.keys(), key=lambda name: op_names_frequency_map[name], reverse=True
+ )
+ elif sort_by == "alphabetical":
+ op_names_frequency_list = sorted(op_names_frequency_map.keys())
+ else:
+ raise NotImplementedError('sort methods other than "frequency" and "alphabetical" are not supported.')
+ return op_names_frequency_list
+
+ def map_unique_ptr_as_latest(self):
+ map_old2new = {}
+ for swap_row in self.swap_list:
+ for key, value in map_old2new.items():
+ if value == swap_row.src_ptr:
+ map_old2new[key] = swap_row.dst_ptr
+ map_old2new[swap_row.src_ptr] = swap_row.dst_ptr
+ for op in self.op_list:
+ for tensor in op.tensor_list:
+ if tensor.ptr in map_old2new:
+ tensor.ptr = map_old2new[tensor.ptr]
+
+ def update_tensor_types(self, map_ptr2type: Dict[UniqueSwapPtr, SwapTensorType]):
+ for op in self.op_list:
+ for tensor in op.tensor_list:
+ if tensor.ptr in map_ptr2type:
+ tensor.tensor_type = map_ptr2type[tensor.ptr]
+
+
+class TensorInfoDetail:
+ def __init__(self, profiler_tensor_info):
+ self.info: ProfilerTensorInfo = profiler_tensor_info
+ self.used_op_list: List[ProfilerOpInfo] = []
+ self.policy_candidate_list: List[SwapPolicyCandidate] = [] # 一个Tensor可能被多次Swap
+
+ def update_op(self, op: ProfilerOpInfo):
+ if len(self.used_op_list) != 0 and self.used_op_list[-1].op_id == op.op_id:
+ return
+ self.used_op_list.append(op)
+
+ def is_used_multiple_times(self): # 如果Tensor只被使用了一次,不需要Swap
+ return len(self.used_op_list) >= 2
+
+
+class SwapPolicyCandidate:
+ def __init__(
+ self,
+ tensor: TensorInfoDetail,
+ is_optimizer_or_weight: bool = False,
+ swap_out_op: ProfilerOpInfo = None,
+ swap_in_op: ProfilerOpInfo = None,
+ swap_out_stage: SwapStage = None,
+ swap_in_stage: SwapStage = None,
+ free_stage: SwapStage = None,
+ swap_in_free_stage: SwapStage = None,
+ ):
+ self.tensor: TensorInfoDetail = tensor
+ self.covered_reductions: List[MemoryReductionInfo] = [] # 可删除
+ self.num_covered_reductions = 0
+ self.start_mri_opid = -1 # 能覆盖的第一个mri的opid
+ self.end_mri_opid = -1 # 能覆盖的最后一个mri的opid
+ self.is_optimizer_or_weight = is_optimizer_or_weight
+ if not is_optimizer_or_weight:
+ self.swap_out_op = swap_out_op
+ self.swap_in_op = swap_in_op
+ self.swap_out_stage = swap_out_op.stage
+ self.swap_in_stage = swap_in_op.stage
+ self.swap_out_stage_actual = self.swap_out_stage
+ self.swap_in_stage_actual = self.swap_in_stage
+ else:
+ self.swap_out_stage = swap_out_stage
+ self.swap_in_stage = swap_in_stage
+ self.swap_out_stage_actual = self.swap_out_stage
+ self.swap_in_stage_actual = self.swap_in_stage
+ self.free_stage = free_stage
+ self.swap_in_free_stage = swap_in_free_stage
+
+ def set_device_to_host_stage(self, stage: SwapStage):
+ self.swap_out_stage = stage
+ self.swap_out_stage_actual = stage
+
+ def get_device_to_host_stage(self):
+ return self.swap_out_stage_actual
+
+ def set_device_to_host_free_stage(self, stage: SwapStage):
+ self.free_stage = stage
+
+ def set_host_to_device_stage(self, stage: SwapStage):
+ self.swap_in_stage = stage
+ self.swap_in_stage_actual = stage
+
+ def get_host_to_device_stage(self):
+ return self.swap_in_stage_actual
+
+ def set_host_to_device_free_stage(self, stage: SwapStage):
+ self.swap_in_free_stage = stage
+
+ def to_cpp(self):
+ instance = get_smart_swap_cpp().SwapPolicyInfo()
+ instance.executorNeedMatch = not self.is_optimizer_or_weight
+ if not self.is_optimizer_or_weight:
+ self.tensor.info.origi_ptr.to_cpp(instance.ptr)
+ instance.swapOutOpId = self.swap_out_op.op_id
+ instance.swapInOpId = self.swap_in_op.op_id
+ else:
+ self.tensor.info.ptr.to_cpp(instance.ptr)
+ self.swap_out_stage.to_cpp(instance.swapOutStage)
+ self.swap_in_stage_actual.to_cpp(instance.swapInStage)
+ self.free_stage.to_cpp(instance.freeStage)
+ self.swap_in_free_stage.to_cpp(instance.swapInFreeStage)
+ return instance
+
+ def __str__(self):
+ return str(
+ dict(
+ tensor=str(self.tensor.info),
+ is_optimizer_or_weight=str(self.is_optimizer_or_weight),
+ swap_out_op=self.swap_out_op.print_dict_brief() if hasattr(self, "swap_out_op") else "None",
+ swap_in_op=self.swap_in_op.print_dict_brief() if hasattr(self, "swap_in_op") else "None",
+ swap_out_stage=str(self.swap_out_stage),
+ swap_in_stage=str(self.swap_in_stage),
+ swap_out_stage_actual=str(
+ self.swap_out_stage_actual if hasattr(self, "swap_out_stage_actual") else "None"
+ ),
+ swap_in_stage_actual=str(
+ self.swap_in_stage_actual if hasattr(self, "swap_in_stage_actual") else "None"
+ ),
+ free_stage=str(self.free_stage),
+ swap_in_free_stage=str(self.swap_in_free_stage),
+ )
+ )
+
+
+class SwapPolicy:
+ def __init__(self, swap_policy_candidates: List[SwapPolicyCandidate], profiler_data: ProfilerDataOneStep):
+ self.__swap_policy_candidates: List[SwapPolicyCandidate] = swap_policy_candidates
+ self.__profiler_data: ProfilerDataOneStep = profiler_data
+ self.__stage_list: List[SwapStage] = []
+ self.__stage_map: Dict[SwapStage, int] = {}
+ self.__parse_stage_info()
+
+ def __parse_stage_info(self):
+ for op in self.__profiler_data.op_list:
+ if not self.__stage_list or op.stage != self.__stage_list[-1]:
+ self.__stage_list.append(op.stage)
+ self.__stage_map[op.stage] = self.__stage_list.index(op.stage)
+
+ def __auto_lint(self, policy: List[SwapPolicyCandidate]):
+ # remove candidates with identical swap out and swap in stages.
+ cand_remove_list = []
+ for cand in policy:
+ swap_out_stage = cand.swap_out_stage_actual
+ swap_in_stage = cand.swap_in_stage_actual
+ if swap_out_stage == swap_in_stage:
+ cand_remove_list.append(cand)
+ continue
+ for cand in policy.copy():
+ if cand in cand_remove_list:
+ policy.remove(cand)
+
+ # offset free stage by one if overlap.
+ for cand in policy:
+ swap_in_stage_actual = cand.swap_in_stage_actual
+ swap_in_free_stage = cand.swap_in_free_stage
+ if swap_in_stage_actual == swap_in_free_stage:
+ cand.swap_in_free_stage = self.__profiler_data.get_next_stage(swap_in_free_stage)
+ swap_out_stage_actual = cand.swap_out_stage_actual
+ swap_out_free_stage = cand.free_stage
+ if swap_out_stage_actual == swap_out_free_stage:
+ cand.free_stage = self.__profiler_data.get_next_stage(swap_out_free_stage)
+
+ def get_candidates(self) -> List[SwapPolicyCandidate]:
+ return self.__swap_policy_candidates
+
+ def set_candidates(self, candidates: List[SwapPolicyCandidate]):
+ self.__auto_lint(candidates)
+ self.__swap_policy_candidates = candidates
+
+ def get_profiler_data(self) -> ProfilerDataOneStep:
+ return self.__profiler_data
+
+
+class PolicyResult:
+ MAX_OP_NAMES_LENGTH = 64
+
+ def __init__(self):
+ self.policy_list: List[SwapPolicyCandidate] = None # 用于SwapOut和SwapIn的Tensor信息列表
+ self.policy_step = None # 用第几个Step的Profiling结果进行匹配
+ self.tensor_size_thresh = None # 最小可能被Swap的Tensor的size大小
+ self.fwd_op_layer_info = None # 当前policy_step的Profiling对应的前向层信息
+ self.bwd_op_layer_info = None # 当前policy_step的Profiling对应的反向层信息
+ self.op_names_frequency_list = None # 当前policy_step的Profiling的OpName的频次列表,由高到低,最多有64个元素
+
+ def clear(self):
+ self.policy_list = None
+ self.policy_step = None
+ self.tensor_size_thresh = None
+ self.fwd_op_layer_info = None
+ self.bwd_op_layer_info = None
+ self.op_names_frequency_list = None
+
+ def __str__(self):
+ info = dict(
+ policy_step=self.policy_step,
+ tensor_size_thresh=self.tensor_size_thresh,
+ fwd_op_layer_info=self.fwd_op_layer_info,
+ bwd_op_layer_info=self.bwd_op_layer_info,
+ )
+ ret = f"Basic policy is {info}\n"
+ ret += f"A total number of {len(self.policy_list)} swaps are selected.\n"
+ for index, item in enumerate(self.policy_list):
+ ret += f"policy-{index}: \t\t{item}\n"
+ return ret
+
+ def set_py_swap_policy_to_cpp(self, config: SwapConfig):
+ # 设置候选swap的tensor到c++侧
+ swap_policy_info_list = []
+ if self.policy_list is not None:
+ for candidate in self.policy_list:
+ try:
+ swap_policy_info_list.append(candidate.to_cpp())
+ except Exception as e:
+ raise RuntimeError(f"candidate.to_cpp() error ! \n{candidate}") from e
+
+ if self.fwd_op_layer_info is not None:
+ config.fwd_op_layer_info = self.fwd_op_layer_info
+ if self.bwd_op_layer_info is not None:
+ config.bwd_op_layer_info = self.bwd_op_layer_info
+
+ if self.policy_step:
+ # 设置config相关
+ config.tensorSizeThresh = self.tensor_size_thresh
+ config.policy_step = self.policy_step
+ # 设置op_names出现的频率
+ get_smart_swap_cpp().setFrequentOpNameData(self.op_names_frequency_list[: self.MAX_OP_NAMES_LENGTH])
+
+ else:
+ config.tensorSizeThresh = swap_policy_config.tensor_size_thresh
+ config.policy_step = 0
+ get_smart_swap_cpp().setFrequentOpNameData([])
+
+ get_smart_swap_cpp().setPolicyInfoData(swap_policy_info_list)
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_engine.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..d11551b474da58330a88af05a7b556b05a4d687e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_engine.py
@@ -0,0 +1,294 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import os
+import stat
+import time
+import pickle
+from typing import Dict
+
+import pandas
+
+from .policy_generator import PolicyGenerator
+from .swap_policy_config import swap_policy_config
+from .swap_utils import print_with_rank, PrintLevel, timer
+from .swap_cpp_adaptor import (
+ SwapConfig,
+ ProfilerDataOneStep,
+ PolicyResult,
+ SwapTensorType,
+ SwapPolicyCandidate,
+ UniqueSwapPtr,
+ TensorInfoDetail,
+ record_tensor_ptr_with_types,
+ SwapPolicy,
+)
+
+
+class SwapEngine:
+ def __init__(self, models, optimizer, get_optimizer_tensors_fcn, config: SwapConfig, custom_policy_fcn):
+ # 相关模块
+ self.models = models
+ self.optimizer = optimizer
+ self.get_optimizer_tensors_fcn = get_optimizer_tensors_fcn
+ self.custom_policy_fcn = custom_policy_fcn
+
+ # 控制参数
+ self.config = config
+ self.rank = swap_policy_config.rank
+ self.output_root_path = swap_policy_config.output_root_path
+ if swap_policy_config.save_policy or swap_policy_config.save_profiler_data:
+ if not os.path.exists(self.output_root_path) and self.rank == 0:
+ os.makedirs(self.output_root_path)
+ self.duration_time = None
+ self.step_parameters = {}
+ self.all_step_duration = {}
+
+ # profiling 数据
+ self.profiler_op_step: ProfilerDataOneStep = None
+ self.profiler_all_step: Dict[int, ProfilerDataOneStep] = dict() # 目前为止所有step的profiler数据
+
+ # 处理后的数据,用于生成策略
+ self.tensor_info_dict: Dict[UniqueSwapPtr, TensorInfoDetail] = dict()
+
+ # 当前生成的最新policy
+ self.newest_policy_result: PolicyResult = PolicyResult()
+ self.map_unique_ptr2tensor_type = dict()
+
+ # 用户policy策略函数
+ if self.custom_policy_fcn is None:
+ print_with_rank("User policy is missing, skip user policy.", print_level=PrintLevel.INFO)
+ self.use_custom_policy = False
+ else:
+ print_with_rank("Found user policy.", print_level=PrintLevel.INFO)
+ self.use_custom_policy = True
+
+ @property
+ def step(self):
+ return self.config.step
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="SwapEngine", print_level=print_level)
+
+ def clear_policy(self):
+ self.newest_policy_result.clear()
+
+ def append_profiler_data(self, profiler_op_step: ProfilerDataOneStep):
+ self.profiler_all_step[profiler_op_step.step] = profiler_op_step
+ self.forced_swap_list = [i for i in profiler_op_step.swap_list if i.is_oom]
+ if swap_policy_config.save_profiler_data:
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ profiler_all_step_file = os.path.join(self.output_root_path, f"profiler_all_step_{self.rank}.pkl")
+ with os.fdopen(os.open(profiler_all_step_file, flags, mode=mode), "wb") as file:
+ pickle.dump(self.profiler_all_step, file)
+
+ def save_policy_list(self, swap_list):
+ swap_list_pd = pandas.DataFrame([i.tensor.info.get_dict() for i in swap_list])
+ flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+ mode = stat.S_IWUSR | stat.S_IRUSR
+ policy_file = os.path.join(self.output_root_path, f"Policy_{self.rank}.csv")
+ with os.fdopen(os.open(policy_file, flags, mode=mode), "wb") as file:
+ swap_list_pd.to_csv(file)
+
+ def record_tensor_types(self):
+ self.map_unique_ptr2tensor_type.clear()
+ # 针对优化器状态的特殊类tensor,将其记录在C++侧的map映射中,方便其执行匹配
+ if self.optimizer and self.get_optimizer_tensors_fcn:
+ tensors = self.get_optimizer_tensors_fcn(self.optimizer)
+ unique_ptrs = record_tensor_ptr_with_types(tensors, SwapTensorType.OPTIM, 1, False)
+ for unique_ptr in unique_ptrs:
+ self.map_unique_ptr2tensor_type[UniqueSwapPtr(unique_ptr)] = SwapTensorType.OPTIM
+
+ def is_similar_with_policy_profiler(self, profiler_op_step: ProfilerDataOneStep):
+ if self.newest_policy_result.policy_step is None:
+ ret = True
+ self.print_with_rank("The policy step is None, maybe initial stage ...")
+ else:
+ ret = self.is_equal_op_sequence(
+ swap_policy_config, profiler_op_step, self.profiler_all_step[self.newest_policy_result.policy_step]
+ )
+ self.print_with_rank(
+ (
+ f"now: {len(profiler_op_step.op_list)}, "
+ f"last: {self.profiler_all_step[self.newest_policy_result.policy_step].length}, "
+ f"ret: {ret}"
+ )
+ )
+ if ret:
+ self.print_with_rank("The sequence is similar with the policy one ...")
+ return ret
+
+ @timer
+ def process_profiler_data(self):
+ self.print_with_rank("Processing data ... ", print_level=PrintLevel.INFO)
+ # 获取特殊类tensor的unique_ptr信息
+ self.record_tensor_types()
+ # 将profiler_op_step中的UniquePtr全部映射为最新的ptr
+ self.profiler_op_step.map_unique_ptr_as_latest()
+ # 刷新tensor type
+ self.profiler_op_step.update_tensor_types(self.map_unique_ptr2tensor_type)
+ self.print_with_rank(str(self.profiler_op_step))
+
+ self.newest_policy_result.policy_step = self.step
+ self.newest_policy_result.op_names_frequency_list = self.profiler_op_step.get_sorted_op_names()
+
+ def run(self, profiler_op_step: ProfilerDataOneStep, is_new_op_sequence) -> PolicyResult:
+ self.current_profiler_step = profiler_op_step
+ self.profiler_op_step = (
+ profiler_op_step if is_new_op_sequence else self.profiler_all_step[self.newest_policy_result.policy_step]
+ )
+
+ # 汇总参数 上一步的参数,运行时间,policy结果
+ # 自适应迭代 fun,分优先级
+ # 更新参数
+ if is_new_op_sequence:
+ self.process_profiler_data()
+
+ policy_candidates, tensor_size_thresh = self.make_policy()
+ self.newest_policy_result.tensor_size_thresh = tensor_size_thresh
+ self.newest_policy_result.policy_list = policy_candidates
+ self.newest_policy_result.fwd_op_layer_info = self.profiler_op_step.layer_info.fwd_op_layer_info
+ self.newest_policy_result.bwd_op_layer_info = self.profiler_op_step.layer_info.bwd_op_layer_info
+
+ return self.newest_policy_result
+
+ @staticmethod
+ def is_equal_op_sequence(
+ policy_config, cur_sequence: ProfilerDataOneStep, target_sequence: ProfilerDataOneStep = None
+ ) -> bool:
+ """
+ Compare how different cur_sequence is from target_sequence, and return a ratio.
+ 暂时先只比较长度
+ """
+ if target_sequence is None:
+ return False
+ target_len = cur_sequence.length
+ cur_len = target_sequence.length
+ return abs(target_len - cur_len) / cur_len < policy_config.op_diff_thresh
+
+ def record_parameters(self):
+ self.step_parameters[self.step] = {
+ "duration_time": swap_policy_config.duration_time,
+ "size_coverage_weight": swap_policy_config.size_coverage_weight,
+ "redundant_memory": swap_policy_config.redundant_memory,
+ }
+
+ def set_parameters(self):
+ swap_step = list(self.step_parameters.keys())
+ min_duration = min(self.all_step_duration[i] for i in swap_step)
+ best_step = [key for key, value in self.all_step_duration.items() if value == min_duration][0]
+ swap_policy_config.duration_time = self.step_parameters[best_step]["duration_time"]
+ swap_policy_config.size_coverage_weight = self.step_parameters[best_step]["size_coverage_weight"]
+ swap_policy_config.redundant_memory = self.step_parameters[best_step]["redundant_memory"]
+
+ def adjust_parameters(self):
+ setattr(
+ swap_policy_config,
+ "duration_time",
+ min(
+ getattr(swap_policy_config, "duration_time", float("inf")),
+ self.current_profiler_step.duration_time * swap_policy_config.adjust_step_duration,
+ ),
+ )
+
+ if self.forced_swap_list:
+ swap_policy_config.redundant_memory += swap_policy_config.adjust_memory
+ self.profiler_op_step.init_memory_reduction_list()
+ self.record_parameters()
+ return
+
+ swap_policy_config.size_coverage_weight += swap_policy_config.adjust_size_coverage_weight
+ self.record_parameters()
+
+ def check_policy_valid(self, candidate: SwapPolicyCandidate):
+ # swap out free stage: (swap out op, swap in stage actual)
+ # swap in stage actual: (swap out free stage, swap in op)
+ # swap in free stage: (swap in op, )
+ if not candidate.is_optimizer_or_weight:
+ free_stage_opid = self.profiler_op_step.layer_start_opid[candidate.free_stage]
+ swap_in_stage_actual_opid = self.profiler_op_step.layer_start_opid[candidate.swap_in_stage_actual]
+ swap_in_free_stage_opid = self.profiler_op_step.layer_start_opid[candidate.swap_in_free_stage]
+ swap_out_opid = (
+ candidate.swap_out_op.op_id
+ if not candidate.is_optimizer_or_weight
+ else self.profiler_op_step.layer_start_opid[candidate.swap_out_stage]
+ )
+ swap_in_opid = (
+ candidate.swap_in_op.op_id
+ if not candidate.is_optimizer_or_weight
+ else self.profiler_op_step.layer_start_opid[candidate.swap_in_stage]
+ )
+ if not (free_stage_opid > swap_out_opid and free_stage_opid < swap_in_stage_actual_opid):
+ print(
+ f"Error! swap_out_free_stage_opid [{free_stage_opid}] should be > swap_out_opid [{swap_out_opid}] and < swap_in_stage_actual_opid [{swap_in_stage_actual_opid}]"
+ )
+ return False
+ if not (swap_in_stage_actual_opid < swap_in_opid):
+ print(
+ f"Error! swap_in_stage_actual_opid [{swap_in_stage_actual_opid}] should be < swap_in_opid [{swap_in_opid}]"
+ )
+ return False
+ if not (swap_in_free_stage_opid > swap_in_stage_actual_opid):
+ print(
+ f"Error! swap_in_free_stage_opid [{swap_in_free_stage_opid}] should be > swap_in_stage_actual_opid [{swap_in_stage_actual_opid}]"
+ )
+ return False
+ return True
+
+ @timer
+ def make_policy(self):
+ self.print_with_rank("Making policy ...", print_level=PrintLevel.INFO)
+ self.adjust_parameters()
+ self.profiler_op_step.reset_memory_reduction_list()
+ policy_generator = PolicyGenerator(self.profiler_op_step)
+ policy_generator.select_candidate()
+
+ start_time = time.time()
+ if self.use_custom_policy:
+ policy_generator.simulation(use_custom_policy=True)
+ else:
+ policy_generator.compute_score()
+ while not policy_generator.reduction_target_satisfied():
+ # 寻找能降内存的policy
+ policy_generator.get_intersect_candidates()
+ # 选不出来就退出
+ if not policy_generator.intersect_candidates:
+ self.print_with_rank(f"Fail to reach reduction target ...", print_level=PrintLevel.INFO)
+ break
+ policy_generator.simulation()
+ end_time = time.time()
+ self.print_with_rank(f"policy generate takes {end_time - start_time} seconds.", print_level=PrintLevel.INFO)
+
+ policy_generator.swap_arranger.save_stage_time_left()
+ policy_generator.swap_arranger.set_free_stage()
+
+ if self.use_custom_policy:
+ # create SwapPolicy by providing existing swap list and profiler info
+ curr_swap_policy = SwapPolicy(policy_generator.swap_list, self.profiler_op_step)
+ self.custom_policy_fcn(curr_swap_policy)
+ policy_generator.swap_list = curr_swap_policy.get_candidates()
+ swap_list = policy_generator.get_sorted_swap_list()
+ tensor_size_thresh = (
+ min([candidate.tensor.info.size for candidate in swap_list])
+ if swap_list
+ else swap_policy_config.tensor_size_thresh
+ )
+
+ self.print_with_rank(
+ (
+ f"\n\tCurrent Step: {self.current_profiler_step.step}, "
+ f"Policy Step: {self.profiler_op_step.step}, "
+ f"Max Memory: {self.profiler_op_step.max_memory}, "
+ f"Target Memory: {self.profiler_op_step.target_memory}, "
+ f"Duration Time: {swap_policy_config.duration_time}, "
+ f"Size Cov Weight: {swap_policy_config.size_coverage_weight}, "
+ f"\n\tCandidate Num: {len(policy_generator.policy_candidate_list)}, "
+ f"Policy Num: {len(swap_list)}, "
+ f"Optim Num: {len([i for i in swap_list if i.tensor.info.tensor_type == SwapTensorType.OPTIM])}, "
+ f"Model Num: {len([i for i in swap_list if i.tensor.info.tensor_type != SwapTensorType.OPTIM])}, "
+ f"Min Tensor Size: {tensor_size_thresh}"
+ ),
+ print_level=PrintLevel.INFO,
+ )
+ if swap_policy_config.save_policy:
+ self.save_policy_list(swap_list)
+ return swap_list, tensor_size_thresh
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_manager.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..e53c32bd03a2de707a939de431088842bdf48b64
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_manager.py
@@ -0,0 +1,235 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import time
+from enum import Enum
+from collections.abc import Iterable
+
+import torch
+
+from .hooks import register_swap_hooks_to_modules
+from .swap_policy_config import swap_policy_config
+from .swap_utils import print_with_rank, PrintLevel
+from .swap_cpp_adaptor import (
+ SwapConfig,
+ ProfilerDataOneStep,
+ SwapStageType,
+ SwapTensorType,
+ record_tensor_ptr_with_types,
+ get_smart_swap_cpp,
+)
+from .swap_engine import SwapEngine
+
+
+def singleton(cls):
+ instances = {}
+
+ def get_instance(*args, **kwargs):
+ if cls not in instances:
+ instances[cls] = cls(*args, **kwargs)
+ return instances[cls]
+
+ return get_instance
+
+
+class SwapRunningStage(Enum):
+ WARMUP_STAGE = 0 # Warmup阶段:opSequence不稳定
+ SEARCHING_POLICY_STAGE = 1 # 迭代策略阶段:opSequence稳定,可能有OOM,策略不稳定
+ STABLE_STAGE = 2 # 策略稳定阶段:opSequence稳定,策略稳定
+ RESERVED = 3
+
+
+@singleton
+class SwapManager:
+ def __init__(
+ self,
+ num_micro_batch_fcn,
+ models,
+ num_layers,
+ optimizer=None,
+ get_optimizer_tensors_fcn=None,
+ get_shared_tensors_fcn=None,
+ custom_policy_fcn=None,
+ ):
+ if torch.distributed.is_initialized():
+ swap_policy_config.rank = torch.distributed.get_rank()
+
+ option = {"OP_HOOK_ENABLE": "enable"}
+ torch.npu.set_option(option)
+
+ self.smart_swap_cpp = get_smart_swap_cpp()
+ self.smart_swap_cpp.init_cpp_manager()
+ self.smart_swap_cpp.NPUSwapManager.GetInstance().swap_enable = True
+ self.smart_swap_cpp.NPUSwapManager.GetInstance().swap_oom_enable = True
+ self.config = SwapConfig()
+ self.num_micro_batch_fcn = num_micro_batch_fcn
+ self.models = models
+ self.get_shared_tensors_fcn = get_shared_tensors_fcn
+ self.swap_hook_registers: list = []
+ self.swap_engine = SwapEngine(models, optimizer, get_optimizer_tensors_fcn, self.config, custom_policy_fcn)
+ self.start_time = time.time()
+ self.cur_warmup_step = 0
+ self.running_stage = SwapRunningStage.RESERVED
+ self.is_new_op_sequence = True
+ self.model_num_layers = num_layers
+ self.global_initialize()
+
+ def __del__(self):
+ option = {"OP_HOOK_ENABLE": "disable"}
+ torch.npu.set_option(option)
+ self.smart_swap_cpp.deinit_cpp_manager()
+
+ def __check_layer_param(self, model_num_layers):
+ if not isinstance(model_num_layers, int):
+ raise ValueError("model_num_layers must be an integer.")
+ if model_num_layers != -1 and model_num_layers <= 0:
+ raise ValueError("model_num_layers must be a positive integer or -1.")
+
+ def print_with_rank(self, message, print_level=PrintLevel.DEBUG):
+ print_with_rank(message, prefix="SwapManager", print_level=print_level)
+
+ def global_initialize(self):
+ stage = self.config.stage
+ stage.stage_type = SwapStageType.INIT
+ self.config.stage = stage
+ self.config.step = 0
+ self.config.micro_batch_num = self.num_micro_batch_fcn()
+ self.config.fwd_op_layer_info = []
+ self.config.bwd_op_layer_info = []
+ self.register_model_hooks(self.models)
+ self.record_shared_memory(self.models)
+ self.start_time = time.time()
+ self.init_for_new_op_seq()
+ self.config.enable_profiler = True
+ self.config.enable_executor = False
+ self.config.enable_custom_record_stream = swap_policy_config.enable_custom_record_stream
+ self.__check_layer_param(self.model_num_layers)
+ swap_policy_config.logical_layer_num = (
+ -1 if self.model_num_layers < 0 else (10 // self.model_num_layers + 1) * self.model_num_layers
+ )
+
+ def record_model_tensor_type(self, models):
+ tensors = []
+ for model in models:
+ # MODEL
+ for name, param in model.named_parameters():
+ tensors.append(param.data)
+
+ record_tensor_ptr_with_types(tensors, SwapTensorType.MODEL, 0, False)
+
+ def record_shared_memory(self, models):
+ if models and self.get_shared_tensors_fcn:
+ tensors = self.get_shared_tensors_fcn(models)
+ record_tensor_ptr_with_types(tensors, SwapTensorType.SHARED_MEMORY, 0, True)
+
+ def init_for_new_op_seq(self):
+ self.print_with_rank("Call init_for_new_op_seq")
+ self.running_stage = SwapRunningStage.WARMUP_STAGE
+ self.swap_engine.clear_policy()
+ self.is_new_op_sequence = True
+ self.cur_warmup_step = 0
+
+ def step(self):
+ end_time = time.time()
+ self.config.one_step_duration = end_time - self.start_time
+ for swap_hook_register in self.swap_hook_registers:
+ swap_hook_register.reset()
+ self.config.micro_batch_num = self.num_micro_batch_fcn()
+ profiler_data_one_step = ProfilerDataOneStep(
+ self.config.one_step_duration, self.config.step, self.config.is_oom, self.config.enable_profiler
+ )
+ self.swap_engine.append_profiler_data(profiler_data_one_step)
+ self.swap_engine.all_step_duration[self.swap_engine.step] = self.config.one_step_duration
+
+ self.print_with_rank(
+ (
+ f"Step: {self.config.step}, Time elapsed: {self.config.one_step_duration}, "
+ f"Logical layer num: {swap_policy_config.logical_layer_num}, "
+ f"Op num: {len(profiler_data_one_step.op_list)}, "
+ f"Current running stage: {self.running_stage.name}, OOM state: {self.config.is_oom}"
+ ),
+ print_level=PrintLevel.INFO,
+ )
+ self.print_with_rank(
+ ("OOM swap: \n" + "\n".join(str(i) for i in profiler_data_one_step.swap_list if i.is_oom)),
+ print_level=PrintLevel.INFO,
+ )
+ self.print_with_rank(f"{str(profiler_data_one_step)}")
+
+ if self.running_stage == SwapRunningStage.WARMUP_STAGE:
+ if self.swap_engine.is_similar_with_policy_profiler(profiler_data_one_step):
+ self.cur_warmup_step += 1
+ if self.cur_warmup_step == swap_policy_config.warmup_step:
+ self.running_stage = SwapRunningStage.SEARCHING_POLICY_STAGE
+ elif self.running_stage == SwapRunningStage.SEARCHING_POLICY_STAGE:
+ self.cur_warmup_step += 1
+ if not self.swap_engine.is_similar_with_policy_profiler(profiler_data_one_step):
+ self.init_for_new_op_seq()
+ elif self.cur_warmup_step == swap_policy_config.stable_step:
+ self.running_stage = SwapRunningStage.STABLE_STAGE
+ elif self.running_stage == SwapRunningStage.STABLE_STAGE:
+ if self.swap_engine.forced_swap_list:
+ self.init_for_new_op_seq()
+ else:
+ raise RuntimeError(f"Get incorrect running_stage: {self.running_stage.name}")
+
+ self.print_with_rank(f"Change running stage to: {self.running_stage.name}", print_level=PrintLevel.INFO)
+ if self.running_stage == SwapRunningStage.WARMUP_STAGE:
+ self.config.enable_profiler = True
+ self.config.enable_executor = False
+ elif self.running_stage == SwapRunningStage.SEARCHING_POLICY_STAGE:
+ self.config.enable_profiler = True
+ self.config.enable_executor = True
+ policy_result = self.swap_engine.run(profiler_data_one_step, self.is_new_op_sequence)
+ policy_result.set_py_swap_policy_to_cpp(self.config)
+ self.smart_swap_cpp.updateStep()
+ self.is_new_op_sequence = False
+ self.print_with_rank(f"Policy result:\n{policy_result}", print_level=PrintLevel.DEBUG)
+ elif self.running_stage == SwapRunningStage.STABLE_STAGE:
+ self.config.enable_profiler = False
+ self.config.enable_executor = True
+ self.smart_swap_cpp.updateStep()
+ else:
+ raise RuntimeError(f"Get incorrect running_stage: {self.running_stage.name}")
+
+ self.print_with_rank(
+ (
+ f"All step duration: "
+ f"{[(step, time) for step, time in self.swap_engine.all_step_duration.items()]}\n\n"
+ ),
+ print_level=PrintLevel.INFO,
+ )
+
+ self.config.step += 1
+ self._update_config_for_step_hook(SwapStageType.INIT, 0, 0, 0)
+ self.start_time = time.time()
+
+ def _update_config_for_step_hook(
+ self, stage_type: SwapStageType, layer_index, micro_batch_index, current_stage_op_id
+ ):
+ stage = self.config.stage
+ stage.stage_type = stage_type
+ stage.layer_index = layer_index
+ stage.micro_batch_index = micro_batch_index
+
+ self.config.stage = stage
+ self.config.current_stage_op_id = current_stage_op_id
+
+ def fwd_pre_hook_custom_func(self, _, fwd_idx):
+ self._update_config_for_step_hook(SwapStageType.FWD, 1, fwd_idx, 0)
+
+ def bwd_pre_hook_custom_func(self, _, bwd_idx):
+ self._update_config_for_step_hook(SwapStageType.BWD, 1, bwd_idx, 0)
+
+ def bwd_post_hook_custom_func(self, _, bwd_idx):
+ if bwd_idx == self.num_micro_batch_fcn():
+ self._update_config_for_step_hook(SwapStageType.OPTIM, 0, 0, 0)
+
+ def register_model_hooks(self, models):
+ if not isinstance(models, Iterable):
+ models = [models]
+ for model in models:
+ swap_hook_register = register_swap_hooks_to_modules(model)
+ swap_hook_register.register_custom_func(
+ self.fwd_pre_hook_custom_func, None, self.bwd_pre_hook_custom_func, self.bwd_post_hook_custom_func
+ )
+ self.swap_hook_registers.append(swap_hook_register)
+ self.print_with_rank("Register model swap hooks completed.")
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_megatron_adaptor.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_megatron_adaptor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d8edc3d47b63ff7fa19480e1cc77b8cca15f3a7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_megatron_adaptor.py
@@ -0,0 +1,72 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+from functools import wraps
+
+from megatron.training.training import get_num_microbatches
+from megatron.training import get_args
+
+from .swap_manager import SwapManager
+
+
+def megatron_get_optimizer_tensors_fcn(optimizer):
+ results = []
+ for group in optimizer.optimizer.param_groups:
+ amsgrad = group["amsgrad"]
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ results.append(p.data)
+
+ state = optimizer.optimizer.state[p]
+ if len(state) > 0:
+ results.append(state["exp_avg"])
+ results.append(state["exp_avg_sq"])
+ if amsgrad:
+ results.append(state["max_exp_avg_sq"])
+
+ return results
+
+
+def megatron_get_shared_tensors_fcn(models):
+ results = []
+ for model in models:
+ # SHARED_MEMORY
+ if model.buffers is not None:
+ for buffer in model.buffers:
+ if buffer.grad_data is not None:
+ results.append(buffer.grad_data)
+ if buffer.param_data is not None:
+ results.append(buffer.param_data)
+ return results
+
+
+def MegatronSwapManager(train_step_args, cmd_args):
+ """
+ Adapter to the megatron's train_step function.
+ train_step_args is from the arguments of train_step.
+ cmd_args is obtained from get_args() from megatron.
+ """
+ if len(train_step_args) < 4:
+ raise ValueError("The length of arguments should be more than 4")
+ model = train_step_args[2]
+ optimizer = train_step_args[3]
+ num_layers = cmd_args.num_layers
+ return SwapManager(
+ get_num_microbatches,
+ model,
+ cmd_args.num_layers,
+ optimizer=optimizer,
+ get_optimizer_tensors_fcn=megatron_get_optimizer_tensors_fcn,
+ get_shared_tensors_fcn=megatron_get_shared_tensors_fcn,
+ )
+
+
+def train_step_wrapper(train_step):
+ @wraps(train_step)
+ def wrapper(*args, **kwargs):
+ args_ = get_args()
+ manager = MegatronSwapManager(args, args_)
+ ret = train_step(*args, **kwargs)
+ manager.step()
+ return ret
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_policy_config.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_policy_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0efce52c44291b4d90bbd4ba90b45e9b716362e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_policy_config.py
@@ -0,0 +1,49 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+class SwapPolicyConfig:
+ def __init__(self):
+ # utils
+ self.rank = 0 # 获取当前rank
+
+ self.save_policy = False
+ self.save_profiler_data = False
+
+ self.print_level = 1 # 设置print级别 DEBUG=0, INFO=1, NONE=2
+ self.print_rank = 0 # 设置打印信息的卡, -1打印所有卡
+ self.output_root_path = "./swap_output"
+
+ # 执行
+ self.warmup_step = 2 # 多少步之后进入SEARCHING_POLICY_STAGE
+ self.stable_step = 10 # 多少步之后进入STABLE_STAGE
+
+ self.op_diff_thresh = 0.05
+ self.tensor_size_thresh = 2**31 - 1
+
+ self.enable_custom_record_stream = True
+ self.free_stage_delay = 4 # 表示将swap out任务的内存延后N个stage强制释放
+ self.swap_in_free_stage_delay = 2 # 表示将swap in任务的内存延后N个stage强制释放
+
+ # 带宽设置
+ self.D2H_bandwidth = 64 / 2.5 * 1000
+ self.H2D_bandwidth = 64 / 2.5 * 1000
+
+ # 内存目标设置
+ # OOM场景: 降低到 device最大内存 - redundant_memory 内存目标
+ # 如果后续迭代中仍触发OOM swap, target_memory 将每步减少 adjust_memory 大小
+ # 非OOM场景: target_mode = True 指降低至 target_memory 内存目标
+ # target_mode = False 指仅降低 reduction_memory 内存目标
+ self.target_mode = False
+ self.reduction_memory = 3 * 1024 * 1024 * 1024 # 手动设置目标内存
+ self.target_memory = 40 * 1024 * 1024 * 1024 # 手动设置目标内存
+ self.tensor_size_filter = 20 * 1024 * 1024 # 设置tensor size的过滤, 小于20MB的不会被选为candidate
+
+ self.redundant_memory = 2 * 1024 * 1024 * 1024
+ self.size_coverage_weight = 2 # 以coverage weight为1, size比之的比例
+ self.adjust_memory = 300 * 1024 * 1024 # 自动化调整 redundant_memory
+ self.adjust_step_duration = 1 # 自动化调整duration time, 将得到的step duration乘以这个数值, 并与历史的取最小值
+ self.adjust_size_coverage_weight = 0 # size_coverage_weight 每次递增这个数值
+
+ def __str__(self):
+ return str(self.__dict__)
+
+
+swap_policy_config = SwapPolicyConfig()
diff --git a/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_utils.py b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ede692962cec51c09c67af6927a7e7693dd2837a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/memory/smart_swap/swap_utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import time
+from enum import Enum
+
+from .swap_policy_config import swap_policy_config
+
+
+class PrintLevel(Enum):
+ DEBUG = 0
+ INFO = 1
+ NONE = 2
+
+
+def print_with_rank(message, prefix="", print_level=PrintLevel.DEBUG):
+ if swap_policy_config.print_level > print_level.value:
+ return
+
+ rank = swap_policy_config.rank
+ print_rank = swap_policy_config.print_rank
+ if print_rank == -1:
+ print(f"[{print_level.name}] rank[{rank}] [{prefix}]: {message}", flush=True)
+ else:
+ if rank == print_rank:
+ print(f"[{print_level.name}] rank[{rank}] [{prefix}]: {message}", flush=True)
+
+
+def timer(func):
+ def wrapper(*args, **kwargs):
+ start_time = time.time()
+ result = func(*args, **kwargs)
+ end_time = time.time()
+ print_with_rank(
+ f"Function {func.__name__} takes {end_time - start_time} seconds to execute.",
+ prefix="timer",
+ print_level=PrintLevel.INFO,
+ )
+ return result
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/mindspeed_parallel_group.py b/model/train/yoco_moe/mindspeed/core/mindspeed_parallel_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60a1f91a425f46c683450ccf23a825b58bff94c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/mindspeed_parallel_group.py
@@ -0,0 +1,75 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import abc
+
+import torch
+import torch.distributed
+from mindspeed.core.simple_parallel_cfg import SimpleParallelCfg
+
+
+class MindspeedParallelGroup:
+ def __init__(
+ self,
+ parallel_cfg: SimpleParallelCfg = None,
+ pg_name: str = None,
+ overlap_gp_name: str = None,
+ nccl_comm_cfgs=None,
+ ):
+ """Parallel group interface, any type of parallelism class can implement this class.
+
+ :param parallel_cfg: Parallel configuration.
+ :param pg_name: parallel process group name.
+ :param overlap_gp_name: overlap process name, for the send/rcv parallel.
+ :param nccl_comm_cfgs:
+ """
+ self._pg_name = pg_name
+ self._overlap_pg_name = overlap_gp_name
+ self._group, self._global_ranks, self._overlap_group = self.init_group(
+ parallel_cfg, pg_name, overlap_gp_name, nccl_comm_cfgs
+ )
+
+ @staticmethod
+ @abc.abstractmethod
+ def init_group(
+ parallel_cfg: SimpleParallelCfg,
+ pg_name: str,
+ overlap_gp_name: str = None,
+ nccl_comm_cfgs=None,
+ ):
+ raise NotImplementedError
+
+ @property
+ def group(self):
+ return self._group
+
+ @property
+ def overlap_group(self):
+ return self._overlap_group
+
+ @property
+ def global_ranks(self):
+ return self._global_ranks
+
+ def get_parallel_rank(self):
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ return torch.distributed.get_rank(group=self.group)
+ else:
+ raise AssertionError("The distribution is not available or not initialized.")
+
+ def get_parallel_group_world_size(self):
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ return torch.distributed.get_world_size(group=self.group)
+ else:
+ return 0
diff --git a/model/train/yoco_moe/mindspeed/core/models/common/embeddings/language_model_embedding.py b/model/train/yoco_moe/mindspeed/core/models/common/embeddings/language_model_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..58aa7a99ac0f8e62a2016ed0eec89e1c1a845ac1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/models/common/embeddings/language_model_embedding.py
@@ -0,0 +1,34 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import wraps
+from megatron.training import get_args
+from mindspeed.core.tensor_parallel.comm_autograd_function import (
+ auto_grad_scatter_along_first_dim_then_last_dim,
+)
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+
+
+def language_model_embedding_forward_wrapper(forward):
+ @wraps(forward)
+ def wrapper(self, *args, **kwargs):
+ encoder_input = forward(self, *args, **kwargs)
+ if get_args().tp_2d:
+ encoder_input = auto_grad_scatter_along_first_dim_then_last_dim(
+ encoder_input, TPXCollectiveComm, TPYCollectiveComm
+ )
+ return encoder_input
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/models/common/embeddings/rotary_pos_embedding.py b/model/train/yoco_moe/mindspeed/core/models/common/embeddings/rotary_pos_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee99257a2c751baa658d14644cdd438b6945cc5f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/models/common/embeddings/rotary_pos_embedding.py
@@ -0,0 +1,386 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import math
+import torch
+from torch import Tensor
+from functools import wraps
+
+from megatron.core.models.common.embeddings.rotary_pos_embedding import _rotate_half
+from megatron.training import get_args
+from megatron.core import parallel_state
+from mindspeed.utils import get_position_ids, generate_rearrange_idx_tensor
+from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding
+
+from mindspeed.core.parallel_state import (get_context_parallel_for_hybrid_ulysses_world_size,
+ get_context_parallel_for_hybrid_ulysses_rank,
+ get_context_parallel_for_hybrid_ring_world_size,
+ get_context_parallel_for_hybrid_ring_rank)
+from mindspeed.core.context_parallel.utils import get_remapped_seq_order
+from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
+
+
+def yarn_find_correction_dim(
+ num_rotations, dim, base=10000, max_position_embeddings=2048
+):
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
+ 2 * math.log(base)
+ )
+
+
+def yarn_find_correction_range(
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
+):
+ low = math.floor(
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
+ )
+ high = math.ceil(
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
+ )
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
+
+
+def yarn_get_mscale(scale=1, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
+
+
+def yarn_linear_ramp_mask(min_, max_, dim):
+ if min_ == max_:
+ max_ += 0.001 # Prevent singularity
+
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min_) / (max_ - min_)
+ ramp_func = torch.clamp(linear_func, 0, 1)
+ return ramp_func
+
+
+def apply_rotary_pos_emb_bshd(t: Tensor, freqs: Tensor, rotary_interleaved: bool = False) -> Tensor:
+ args = get_args()
+ _mscale = 1.0
+ if args.rope_scaling_type == "yarn":
+ _mscale = float(
+ yarn_get_mscale(args.rope_scaling_factor, args.rope_scaling_mscale)
+ / yarn_get_mscale(args.rope_scaling_factor, args.rope_scaling_mscale_all_dim)
+ )
+
+ rot_dim = freqs.shape[-1]
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
+ cos_ = (torch.cos(freqs) * _mscale).to(t.dtype)
+ sin_ = (torch.sin(freqs) * _mscale).to(t.dtype)
+
+ if args.use_fused_rotary_pos_emb:
+ mode = 1 if rotary_interleaved else 0
+ t = npu_rotary_position_embedding(t.contiguous(), cos_, sin_, mode).to(t.dtype)
+ else:
+ t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
+
+ return torch.cat((t, t_pass), dim=-1)
+
+
+def apply_yarn_scaling(freqs: torch.Tensor):
+ args = get_args()
+
+ scaling_factor = args.rope_scaling_factor
+ dim = args.qk_rope_head_dim if args.multi_head_latent_attention else (args.hidden_size // args.num_attention_heads)
+ rotary_ratio = args.rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=freqs.device) / dim)
+ freq_extra = 1.0 / rotary_ratio
+ freq_inter = 1.0 / (scaling_factor * rotary_ratio)
+ low, high = yarn_find_correction_range(
+ args.rope_scaling_beta_fast,
+ args.rope_scaling_beta_slow,
+ dim,
+ args.rotary_base,
+ args.rope_scaling_original_max_position_embeddings,
+ )
+
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
+ device=freqs.device, dtype=torch.float32
+ )
+
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
+
+ return inv_freq
+
+
+def rotary_embedding_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ _args = get_args()
+ if _args.rotary_base and ("rotary_base" not in kwargs or kwargs["rotary_base"] == 10000): # default value
+ kwargs["rotary_base"] = _args.rotary_base
+ fn(self, *args, **kwargs)
+ if hasattr(_args, "rope_scaling_type") and _args.rope_scaling_type == "yarn":
+ self.inv_freq = apply_yarn_scaling(self.inv_freq)
+
+ return wrapper
+
+
+def rotary_forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
+ """Forward pass of RoPE embedding.
+
+ Args:
+ max_seq_len (int): Maximum size of sequence
+ offset (int, optional): _description_. Defaults to 0.
+
+ Returns:
+ Tensor: Embeddings after applying RoPE.
+ """
+ if self.inv_freq.device.type == 'cpu':
+ # move `inv_freq` to GPU once at the first micro-batch forward pass
+ self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device())
+ seq = (
+ torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ + offset
+ )
+
+ if self.seq_len_interpolation_factor is not None:
+ seq *= 1 / self.seq_len_interpolation_factor
+
+ freqs = torch.outer(seq, self.inv_freq)
+ # first part even vector components, second part odd vector components,
+ # 2 * dim in dimension size
+ if not self.rotary_interleaved:
+ emb = torch.cat((freqs, freqs), dim=-1)
+ else:
+ emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
+ freqs.shape[0], -1
+ )
+ # emb [seq_length, .., dim]
+ emb = emb[:, None, None, :]
+
+ return emb
+
+
+def apply_rotary_pos_emb_thd(
+ t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False
+) -> Tensor:
+
+ """A baseline implementation of applying RoPE for `thd` format.
+
+ Args:
+ t (Tensor): Input tensor T is of shape [t, h, d]
+ cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
+ with shape [b + 1] and dtype torch.int32.
+ freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
+
+ Returns:
+ Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
+ """
+ args = get_args()
+
+ position_ids = cu_seqlens.position_ids
+ block_size, bsz = position_ids.shape
+ freqs = freqs[position_ids.view(-1)].reshape(block_size, bsz, 1, -1)
+
+ return apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved)
+
+
+def get_pos_emb_on_this_cp_rank(pos_emb, seq_dim):
+ args = get_args()
+
+ cp_expanded_by_2d_tp = args.tp_y > 1
+ if args.context_parallel_algo == 'megatron_cp_algo':
+ if args.attention_mask_type == 'general':
+ pos_emb = _get_pos_emb_on_this_cp_rank_in_ulysses_cp(pos_emb, seq_dim)
+ elif cp_expanded_by_2d_tp:
+ pos_emb = _get_pos_emb_on_this_tp_y_cp_rank_in_megatron_cp(pos_emb, seq_dim)
+ elif args.reset_position_ids and args.attention_mask_type == 'causal':
+ return pos_emb
+ else:
+ pos_emb = _get_pos_emb_on_this_cp_rank_in_megatron_cp(pos_emb, seq_dim)
+ elif args.context_parallel_algo == 'ulysses_cp_algo':
+ if cp_expanded_by_2d_tp:
+ pos_emb = _get_pos_emb_on_this_tp_y_cp_rank_in_ulysses_cp(pos_emb, seq_dim)
+ else:
+ pos_emb = _get_pos_emb_on_this_cp_rank_in_ulysses_cp(pos_emb, seq_dim)
+ elif args.context_parallel_algo == 'hybrid_cp_algo':
+ if args.attention_mask_type == 'general':
+ pos_emb = _get_pos_emb_on_this_cp_rank_in_hybrid_cp_general(pos_emb, seq_dim)
+ else:
+ pos_emb = _get_pos_emb_on_this_cp_rank_in_hybrid_cp(pos_emb, seq_dim)
+ elif args.context_parallel_algo == 'adaptive_cp_algo':
+ pos_emb = _get_pos_emb_on_this_cp_rank_in_adaptive_cp(pos_emb, seq_dim)
+ elif args.context_parallel_algo == 'hybrid_adaptive_cp_algo':
+ pos_emb = _get_pos_emb_on_this_cp_rank_in_hybrid_adaptive_cp(pos_emb, seq_dim)
+ return pos_emb
+
+
+def _get_pos_emb_on_this_cp_rank_in_megatron_cp(pos_emb, seq_dim):
+ cp_size = parallel_state.get_context_parallel_world_size()
+ cp_rank = parallel_state.get_context_parallel_rank()
+ cp_idx = torch.tensor(
+ [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
+ ).cuda(non_blocking=True)
+ pos_emb = pos_emb.view(
+ *pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]
+ )
+ pos_emb = pos_emb.index_select(seq_dim, cp_idx)
+ pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
+ return pos_emb
+
+
+def _get_pos_emb_on_this_tp_y_cp_rank_in_megatron_cp(pos_emb, seq_dim):
+ origin_pos_emb_shape = pos_emb.shape
+ tp_y_cp_group = TensorParallelYUnionCP()
+ tp_y_cp_size = tp_y_cp_group.get_parallel_group_world_size()
+ # [s, 1, 1, head_dim] ---> [2*tp_y_cp_size, s/(2*tp_y_cp_size), 1, 1, head_dim]
+ pos_emb = pos_emb.view(
+ *pos_emb.shape[:seq_dim], 2 * tp_y_cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]
+ )
+ rearrange_idx_tensor = generate_rearrange_idx_tensor(tp_y_cp_size)
+
+ # Reorder pos embedding according dataset handling.
+ # selected res shape: [2 * tp_y_cp_size, s / (2 * tp_y_cp_size), 1, 1, head_dim]
+ pos_emb = pos_emb.index_select(seq_dim, index=rearrange_idx_tensor)
+ pos_emb = pos_emb.view(*origin_pos_emb_shape)
+ # viewed res shape: [tp_y_cp_sz, s/tp_y_cp_sz, 1, head_dim]
+ pos_emb = pos_emb.view(
+ *pos_emb.shape[0:seq_dim],
+ tp_y_cp_size,
+ pos_emb.shape[seq_dim] // tp_y_cp_size,
+ *pos_emb.shape[(seq_dim + 1):],
+ )
+ # cur_rank_pos_emb shape: [s/cp, 1, 1, head_dim]
+ tp_y_cp_rank = tp_y_cp_group.get_parallel_rank()
+ cur_rank_pos_emb = pos_emb[tp_y_cp_rank].squeeze(axis=0)
+ return cur_rank_pos_emb
+
+
+def _get_pos_emb_on_this_cp_rank_in_ulysses_cp(pos_emb, seq_dim):
+ cp_size = parallel_state.get_context_parallel_world_size()
+ cp_rank = parallel_state.get_context_parallel_rank()
+ pos_emb = pos_emb.chunk(cp_size, dim=seq_dim)[cp_rank]
+
+ return pos_emb
+
+
+def _get_pos_emb_on_this_cp_rank_in_hybrid_cp(pos_emb, seq_dim):
+ u_size = get_context_parallel_for_hybrid_ulysses_world_size()
+ r_size = get_context_parallel_for_hybrid_ring_world_size()
+ u_rank = get_context_parallel_for_hybrid_ulysses_rank()
+ r_rank = get_context_parallel_for_hybrid_ring_rank()
+
+ cp_idx = torch.tensor(
+ [r_rank, (2 * r_size - r_rank - 1)], device="cpu", pin_memory=True
+ ).cuda(non_blocking=True)
+ pos_emb = pos_emb.view(
+ *pos_emb.shape[:seq_dim], 2 * r_size, -1, *pos_emb.shape[(seq_dim + 1) :]
+ )
+ pos_emb = pos_emb.index_select(seq_dim, cp_idx)
+ pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
+
+ pos_emb = pos_emb.chunk(u_size, dim=seq_dim)[u_rank]
+
+ return pos_emb
+
+
+def _get_pos_emb_on_this_cp_rank_in_hybrid_cp_general(pos_emb, seq_dim):
+ u_size = get_context_parallel_for_hybrid_ulysses_world_size()
+ r_size = get_context_parallel_for_hybrid_ring_world_size()
+ u_rank = get_context_parallel_for_hybrid_ulysses_rank()
+ r_rank = get_context_parallel_for_hybrid_ring_rank()
+
+ pos_emb = pos_emb.chunk(r_size, dim=seq_dim)[r_rank]
+ pos_emb = pos_emb.chunk(u_size, dim=seq_dim)[u_rank]
+
+ return pos_emb
+
+
+def _get_pos_emb_on_this_cp_rank_in_adaptive_cp(pos_emd, seq_dim):
+ cp_size = parallel_state.get_context_parallel_world_size()
+ cp_rank = parallel_state.get_context_parallel_rank()
+
+ remapped_seq_order = get_remapped_seq_order()
+ if remapped_seq_order is not None:
+ per = pos_emd.shape[seq_dim] // cp_size
+ index = torch.tensor(remapped_seq_order[cp_rank * per:(cp_rank + 1) * per], dtype=torch.int,
+ device=pos_emd.device)
+ pos_emd = pos_emd.index_select(seq_dim, index)
+
+ return pos_emd
+
+
+def _get_pos_emb_on_this_cp_rank_in_hybrid_adaptive_cp(pos_emd, seq_dim):
+ ulys_size = get_context_parallel_for_hybrid_ulysses_world_size()
+ adap_size = get_context_parallel_for_hybrid_ring_world_size()
+ ulys_rank = get_context_parallel_for_hybrid_ulysses_rank()
+ adap_rank = get_context_parallel_for_hybrid_ring_rank()
+
+ remapped_seq_order = get_remapped_seq_order()
+ if remapped_seq_order is not None:
+ per = pos_emd.shape[seq_dim] // adap_size // ulys_size
+ which_per = adap_rank * ulys_size + ulys_rank
+ index = torch.tensor(remapped_seq_order[which_per * per:(which_per + 1) * per], dtype=torch.int,
+ device=pos_emd.device)
+ pos_emd = pos_emd.index_select(seq_dim, index)
+
+ return pos_emd
+
+
+def rotary_embedding_forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
+ """Forward pass of RoPE embedding.
+
+ Args:
+ max_seq_len (int): Maximum size of sequence
+ offset (int, optional): _description_. Defaults to 0.
+
+ Returns:
+ Tensor: Embeddings after applying RoPE.
+ """
+ seq = (
+ torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ + offset
+ )
+
+ if self.seq_len_interpolation_factor is not None:
+ seq *= 1 / self.seq_len_interpolation_factor
+
+ freqs = torch.outer(seq, self.inv_freq)
+ # first part even vector components, second part odd vector components,
+ # 2 * dim in dimension size
+ if not self.rotary_interleaved:
+ emb = torch.cat((freqs, freqs), dim=-1)
+ else:
+ emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
+ freqs.shape[0], -1
+ )
+ # emb [seq_length, .., dim]
+ emb = emb[:, None, None, :]
+ global_args = get_args()
+ cp = global_args.context_parallel_size
+ if global_args.tp_2d:
+ tp_y_cp_sz = cp * global_args.tp_y
+ else:
+ tp_y_cp_sz = cp
+ if tp_y_cp_sz > 1:
+ # slice rotary_pos_emb along sequence dimension and select the parition of the current CP rank
+ emb = get_pos_emb_on_this_cp_rank(emb, 0)
+ return emb
+
+
+def rotary_embedding_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, max_seq_len: int, offset: int = 0):
+ return rotary_embedding_forward(self, max_seq_len, offset)
+
+ return wrapper
+
+
+def _get_pos_emb_on_this_tp_y_cp_rank_in_ulysses_cp(pos_emb, seq_dim):
+ tp_y_cp_group = TensorParallelYUnionCP()
+ tp_y_cp_size = tp_y_cp_group.get_parallel_group_world_size()
+
+ cp_rank = tp_y_cp_group.get_parallel_rank()
+ pos_emb = pos_emb.chunk(tp_y_cp_size, dim=seq_dim)[cp_rank]
+ return pos_emb
+
+
+def rotary_embedding_get_rotary_seq_len_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, inference_params, transformer, transformer_input, transformer_config,):
+ rotary_seq_len = fn(self, inference_params, transformer, transformer_input, transformer_config,)
+ global_args = get_args()
+ if global_args.tp_2d:
+ rotary_seq_len *= global_args.tp_x
+ return rotary_seq_len
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/models/gpt/gpt_layer_specs.py b/model/train/yoco_moe/mindspeed/core/models/gpt/gpt_layer_specs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9f5f16cc6977fae161e17205f64349db0a19e22
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/models/gpt/gpt_layer_specs.py
@@ -0,0 +1,101 @@
+from functools import wraps
+from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
+from megatron.core.transformer.attention import SelfAttentionSubmodules
+from megatron.core.transformer.dot_product_attention import DotProductAttention
+from megatron.core.transformer.identity_op import IdentityOp
+from megatron.core.transformer.mlp import MLP, MLPSubmodules
+from megatron.training import get_args
+from megatron.core.transformer.moe.moe_layer import MoELayer
+from megatron.core.transformer.custom_layers.transformer_engine import TENorm
+from megatron.core.transformer.spec_utils import ModuleSpec
+from mindspeed.core.transformer.transformer import norm_recompute_forward
+from mindspeed.core.transformer.transformer_block import NoopTransformerLayer
+from mindspeed.model.transformer import should_recompute_norm
+from mindspeed.core.transformer.moe.tp_2d.moe_layer_2d import MoELayer2D
+import types
+
+
+def get_gpt_layer_local_spec_wrapper(fn):
+ @wraps(fn)
+ def wrapper(num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False):
+ res = fn(num_experts, moe_grouped_gemm, qk_layernorm)
+ args = get_args()
+ if args.multi_head_latent_attention:
+ res.submodules.self_attention.submodules = SelfAttentionSubmodules(
+ linear_qkv=ColumnParallelLinear,
+ core_attention=DotProductAttention,
+ linear_proj=RowParallelLinear,
+ q_layernorm=TENorm if args.qk_layernorm else IdentityOp,
+ k_layernorm=TENorm if args.qk_layernorm else IdentityOp,
+ linear_qb=ColumnParallelLinear,
+ linear_kvb=ColumnParallelLinear
+ )
+ else:
+ if qk_layernorm:
+ res.submodules.self_attention.submodules.q_layernorm = TENorm
+ res.submodules.self_attention.submodules.k_layernorm = TENorm
+ res.submodules.input_layernorm = TENorm
+ res.submodules.pre_mlp_layernorm = TENorm
+ return res
+
+ return wrapper
+
+
+def build_layers_wrapper(fn, column_forward, row_forward):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ fn(self, *args, **kwargs)
+ for layer in self.layers:
+ if isinstance(getattr(layer, 'mlp', None), MoELayer):
+ for local_expert in layer.mlp.experts.local_experts:
+ local_expert.linear_fc1.forward = types.MethodType(column_forward, local_expert.linear_fc1)
+ local_expert.linear_fc2.forward = types.MethodType(row_forward, local_expert.linear_fc2)
+ return wrapper
+
+
+def build_norm_recompute_layer_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ fn(self, *args, **kwargs)
+ for layer in self.layers:
+ if isinstance(layer, NoopTransformerLayer):
+ continue
+ if should_recompute_norm(layer):
+ layer.forward = types.MethodType(norm_recompute_forward, layer)
+ return wrapper
+
+
+def get_mlp_module_spec_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ use_te, num_experts, moe_grouped_gemm = kwargs['use_te'], kwargs['num_experts'], kwargs['moe_grouped_gemm']
+ if num_experts is None:
+ # Dense MLP w/ or w/o TE modules.
+ return ModuleSpec(
+ module=MLP,
+ submodules=MLPSubmodules(
+ linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
+ linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
+ ),
+ )
+ else:
+ # Mixture of experts with modules in megatron core.
+ if use_te and moe_grouped_gemm:
+ linear_fc1 = TEColumnParallelGroupedLinear
+ linear_fc2 = TERowParallelGroupedLinear
+ else:
+ linear_fc1 = ColumnParallelLinear
+ linear_fc2 = RowParallelLinear
+
+ use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None
+
+ return ModuleSpec(
+ module=MoELayer2D,
+ submodules=(
+ MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
+ if not moe_grouped_gemm or use_te_grouped_gemm
+ else None
+ ),
+ )
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/models/gpt/gpt_model.py b/model/train/yoco_moe/mindspeed/core/models/gpt/gpt_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..642d21c214dae1135dcd2db5dcab3a73c0c6aa42
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/models/gpt/gpt_model.py
@@ -0,0 +1,28 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from functools import wraps
+from torch import Tensor
+from megatron.core.packed_seq_params import PackedSeqParams
+from mindspeed.utils import get_actual_seq_len, compute_qkv_index, get_position_ids
+
+
+def gpt_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ actual_seq_len = get_actual_seq_len()
+
+ packed_seq_params = PackedSeqParams(
+ cu_seqlens_q=actual_seq_len,
+ cu_seqlens_kv=actual_seq_len
+ )
+
+ q_index, kv_index = compute_qkv_index(actual_seq_len.clone().tolist())
+ packed_seq_params.q_index = q_index
+ packed_seq_params.kv_index = kv_index
+ packed_seq_params.position_ids = get_position_ids()
+
+ kwargs['packed_seq_params'] = packed_seq_params
+ return fn(*args, **kwargs)
+
+ return wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/parallel_state.py b/model/train/yoco_moe/mindspeed/core/parallel_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..f83afae0ea761e4ccb29898a616e8a4f7a9b299f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/parallel_state.py
@@ -0,0 +1,1223 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Expert parallel groups."""
+import os
+from functools import wraps
+from typing import Optional
+from datetime import timedelta
+
+import torch
+import torch_npu
+import megatron
+from megatron.core.parallel_state import get_context_parallel_world_size, get_nccl_options
+
+from mindspeed.core.simple_parallel_cfg import SimpleParallelCfg
+from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
+
+_CONTEXT_PARALLEL_GROUP_FOR_SEND_RECV_OVERLAP = None
+_CONTEXT_PARALLEL_GROUP_FOR_HYBRID_ULYSSES = None
+_CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING = None
+_PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM = None
+
+_CONTEXT_PARALLEL_RANKS_FOR_HYBRID_ULYSSES = None
+_CONTEXT_PARALLEL_RANKS_FOR_HYBRID_RING = None
+
+_CONTEXT_PARALLEL_RANKS_FOR_RING_INTRA_WINDOW = None
+_CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_KV = None
+_CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_DKV = None
+_CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW = None
+_CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW_SEND_RECV_OVERLAP = None
+
+_TP_X_EP_GROUP = None
+_TP_X_EP_GROUP_WORLD_SIZE = None
+_TP_X_EP_GROUP_RANK = None
+_TP_X_PARALLEL_RING_RANKS = None
+_TP_Y_PARALLEL_RING_RANKS = None
+
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1 = None
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2 = None
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1 = None
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2 = None
+_TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM1 = None
+_TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM2 = None
+_TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM1 = None
+_TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM2 = None
+_TP_X_SD_RCV_OVERLAP_GROUP = None
+_TP_Y_SD_RCV_OVERLAP_GROUP = None
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_RANK = None
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_RANK = None
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_WORLD_SIZE = None
+_TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_WORLD_SIZE = None
+
+_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None
+_TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = None
+
+_HCCL_GROUP_BUFFER = None
+
+
+def parse_hccl_buffer_string(hccl_group_buffer):
+ global _HCCL_GROUP_BUFFER
+
+ if hccl_group_buffer == None:
+ return
+
+ allowed_keys = ["dp", "dp_cp", "cp", "mp", "mp_exp", "tp", "pp", "embd", "tp_dp_cp",
+ "tp_dp", "tp_cp", "tp_exp", "exp", "dp_modulo_exp", "pp_new_stream",
+ "cp2", "cp_ulysses", "cp_ring", "cp_ring_intra", "cp_ring_intra_overlap", "nd1_dim1", "ag_x_sd_rcv_overlap",
+ "nd1_dim2", "ag_y_sd_rcv_overlap", "nd2_dim1", "nd2_dim2"]
+
+ parts = hccl_group_buffer.split(';')
+ for part in parts:
+ key_value = part.split(':')
+ if len(key_value) == 2:
+ key = key_value[0].strip()
+ value_str = key_value[1].strip()
+ key = key.replace(' ', '')
+ value_str = value_str.replace(' ', '')
+ if key in allowed_keys:
+ try:
+ value = int(value_str)
+ if value <= 0:
+ raise RuntimeError(f"Value {value} must be greater than 0")
+ _HCCL_GROUP_BUFFER[key] = value
+ except ValueError:
+ raise RuntimeError(f"{value_str} is not a valid positive integer")
+ else:
+ raise RuntimeError(f"Key {key} is not allowed")
+ else:
+ raise RuntimeError("The str of hccl-group-buffer is not valid")
+
+
+def hccl_buffer_auto_adaptive():
+ import math
+ from megatron.training import get_args
+ args = get_args()
+
+ seq_length = args.seq_length
+ micro_batch_size = args.micro_batch_size
+ hidden_size = args.hidden_size
+
+ context_parallel_size = args.context_parallel_size
+ tensor_model_parallel_size = args.tensor_model_parallel_size
+ expert_model_parallel_size = args.expert_model_parallel_size
+
+ moe_router_topk = args.moe_router_topk
+ moe_token_dispatcher_type = args.moe_token_dispatcher_type
+
+ context_parallel_algo = args.context_parallel_algo
+ num_attention_heads = args.num_attention_heads
+ group_query_attention = args.group_query_attention
+
+ global _HCCL_GROUP_BUFFER
+ #The DP group, DP-CP group, and DP-EP group .Here, we take the default value of 200M.
+
+ #Calculation of the maximum communication volume of the TP group.
+ if moe_token_dispatcher_type is not None and moe_token_dispatcher_type == 'alltoall':
+ #No MOE + No SP, AllReduce MaxComm: S/CP * B * H * 2;No MOE + SP, AllGather MaxComm: S/CP * B * H
+ hccl_tp_buffer_size_mlp = 2 * math.ceil(seq_length / context_parallel_size * micro_batch_size * hidden_size / 1024 / 1024)
+ if args.sequence_parallel:
+ _HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp
+ else:
+ _HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp * 2
+ #MOE and AlltoAll MaxComm: (S/CP/TP * B * H * topK).
+ if args.hccl_ep_group_buffer_adaptive_factor > 0:
+ hccl_tp_buffer_size_moe = 2 * math.ceil(args.hccl_ep_group_buffer_adaptive_factor * seq_length / context_parallel_size / tensor_model_parallel_size * micro_batch_size * hidden_size / 1024 / 1024 * moe_router_topk)
+ else:
+ hccl_tp_buffer_size_moe = 200
+ _HCCL_GROUP_BUFFER['tp'] = max(hccl_tp_buffer_size_moe, _HCCL_GROUP_BUFFER['tp'])
+ else:
+ #MOE + SP, AllReduce MaxComm: S/CP * B * H * 2;No MOE + SP, AllGather MaxComm: S/CP * B * H
+ hccl_tp_buffer_size_mlp = 2 * math.ceil(seq_length / context_parallel_size * micro_batch_size * hidden_size / 1024 / 1024)
+ if args.sequence_parallel:
+ _HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp
+ else:
+ _HCCL_GROUP_BUFFER['tp'] = hccl_tp_buffer_size_mlp * 2
+
+ #Calculation of the maximum communication volume of the PP group.
+ #P2P MaxComm::S/CP/TP * B *H
+ if args.sequence_parallel:
+ hccl_pp_buffer_size = 2 * math.ceil(seq_length / context_parallel_size / tensor_model_parallel_size * micro_batch_size * hidden_size / 1024 / 1024)
+ else:
+ hccl_pp_buffer_size = 2 * math.ceil(seq_length / context_parallel_size * micro_batch_size * hidden_size / 1024 / 1024)
+ _HCCL_GROUP_BUFFER['pp'] = hccl_pp_buffer_size
+ _HCCL_GROUP_BUFFER['pp_new_stream'] = hccl_pp_buffer_size
+
+ #MP & MP-EXP groups for optimizer, based on num of zero gradients and max grad_norm. Just set a constant (default 10M).
+ #It won't be used after the distributed optimizer is enabled.
+ _HCCL_GROUP_BUFFER['mp'] = 10
+ _HCCL_GROUP_BUFFER['mp_exp'] = 10
+
+ #Calculation of the maximum communication volume of the EP group.
+ #Moe of alltoall, MaxComm:S/CP/TP * B * H * Topk
+ if args.hccl_ep_group_buffer_adaptive_factor > 0:
+ hccl_ep_buffer_size = 2 * math.ceil(seq_length / context_parallel_size / tensor_model_parallel_size * micro_batch_size * hidden_size / 1024 / 1024 * moe_router_topk)
+ else:
+ hccl_ep_buffer_size = 200
+ _HCCL_GROUP_BUFFER['exp'] = hccl_ep_buffer_size
+
+ #Calculation of the maximum communication volume of the EP-TP group.
+ #Moe of allgather, MaxComm:S/CP/TP * B * H * EP * TP
+ #Moe of alltoall + moe-tp-extend-ep , MaxComm:S/CP/TP * B * H * topK
+ if moe_token_dispatcher_type is not None and moe_token_dispatcher_type == 'allgather':
+ if args.hccl_ep_group_buffer_adaptive_factor > 0:
+ hccl_tp_ep_buffer_size = 2 * math.ceil(args.hccl_ep_group_buffer_adaptive_factor * seq_length / context_parallel_size * micro_batch_size * hidden_size * expert_model_parallel_size / 1024 / 1024)
+ else:
+ hccl_tp_ep_buffer_size = 200
+ _HCCL_GROUP_BUFFER['tp_exp'] = hccl_ep_buffer_size
+ elif moe_token_dispatcher_type is not None and moe_token_dispatcher_type == 'alltoall' and args.moe_tp_extend_ep:
+ if args.hccl_ep_group_buffer_adaptive_factor > 0:
+ hccl_tp_ep_buffer_size = 2 * math.ceil(args.hccl_ep_group_buffer_adaptive_factor * seq_length / context_parallel_size / tensor_model_parallel_size * micro_batch_size * hidden_size * moe_router_topk / 1024 / 1024)
+ else:
+ hccl_tp_ep_buffer_size = 200
+ _HCCL_GROUP_BUFFER['tp_exp'] = hccl_ep_buffer_size
+
+ #TP-CP group in 8.0 for seq count by experts & Router bal_loss. Small comm vol, set const (default 10M).
+ _HCCL_GROUP_BUFFER['tp_cp'] = 10
+
+ #Calculation of the maximum communication volume of the CP、CP2、CP_Ring、CP_Ulysess group.
+ #CP of RingAttention,SendRecv,MaxComm:S/CP * B * (H / headcount * GQA /TP ) * 2
+ #CP of Ulysess,All2All,MaxComm:S/CP * B * (H / TP)
+ #CP_ulysess & CP_ring like CP in max comm. CP2 is half of CP.
+ if context_parallel_algo == 'ulysses_cp_algo' or context_parallel_algo is None:
+ hccl_cp_buffer_size = 2 * math.ceil(seq_length / context_parallel_size * micro_batch_size * hidden_size / tensor_model_parallel_size / 1024 / 1024)
+ _HCCL_GROUP_BUFFER['cp'] = hccl_cp_buffer_size
+ elif context_parallel_algo == 'megatron_cp_algo' :
+ hccl_cp2_buffer_size = 2 * math.ceil(seq_length / context_parallel_size * micro_batch_size * hidden_size / num_attention_heads * group_query_attention / tensor_model_parallel_size / 1024 / 1024)
+ hccl_cp_buffer_size = 2 * 2 * math.ceil(seq_length / context_parallel_size * micro_batch_size * hidden_size / num_attention_heads * group_query_attention / tensor_model_parallel_size / 1024 / 1024)
+ if args.cp_window_size > 1:
+ if args.use_cp_send_recv_overlap:
+ _HCCL_GROUP_BUFFER['cp2'] = hccl_cp2_buffer_size
+ _HCCL_GROUP_BUFFER['cp'] = hccl_cp2_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp2_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring_intra_overlap'] = hccl_cp2_buffer_size
+ else:
+ _HCCL_GROUP_BUFFER['cp'] = hccl_cp_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp_buffer_size
+ else:
+ if args.use_cp_send_recv_overlap:
+ _HCCL_GROUP_BUFFER['cp2'] = hccl_cp2_buffer_size
+ _HCCL_GROUP_BUFFER['cp'] = hccl_cp2_buffer_size
+ else:
+ _HCCL_GROUP_BUFFER['cp'] = hccl_cp_buffer_size
+ elif context_parallel_algo == 'hybrid_cp_algo':
+ ulysses_context_parallel_size = args.ulysses_degree_in_cp
+ ring_context_parallel_size = context_parallel_size / ulysses_context_parallel_size
+ hccl_cp_ulysess_buffer_size = 2 * math.ceil(seq_length / ulysses_context_parallel_size * micro_batch_size * hidden_size / tensor_model_parallel_size / 1024 / 1024)
+ hccl_cp_ring_buffer_size = 2 * math.ceil(seq_length / ring_context_parallel_size * micro_batch_size * hidden_size / num_attention_heads * group_query_attention / tensor_model_parallel_size / 1024 / 1024)
+ if args.cp_window_size > 1:
+ if args.use_cp_send_recv_overlap:
+ _HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size
+ _HCCL_GROUP_BUFFER['cp2'] = hccl_cp_ring_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp_ring_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring_intra_overlap'] = hccl_cp_ring_buffer_size
+ #The CP group is used to calculate losses. The traffic volume is very small and is given a fixed value of 10M.
+ _HCCL_GROUP_BUFFER['cp'] = 10
+ else:
+ _HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size * 2
+ _HCCL_GROUP_BUFFER['cp_ring_intra'] = hccl_cp_ring_buffer_size * 2
+ #The CP group is used to calculate losses. The traffic volume is very small and is given a fixed value of 10M.
+ _HCCL_GROUP_BUFFER['cp'] = 10
+ else:
+ if args.use_cp_send_recv_overlap:
+ _HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size
+ _HCCL_GROUP_BUFFER['cp2'] = hccl_cp_ring_buffer_size
+ #The CP group is used to calculate losses. The traffic volume is very small and is given a fixed value of 10M.
+ _HCCL_GROUP_BUFFER['cp'] = 10
+ else:
+ _HCCL_GROUP_BUFFER['cp_ulysses'] = hccl_cp_ulysess_buffer_size
+ _HCCL_GROUP_BUFFER['cp_ring'] = hccl_cp_ring_buffer_size * 2
+ #The CP group is used to calculate losses. The traffic volume is very small and is given a fixed value of 10M.
+ _HCCL_GROUP_BUFFER['cp'] = 10
+
+
+def get_nccl_options_wrapper(get_nccl_options):
+ @wraps(get_nccl_options)
+ def wrapper(pg_name, nccl_comm_cfgs):
+ from megatron.training import get_args
+ args = get_args()
+ if args.hccl_group_buffer is not None or args.hccl_group_buffer_adaptive:
+ global _HCCL_GROUP_BUFFER
+ if _HCCL_GROUP_BUFFER.get(pg_name) is not None:
+ options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
+ options.hccl_config = {"hccl_buffer_size":_HCCL_GROUP_BUFFER[pg_name]}
+ return options
+ return get_nccl_options(pg_name, nccl_comm_cfgs)
+ return wrapper
+
+
+def initialize_model_parallel_wrapper(initialize_model_parallel):
+ @wraps(initialize_model_parallel)
+ def wrapper(
+ tensor_model_parallel_size: int = 1,
+ pipeline_model_parallel_size: int = 1,
+ virtual_pipeline_model_parallel_size: Optional[int] = None,
+ pipeline_model_parallel_split_rank: Optional[int] = None,
+ use_sharp: bool = False,
+ context_parallel_size: int = 1,
+ expert_model_parallel_size: int = 1,
+ nccl_communicator_config_path: Optional[str] = None,
+ distributed_timeout_minutes: int = 30,
+ order: str = "tp-cp-ep-dp-pp",
+ ):
+ from megatron.training.utils import print_rank_0
+ from megatron.training import get_args
+ args = get_args()
+
+ global _HCCL_GROUP_BUFFER
+ _HCCL_GROUP_BUFFER = {}
+
+ if args.hccl_group_buffer_adaptive:
+ hccl_buffer_auto_adaptive()
+ print_rank_0(f"hccl_group_buffer_adaptive: {_HCCL_GROUP_BUFFER}")
+
+ if args.hccl_group_buffer is not None:
+ parse_hccl_buffer_string(args.hccl_group_buffer)
+
+ data_parallel_size = 1 # dp 1
+ rank = torch.distributed.get_rank()
+ all_ep_groups = []
+ if order == "tp-cp-ep-dp-pp":
+ # Megatron doesn't allow ep & cp combination, set ep to 1 to bypass that, ep related groups will be regenerated
+ initialize_model_parallel(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ virtual_pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank,
+ use_sharp,
+ context_parallel_size,
+ 1,
+ nccl_communicator_config_path,
+ distributed_timeout_minutes,
+ order
+ )
+
+ world_size: int = torch.distributed.get_world_size()
+ num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
+ num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
+ data_parallel_size: int = world_size // (
+ tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
+ )
+
+ if data_parallel_size * context_parallel_size % expert_model_parallel_size != 0:
+ raise RuntimeError(
+ f"data_parallel_size * context_parallel_size ({data_parallel_size * context_parallel_size}) is not "
+ f"divisible by expert_model_parallel_size "
+ )
+
+ nccl_comm_cfgs = {}
+ if nccl_communicator_config_path is not None:
+ import yaml
+
+ with open(nccl_communicator_config_path, "r") as stream:
+ nccl_comm_cfgs = yaml.safe_load(stream)
+
+ all_data_parallel_group_ranks = []
+ all_data_parallel_group_ranks_with_cp = []
+ for i in range(pipeline_model_parallel_size):
+ start_rank = i * num_pipeline_model_parallel_groups
+ end_rank = (i + 1) * num_pipeline_model_parallel_groups
+ for j in range(context_parallel_size * tensor_model_parallel_size):
+ ranks = range(
+ start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size
+ )
+ all_data_parallel_group_ranks.append(list(ranks))
+ for j in range(tensor_model_parallel_size):
+ ranks_with_cp = range(
+ start_rank + j, end_rank, tensor_model_parallel_size
+ )
+ all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp))
+
+ timeout = timedelta(minutes=distributed_timeout_minutes)
+
+ # # Regenerate ep related groups because ep is set to 1 in initialize_model_parallel func
+ rank_generator = megatron.core.parallel_state.RankGenerator(
+ tp=tensor_model_parallel_size,
+ ep=expert_model_parallel_size,
+ dp=data_parallel_size * context_parallel_size,
+ pp=pipeline_model_parallel_size,
+ cp=1,
+ order=order,
+ )
+ for ranks in rank_generator.get_ranks('tp-ep-pp', independent_ep=True):
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout,
+ pg_options=get_nccl_options('mp_exp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ megatron.core.parallel_state._MODEL_AND_EXPERT_PARALLEL_GROUP = group
+
+ all_tensor_and_expert_group_ranks = []
+ for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True):
+ all_tensor_and_expert_group_ranks.append(list(ranks))
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ megatron.core.parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = group
+
+ for ranks in rank_generator.get_ranks('ep', independent_ep=True):
+ all_ep_groups.append(list(ranks))
+ group = torch.distributed.new_group(
+ ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ megatron.core.parallel_state._EXPERT_MODEL_PARALLEL_GROUP = group
+
+ all_dp_modulo_exp_group_ranks = []
+ for ranks in rank_generator.get_ranks('dp', independent_ep=True):
+ all_dp_modulo_exp_group_ranks.append(list(ranks))
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('dp_modulo_exp', nccl_comm_cfgs)
+ )
+ group_gloo = torch.distributed.new_group(ranks, backend="gloo")
+ if rank in ranks:
+ megatron.core.parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = group
+ megatron.core.parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = group_gloo
+
+ for ranks in rank_generator.get_ranks('dp-cp', independent_ep=True):
+ # Lazy initialization of the group
+ if get_context_parallel_world_size() > 1:
+ group = torch.distributed.new_group(
+ ranks,
+ timeout=timeout,
+ pg_options=get_nccl_options('dp_modulo_exp_cp', nccl_comm_cfgs),
+ )
+ group_gloo = torch.distributed.new_group(ranks, backend="gloo")
+ else:
+ group = megatron.core.parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP
+ group_gloo = megatron.core.parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
+ if rank in ranks:
+ megatron.core.parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = group
+ megatron.core.parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = group_gloo
+
+ all_tp_groups = []
+ for i in range(num_tensor_model_parallel_groups):
+ ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
+ all_tp_groups.append(list(ranks))
+
+ print_rank_0(f"all tp gourps {all_tp_groups}")
+ print_rank_0(f"all ep groups {all_ep_groups}")
+ print_rank_0(f"all dp groups {all_data_parallel_group_ranks}")
+ print_rank_0(f"all_dp_modulo_exp_group_ranks {all_dp_modulo_exp_group_ranks}")
+ print_rank_0(f"all_tensor_and_expert_group_ranks {all_tensor_and_expert_group_ranks}")
+ print_rank_0(f"all_data_parallel_group_ranks_with_cp {all_data_parallel_group_ranks_with_cp}")
+
+ else:
+ initialize_model_parallel(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ virtual_pipeline_model_parallel_size,
+ pipeline_model_parallel_split_rank,
+ use_sharp,
+ context_parallel_size,
+ expert_model_parallel_size,
+ nccl_communicator_config_path,
+ distributed_timeout_minutes,
+ order
+ )
+
+ initialize_context_parallel_group_for_send_recv_overlap(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ context_parallel_size,
+ nccl_comm_cfgs
+ )
+
+ initialize_context_parallel_group_for_hybrid_cp(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ context_parallel_size,
+ nccl_comm_cfgs
+ )
+
+ initialize_context_parallel_group_for_double_ring(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ context_parallel_size,
+ nccl_comm_cfgs
+ )
+
+ global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM
+ if _PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM is not None:
+ raise AttributeError('Pipeline parallel group for new stream is already initialized')
+ num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
+ for i in range(num_pipeline_model_parallel_groups):
+ ranks = range(i, world_size, num_pipeline_model_parallel_groups)
+ group = torch.distributed.new_group(
+ ranks, pg_options=megatron.core.parallel_state.get_nccl_options('pp_new_stream', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ _PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM = group
+
+ from megatron.training import get_args
+ args = get_args()
+ nd1_dim1_sz = args.nd1_dim1_size if args.use_nd_matmul else args.tp_x
+ nd2_dim1_sz = args.nd2_dim1_size if args.use_nd_matmul else args.tp_y
+ tp_x_groups = initialize_ndmm_parallel_group(
+ nccl_comm_cfgs,
+ tensor_model_parallel_size=tensor_model_parallel_size,
+ nd1_dim1_size=nd1_dim1_sz,
+ nd2_dim1_size=nd2_dim1_sz,
+ )
+
+ if args.tp_2d:
+ from mindspeed.core.tensor_parallel_x_union_cp import TensorParallelXUnionCP
+
+ tp_y_cp_group = TensorParallelYUnionCP(
+ parallel_cfg=SimpleParallelCfg(
+ dp=data_parallel_size,
+ pp=pipeline_model_parallel_size,
+ tp=tensor_model_parallel_size,
+ cp=context_parallel_size,
+ ep=expert_model_parallel_size,
+ tp_x=get_args().tp_x,
+ tp_y=get_args().tp_y,
+ ),
+ pg_name="tp-y-cp",
+ overlap_gp_name="tp-y-cp-overlap",
+ nccl_comm_cfgs=nccl_comm_cfgs
+ )
+ print(f'tp_y_cp_group.global_ranks={tp_y_cp_group.global_ranks} for rank {rank}')
+
+ tp_x_cp_group = TensorParallelXUnionCP(
+ parallel_cfg=SimpleParallelCfg(
+ dp=data_parallel_size,
+ pp=pipeline_model_parallel_size,
+ tp=tensor_model_parallel_size,
+ cp=context_parallel_size,
+ ep=expert_model_parallel_size,
+ tp_x=get_args().tp_x,
+ tp_y=get_args().tp_y,
+ ),
+ pg_name="tp-x-cp",
+ overlap_gp_name=None,
+ nccl_comm_cfgs=nccl_comm_cfgs
+ )
+ print(f'tp_x_cp_group.global_ranks={tp_x_cp_group.global_ranks} for rank {rank}')
+
+ if expert_model_parallel_size > 1:
+ all_tp_x_ep_groups = set()
+ print(f'all_ep_groups={all_ep_groups}')
+ for tp_x_ranks in tp_x_groups:
+ tp_x_ep_ranks_set = set()
+ for ep_ranks in all_ep_groups:
+ tp_x_ranks_set = set(tp_x_ranks)
+ ep_ranks_set = set(ep_ranks)
+ if not tp_x_ranks_set.intersection(ep_ranks_set):
+ continue
+
+ cur_tp_x_ep_ranks_set = tp_x_ranks_set.union(ep_ranks_set)
+ tp_x_ep_ranks_set = tp_x_ep_ranks_set.union(cur_tp_x_ep_ranks_set)
+
+ all_tp_x_ep_groups.add(tuple(sorted(list(tp_x_ep_ranks_set))))
+
+ print(f'{all_tp_x_ep_groups=}')
+ all_tp_x_ep_groups = [tp_x_ep_ranks for tp_x_ep_ranks in all_tp_x_ep_groups]
+ timeout = timedelta(minutes=distributed_timeout_minutes)
+
+ global _TP_X_EP_GROUP
+ for tp_x_ep_ranks in all_tp_x_ep_groups:
+ group = torch.distributed.new_group(
+ tp_x_ep_ranks, timeout=timeout,
+ pg_options=get_nccl_options('tp_x_ep', nccl_comm_cfgs)
+ )
+ if rank in tp_x_ep_ranks:
+ _TP_X_EP_GROUP = group
+
+ print(f'{all_tp_x_ep_groups=}')
+
+ return wrapper
+
+
+def get_ring_group_for_intra_window():
+ global _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW
+ return _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW
+
+
+def get_ring_group_for_intra_window_send_recv_overlap():
+ global _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW_SEND_RECV_OVERLAP
+ return _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW_SEND_RECV_OVERLAP
+
+
+def get_ring_ranks_for_intra_window():
+ global _CONTEXT_PARALLEL_RANKS_FOR_RING_INTRA_WINDOW
+ assert _CONTEXT_PARALLEL_RANKS_FOR_RING_INTRA_WINDOW is not None
+ return _CONTEXT_PARALLEL_RANKS_FOR_RING_INTRA_WINDOW
+
+
+def get_ring_ranks_for_inter_window_kv():
+ global _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_KV
+ assert _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_KV is not None
+ return _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_KV
+
+
+def get_ring_ranks_for_inter_window_dkv():
+ global _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_DKV
+ assert _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_DKV is not None
+ return _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_DKV
+
+
+def initialize_context_parallel_group_for_send_recv_overlap(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ context_parallel_size,
+ nccl_comm_cfgs
+):
+ from megatron.training import get_args
+ if not get_args().use_cp_send_recv_overlap:
+ return
+ # when tp_y > 1, use TensorParallelYUnionCP
+ if get_args().tp_2d and get_args().tp_y > 1:
+ return
+ rank = torch.distributed.get_rank()
+ world_size: int = torch.distributed.get_world_size()
+ num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
+ data_parallel_size: int = world_size // (
+ tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
+ )
+ global _CONTEXT_PARALLEL_GROUP_FOR_SEND_RECV_OVERLAP
+ for i in range(pipeline_model_parallel_size):
+ for j in range(data_parallel_size):
+ start_rank = (
+ i * num_pipeline_model_parallel_groups
+ + j * tensor_model_parallel_size * context_parallel_size
+ )
+ end_rank = (
+ i * num_pipeline_model_parallel_groups
+ + (j + 1) * tensor_model_parallel_size * context_parallel_size
+ )
+ for k in range(tensor_model_parallel_size):
+ ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)
+ group_send_recv_overlap = torch.distributed.new_group(
+ ranks, pg_options=megatron.core.parallel_state.get_nccl_options('cp2', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ _CONTEXT_PARALLEL_GROUP_FOR_SEND_RECV_OVERLAP = group_send_recv_overlap
+
+
+def initialize_context_parallel_group_for_hybrid_cp(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ context_parallel_size,
+ nccl_comm_cfgs
+):
+ from megatron.training import get_args
+ if (not hasattr(get_args(), 'context_parallel_algo') or
+ (
+ get_args().context_parallel_algo != 'hybrid_cp_algo' and get_args().context_parallel_algo != 'hybrid_adaptive_cp_algo')):
+ return
+
+ rank = torch.distributed.get_rank()
+ world_size: int = torch.distributed.get_world_size()
+ num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
+ data_parallel_size: int = world_size // (
+ tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
+ )
+
+ ulysses_degree = get_args().ulysses_degree_in_cp
+ assert (context_parallel_size > ulysses_degree and context_parallel_size % ulysses_degree == 0)
+ ring_degree = context_parallel_size // ulysses_degree
+
+ global _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_ULYSSES
+ global _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_ULYSSES
+ global _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING
+ global _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_RING
+ for i in range(pipeline_model_parallel_size):
+ for j in range(data_parallel_size):
+ start_rank = (
+ i * num_pipeline_model_parallel_groups
+ + j * tensor_model_parallel_size * context_parallel_size
+ )
+ end_rank = (
+ i * num_pipeline_model_parallel_groups
+ + (j + 1) * tensor_model_parallel_size * context_parallel_size
+ )
+ for k in range(tensor_model_parallel_size):
+ # cp ranks
+ ranks = list(range(start_rank + k, end_rank, tensor_model_parallel_size))
+ # ulysses cp ranks.
+ # Ulysses need higher communication bandwidth than Ring.
+ # Try to put Ulysses ranks in the same node.
+ for m in range(ring_degree):
+ ulysses_ranks = [ranks[idx] for idx in range(m * ulysses_degree, (m + 1) * ulysses_degree)]
+ ulysses_group = torch.distributed.new_group(
+ ulysses_ranks,
+ pg_options=megatron.core.parallel_state.get_nccl_options('cp_ulysses', nccl_comm_cfgs)
+ )
+ if rank in ulysses_ranks:
+ _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_ULYSSES = ulysses_group
+ _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_ULYSSES = ulysses_ranks
+
+ # ring cp ranks
+ for m in range(ulysses_degree):
+ ring_ranks = [ranks[idx] for idx in range(m, len(ranks), ulysses_degree)]
+ ring_group = torch.distributed.new_group(
+ ring_ranks, pg_options=megatron.core.parallel_state.get_nccl_options('cp_ring', nccl_comm_cfgs)
+ )
+ if rank in ring_ranks:
+ _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING = ring_group
+ _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_RING = ring_ranks
+
+
+def initialize_context_parallel_group_for_double_ring(
+ tensor_model_parallel_size,
+ pipeline_model_parallel_size,
+ context_parallel_size,
+ nccl_comm_cfgs,
+):
+ from megatron.training import get_args
+ import megatron.core.parallel_state as ps
+ args = get_args()
+ if args.tp_2d:
+ return
+ if context_parallel_size == 1 or args.context_parallel_algo not in ['megatron_cp_algo', 'hybrid_cp_algo']:
+ return
+
+ use_hybrid_cp = args.context_parallel_algo == 'hybrid_cp_algo' and args.ulysses_degree_in_cp > 1
+
+ rank = torch.distributed.get_rank()
+ world_size: int = torch.distributed.get_world_size()
+ num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
+ data_parallel_size: int = world_size // (
+ tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
+ )
+
+ def _initialize_helper(
+ rank,
+ ring_global_ranks,
+ window_size
+ ):
+ from megatron.training import get_args
+ global _CONTEXT_PARALLEL_RANKS_FOR_RING_INTRA_WINDOW
+ global _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_KV
+ global _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_DKV
+ global _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW
+ global _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW_SEND_RECV_OVERLAP
+
+ ring_size = len(ring_global_ranks)
+ inter_size = ring_size // window_size
+ for wid in range(inter_size):
+ intra_ranks = [ring_global_ranks[idx] for idx in range(wid * window_size, (wid + 1) * window_size)]
+ intra_group = torch.distributed.new_group(intra_ranks, pg_options=ps.get_nccl_options('cp_ring_intra', nccl_comm_cfgs))
+ intra_group_for_send_recv_overlap = None
+ if args.use_cp_send_recv_overlap:
+ intra_group_for_send_recv_overlap = torch.distributed.new_group(intra_ranks, pg_options=ps.get_nccl_options('cp_ring_intra_overlap', nccl_comm_cfgs))
+
+ if rank in intra_ranks:
+ _CONTEXT_PARALLEL_RANKS_FOR_RING_INTRA_WINDOW = intra_ranks
+ _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW = intra_group
+ _CONTEXT_PARALLEL_GROUP_FOR_RING_INTRA_WINDOW_SEND_RECV_OVERLAP = intra_group_for_send_recv_overlap
+
+ for inner_id in range(window_size):
+ inter_ranks = [ring_global_ranks[idx] for idx in range(inner_id, ring_size, window_size)]
+ if rank in inter_ranks:
+ _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_KV = inter_ranks
+ break
+
+ for inner_id in range(window_size):
+ inter_dkv_ranks = []
+ cur_rank = ring_global_ranks[inner_id]
+ cur_idx = inner_id
+ cur_window = 0
+ while cur_rank not in inter_dkv_ranks:
+ inter_dkv_ranks.append(cur_rank)
+ cur_window = (cur_window + 1) % inter_size
+ window_start = cur_window * window_size
+ cur_idx = window_start + (cur_idx + 1) % window_size
+ cur_rank = ring_global_ranks[cur_idx]
+
+ if rank in inter_dkv_ranks:
+ _CONTEXT_PARALLEL_RANKS_FOR_RING_INTER_WINDOW_DKV = inter_dkv_ranks
+ break
+
+
+ for i in range(pipeline_model_parallel_size):
+ for j in range(data_parallel_size):
+ start_rank = (
+ i * num_pipeline_model_parallel_groups
+ + j * tensor_model_parallel_size * context_parallel_size
+ )
+ end_rank = (
+ i * num_pipeline_model_parallel_groups
+ + (j + 1) * tensor_model_parallel_size * context_parallel_size
+ )
+ for k in range(tensor_model_parallel_size):
+ cp_ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)
+
+ if use_hybrid_cp:
+ ulysses_degree = get_args().ulysses_degree_in_cp
+ assert (context_parallel_size > ulysses_degree and context_parallel_size % ulysses_degree == 0)
+ # ring cp ranks
+ for m in range(ulysses_degree):
+ ring_ranks = [cp_ranks[idx] for idx in range(m, len(cp_ranks), ulysses_degree)]
+
+ _initialize_helper(rank, ring_ranks, args.cp_window_size)
+ else:
+ _initialize_helper(rank, cp_ranks, args.cp_window_size)
+
+
+def get_context_parallel_group_for_send_recv_overlap(check_initialized=True):
+ """Get the context parallel group for send-recv overlap the caller rank belongs to."""
+ if check_initialized:
+ assert (
+ _CONTEXT_PARALLEL_GROUP_FOR_SEND_RECV_OVERLAP is not None
+ ), 'context parallel group for send-recv overlap is not initialized'
+ return _CONTEXT_PARALLEL_GROUP_FOR_SEND_RECV_OVERLAP
+
+
+def get_context_parallel_next_rank():
+ """Return the global rank that follows the caller in the context parallel"""
+ import megatron.core.parallel_state as ps
+ assert ps._CONTEXT_PARALLEL_GLOBAL_RANKS is not None, "Context parallel group is not initialized"
+ rank_in_context = ps.get_context_parallel_rank()
+ world_size = ps.get_context_parallel_world_size()
+ return ps._CONTEXT_PARALLEL_GLOBAL_RANKS[(rank_in_context + 1) % world_size]
+
+
+def get_context_parallel_prev_rank():
+ """Return the global rank that preceeds the caller in the context parallel"""
+ import megatron.core.parallel_state as ps
+ assert ps._CONTEXT_PARALLEL_GLOBAL_RANKS is not None, "Context parallel group is not initialized"
+ rank_in_context = ps.get_context_parallel_rank()
+ world_size = ps.get_context_parallel_world_size()
+ return ps._CONTEXT_PARALLEL_GLOBAL_RANKS[(rank_in_context - 1) % world_size]
+
+
+def get_pipeline_parallel_group_for_new_stream():
+ if _PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM is None:
+ raise AttributeError('Pipeline parallel group of backward is not initialized')
+ return _PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM
+
+
+def get_context_parallel_group_for_hybrid_ulysses(check_initialized=True):
+ """Get the context parallel group for hybrid ulysses the caller rank belongs to."""
+ if check_initialized:
+ assert (
+ _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_ULYSSES is not None
+ ), 'context parallel group for hybrid ulysses is not initialized'
+ return _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_ULYSSES
+
+
+def get_context_parallel_for_hybrid_ulysses_world_size():
+ return torch.distributed.get_world_size(group=get_context_parallel_group_for_hybrid_ulysses())
+
+
+def get_context_parallel_for_hybrid_ulysses_rank():
+ return torch.distributed.get_rank(group=get_context_parallel_group_for_hybrid_ulysses())
+
+
+def get_context_parallel_group_for_hybrid_ring(check_initialized=True):
+ """Get the context parallel group for hybrid ring the caller rank belongs to."""
+ if check_initialized:
+ assert (
+ _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING is not None
+ ), 'context parallel group for hybrid ring is not initialized'
+ return _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING
+
+
+def get_context_parallel_for_hybrid_ring_world_size():
+ return torch.distributed.get_world_size(group=get_context_parallel_group_for_hybrid_ring())
+
+
+def get_context_parallel_for_hybrid_ring_rank():
+ return torch.distributed.get_rank(group=get_context_parallel_group_for_hybrid_ring())
+
+
+def get_context_parallel_for_hybrid_ring_global_ranks():
+ assert (_CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING is not None
+ ), 'context parallel group for hybrid ring is not initialized'
+ global _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_RING
+ return _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_RING
+
+
+def get_tp_x_ring_global_ranks():
+ global _TP_X_PARALLEL_RING_RANKS
+ assert (_TP_X_PARALLEL_RING_RANKS is not None), 'TP-X parallel group for ring is not initialized'
+ return _TP_X_PARALLEL_RING_RANKS
+
+
+def get_tp_y_ring_global_ranks():
+ global _TP_Y_PARALLEL_RING_RANKS
+ assert (_TP_Y_PARALLEL_RING_RANKS is not None), 'TP-Y parallel group for ring is not initialized'
+ return _TP_Y_PARALLEL_RING_RANKS
+
+
+def destroy_model_parallel_wrapper(destroy_model_parallel):
+ @wraps(destroy_model_parallel)
+ def wrapper():
+ destroy_model_parallel()
+
+ global _CONTEXT_PARALLEL_GROUP_FOR_SEND_RECV_OVERLAP
+ global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM
+ global _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING
+ global _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_ULYSSES
+ global _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_RING
+ global _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_ULYSSES
+ global _TP_X_PARALLEL_RING_RANKS
+ global _TP_Y_PARALLEL_RING_RANKS
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1
+ global _TP_X_SD_RCV_OVERLAP_GROUP
+ global _TP_Y_SD_RCV_OVERLAP_GROUP
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_RANK
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_RANK
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_WORLD_SIZE
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_WORLD_SIZE
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM1
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM2
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM1
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM2
+ global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
+ global _TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS
+ _CONTEXT_PARALLEL_GROUP_FOR_SEND_RECV_OVERLAP = None
+ _PIPELINE_MODEL_PARALLEL_GROUP_FOR_NEW_STREAM = None
+ _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_RING = None
+ _CONTEXT_PARALLEL_GROUP_FOR_HYBRID_ULYSSES = None
+ _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_RING = None
+ _CONTEXT_PARALLEL_RANKS_FOR_HYBRID_ULYSSES = None
+ _TENSOR_AND_CONTEXT_PARALLEL_GROUP = None
+ _TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = None
+ _TP_X_PARALLEL_RING_RANKS = None
+ _TP_Y_PARALLEL_RING_RANKS = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1 = None
+ _TP_X_SD_RCV_OVERLAP_GROUP = None
+ _TP_Y_SD_RCV_OVERLAP_GROUP = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2 = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_RANK = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_RANK = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_WORLD_SIZE = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_WORLD_SIZE = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1 = None
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2 = None
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM1 = None
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM2 = None
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM1 = None
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM2 = None
+
+ return wrapper
+
+
+def get_tensor_model_parallel_group_for_nd1_dim1(check_initialized=True):
+ if check_initialized and _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1 is None:
+ raise AssertionError('tensor model parallel group for nd1 dim1 is not initialized')
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1
+
+
+def get_tp_x_sd_rcv_overlap_group(check_initialized=True):
+ if check_initialized and _TP_X_SD_RCV_OVERLAP_GROUP is None:
+ raise AssertionError('tp-x send recv overlap group is not initialized')
+ return _TP_X_SD_RCV_OVERLAP_GROUP
+
+
+def get_tp_y_sd_rcv_overlap_group(check_initialized=True):
+ if check_initialized and _TP_Y_SD_RCV_OVERLAP_GROUP is None:
+ raise AssertionError('tp-y send recv overlap group is not initialized')
+ return _TP_Y_SD_RCV_OVERLAP_GROUP
+
+
+def get_tensor_model_parallel_group_for_nd1_dim2(check_initialized=True):
+ if check_initialized and _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2 is None:
+ raise AssertionError('tensor model parallel group for nd1 dim2 is not initialized')
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2
+
+
+def get_tp_x_ep_group(check_initialized=True):
+ if check_initialized and _TP_X_EP_GROUP is None:
+ return get_tensor_model_parallel_group_for_nd1_dim1()
+ return _TP_X_EP_GROUP
+
+
+def get_tp_x_ep_group_world_size():
+ global _TP_X_EP_GROUP_WORLD_SIZE
+ if _TP_X_EP_GROUP_WORLD_SIZE is None:
+ _TP_X_EP_GROUP_WORLD_SIZE = torch.distributed.get_world_size(group=get_tp_x_ep_group())
+
+ return _TP_X_EP_GROUP_WORLD_SIZE
+
+
+def get_tp_x_ep_group_rank():
+ global _TP_X_EP_GROUP_RANK
+ if _TP_X_EP_GROUP_RANK is None:
+ _TP_X_EP_GROUP_RANK = torch.distributed.get_rank(
+ group=get_tp_x_ep_group())
+
+ return _TP_X_EP_GROUP_RANK
+
+
+def get_tensor_model_parallel_group_for_nd2_dim1(check_initialized=True):
+ if check_initialized and _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1 is None:
+ raise AssertionError('tensor model parallel group for nd2 dim1 is not initialized')
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1
+
+
+def get_tensor_model_parallel_group_for_nd1_dim1_rank():
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_RANK
+ if _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_RANK is None:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_RANK = torch.distributed.get_rank(
+ group=get_tensor_model_parallel_group_for_nd1_dim1())
+
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_RANK
+
+
+def get_tensor_model_parallel_group_for_nd1_dim2_rank():
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_RANK
+ if _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_RANK is None:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_RANK = torch.distributed.get_rank(
+ group=get_tensor_model_parallel_group_for_nd1_dim2())
+
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_RANK
+
+
+def get_tensor_model_parallel_group_for_nd1_dim1_world_size():
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_WORLD_SIZE
+ if _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_WORLD_SIZE is None:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_WORLD_SIZE = torch.distributed.get_world_size(
+ group=get_tensor_model_parallel_group_for_nd1_dim1())
+
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1_WORLD_SIZE
+
+
+def get_tensor_model_parallel_group_for_nd1_dim2_world_size():
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_WORLD_SIZE
+ if _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_WORLD_SIZE is None:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_WORLD_SIZE = torch.distributed.get_world_size(
+ group=get_tensor_model_parallel_group_for_nd1_dim2())
+
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2_WORLD_SIZE
+
+
+def get_tensor_model_parallel_group_for_nd2_dim2(check_initialized=True):
+ if check_initialized and _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2 is None:
+ raise AssertionError('tensor model parallel group for nd2 dim2 is not initialized')
+ return _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2
+
+
+def get_tensor_model_parallel_world_size_for_nd1_dim1():
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM1
+ if _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM1 is None:
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM1 = torch.distributed.get_world_size(
+ group=get_tensor_model_parallel_group_for_nd1_dim1()
+ )
+ return _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM1
+
+
+def get_tensor_model_parallel_world_size_for_nd1_dim2():
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM2
+ if _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM2 is None:
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM2 = torch.distributed.get_world_size(
+ group=get_tensor_model_parallel_group_for_nd1_dim2()
+ )
+ return _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND1_DIM2
+
+
+def get_tensor_model_parallel_world_size_for_nd2_dim1():
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM1
+ if _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM1 is None:
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM1 = torch.distributed.get_world_size(
+ group=get_tensor_model_parallel_group_for_nd2_dim1()
+ )
+ return _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM1
+
+
+def get_tensor_model_parallel_world_size_for_nd2_dim2():
+ global _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM2
+ if _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM2 is None:
+ _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM2 = torch.distributed.get_world_size(
+ group=get_tensor_model_parallel_group_for_nd2_dim2()
+ )
+ return _TENSOR_MODEL_PARALLEL_WORLD_SIZE_FOR_ND2_DIM2
+
+
+def initialize_ndmm_parallel_group(
+ nccl_comm_cfgs: dict,
+ tensor_model_parallel_size: int = 1,
+ nd1_dim1_size: int = 1,
+ nd2_dim1_size: int = 1,
+):
+ import megatron.core.parallel_state as ps
+ from megatron.training import get_args
+ from megatron.training.global_vars import _ensure_var_is_not_initialized
+
+ args = get_args()
+ if not (args.use_nd_matmul or args.tp_2d):
+ return
+
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1
+ _ensure_var_is_not_initialized(
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1, 'nd1_dim1'
+ )
+
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2
+ _ensure_var_is_not_initialized(
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2, 'nd1_dim2'
+ )
+
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1
+ _ensure_var_is_not_initialized(
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1, 'nd2_dim1'
+ )
+
+ global _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2
+ _ensure_var_is_not_initialized(
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2, 'nd2_dim2'
+ )
+
+ global _TP_X_PARALLEL_RING_RANKS
+ _ensure_var_is_not_initialized(_TP_X_PARALLEL_RING_RANKS, 'tp_x_ring_ranks')
+
+ global _TP_Y_PARALLEL_RING_RANKS
+ _ensure_var_is_not_initialized(_TP_Y_PARALLEL_RING_RANKS, 'tp_y_ring_ranks')
+
+ global _TP_X_SD_RCV_OVERLAP_GROUP
+ _ensure_var_is_not_initialized(_TP_X_SD_RCV_OVERLAP_GROUP, 'tp_x_overlap_ranks')
+
+ global _TP_Y_SD_RCV_OVERLAP_GROUP
+ _ensure_var_is_not_initialized(_TP_Y_SD_RCV_OVERLAP_GROUP, 'tp_y_overlap_ranks')
+
+ if tensor_model_parallel_size % nd1_dim1_size != 0:
+ raise RuntimeError(
+ f"tensor_model_parallel_size can't divisible by nd1_dim1_size"
+ )
+
+ if tensor_model_parallel_size % nd2_dim1_size != 0:
+ raise RuntimeError(
+ f"tensor_model_parallel_size can't divisible by nd2_dim1_size"
+ )
+
+ rank = torch.distributed.get_rank()
+ world_size: int = torch.distributed.get_world_size()
+ num_tensor_model_parallel_group: int = world_size // tensor_model_parallel_size
+
+ tp_nd1_dim1_groups = [] # TPX-RANKS
+ tp_nd1_dim2_groups = []
+ tp_nd2_dim1_groups = []
+ tp_nd2_dim2_groups = []
+ for i in range(num_tensor_model_parallel_group):
+ for j in range(tensor_model_parallel_size // nd1_dim1_size):
+ ranks = range(
+ i * tensor_model_parallel_size + j * nd1_dim1_size,
+ i * tensor_model_parallel_size + (j + 1) * nd1_dim1_size
+ )
+ tp_nd1_dim1_groups.append(list(ranks))
+ group = torch.distributed.new_group(
+ ranks, pg_options=ps.get_nccl_options('nd1_dim1', nccl_comm_cfgs)
+ )
+ if args.enable_overlap_ag_with_matmul or args.enable_backward_overlap_ag_with_matmul:
+ tp_x_ag_overlap_group = torch.distributed.new_group(
+ ranks, pg_options=ps.get_nccl_options('ag_x_sd_rcv_overlap', nccl_comm_cfgs)
+ )
+ else:
+ tp_x_ag_overlap_group = None
+ if rank in ranks:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM1 = group
+ _TP_X_SD_RCV_OVERLAP_GROUP = tp_x_ag_overlap_group
+ _TP_X_PARALLEL_RING_RANKS = ranks
+
+ nd1_dim2_size = tensor_model_parallel_size // nd1_dim1_size
+ for j in range(tensor_model_parallel_size // nd1_dim2_size):
+ ranks = range(
+ i * tensor_model_parallel_size + j,
+ (i + 1) * tensor_model_parallel_size,
+ nd1_dim1_size
+ )
+ tp_nd1_dim2_groups.append(list(ranks))
+ group = torch.distributed.new_group(
+ ranks, pg_options=ps.get_nccl_options('nd1_dim2', nccl_comm_cfgs)
+ )
+ if args.enable_overlap_ag_with_matmul or args.enable_backward_overlap_ag_with_matmul:
+ tp_y_ag_overlap_group = torch.distributed.new_group(
+ ranks, pg_options=ps.get_nccl_options('ag_y_sd_rcv_overlap', nccl_comm_cfgs)
+ )
+ else:
+ tp_y_ag_overlap_group = None
+ if rank in ranks:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND1_DIM2 = group
+ _TP_Y_SD_RCV_OVERLAP_GROUP = tp_y_ag_overlap_group
+ _TP_Y_PARALLEL_RING_RANKS = ranks
+
+ for j in range(tensor_model_parallel_size // nd2_dim1_size):
+ ranks = range(
+ i * tensor_model_parallel_size + j * nd2_dim1_size,
+ i * tensor_model_parallel_size + (j + 1) * nd2_dim1_size
+ )
+ tp_nd2_dim1_groups.append(list(ranks))
+ group = torch.distributed.new_group(
+ ranks, pg_options=ps.get_nccl_options('nd2_dim1', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM1 = group
+
+ nd2_dim2_size = tensor_model_parallel_size // nd2_dim1_size
+ for j in range(tensor_model_parallel_size // nd2_dim2_size):
+ ranks = range(
+ i * tensor_model_parallel_size + j,
+ (i + 1) * tensor_model_parallel_size,
+ nd2_dim1_size
+ )
+ tp_nd2_dim2_groups.append(list(ranks))
+ group = torch.distributed.new_group(
+ ranks, pg_options=ps.get_nccl_options('nd2_dim2', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ _TENSOR_MODEL_PARALLEL_GROUP_FOR_ND2_DIM2 = group
+
+ print(f'tp-x groups: {tp_nd1_dim1_groups}')
+ return tp_nd1_dim1_groups
+
+
+def get_data_parallel_group_gloo_replace(with_context_parallel=False):
+ """Get the data parallel group-gloo the caller rank belongs to."""
+ import megatron.core.parallel_state as ps
+
+ if with_context_parallel:
+ assert (
+ ps._DATA_PARALLEL_GROUP_WITH_CP_GLOO is None
+ ), 'data parallel group-gloo with context parallel combined should be None when args.disable_gloo_group is True'
+ return ps._DATA_PARALLEL_GROUP_WITH_CP
+ else:
+ assert ps._DATA_PARALLEL_GROUP_GLOO is None, 'data parallel group-gloo should be None when args.disable_gloo_group is True'
+ return ps._DATA_PARALLEL_GROUP
+
+
+def get_data_modulo_expert_parallel_group_gloo_replace(with_context_parallel=False):
+ import megatron.core.parallel_state as ps
+
+ if with_context_parallel:
+ assert (
+ ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO is None
+ ), 'data modulo expert parallel group-gloo with context parallel is not initialized'
+ return ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
+ else:
+ assert (
+ ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is None
+ ), 'data modulo expert parallel group-gloo should be None when args.disable_gloo_group is True'
+ return ps._DATA_MODULO_EXPERT_PARALLEL_GROUP
+
+
+def new_group_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ from megatron.training import get_args
+ if get_args().disable_gloo_group:
+ if "backend" in kwargs and kwargs["backend"] == "gloo":
+ return None
+ return fn(*args, **kwargs)
+ return wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/autopipeline_perf.py b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/autopipeline_perf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1ad8a5962479bb6a06a6070b342e50e92f857e4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/autopipeline_perf.py
@@ -0,0 +1,401 @@
+import time
+from functools import partial
+_TRAIN_START_TIME = time.time()
+import json
+import os.path
+import gc
+import copy
+import torch
+import torch.nn
+import torch_npu
+from megatron.training import print_rank_0
+from megatron.training.arguments import parse_args
+from megatron.core.parallel_state import get_embedding_group
+from megatron.training import get_args
+from megatron.training import get_timers
+from megatron.training import training
+from megatron.training.training import print_datetime
+from megatron.core.pipeline_parallel import p2p_communication
+from megatron.core import mpu, tensor_parallel
+from megatron.training.initialize import initialize_megatron
+from megatron.training.initialize import set_jit_fusion_options
+
+
+profile_context = {"fwd_time":[], "bwd_time":[]}
+
+
+class AutoPipeline_Perf:
+ autopipeline_perf = None
+
+ def __init__(self, args):
+ self.args = copy.deepcopy(args)
+ self.context = {
+ 'module': []
+ }
+ self.modules_hooks = []
+ self.profiling_step = 0
+ self.stop_profiling_step = 3
+ self.unit_mb = 1024 * 1024
+
+ @staticmethod
+ def get_memory_status():
+ used_memory = torch.npu.memory_allocated()
+ reserved_memory = torch.npu.memory_reserved()
+ return used_memory, reserved_memory
+
+ def _cal_tensor_size(self, tensor):
+ try:
+ return tensor.numel() * tensor.element_size() / self.unit_mb
+ except ZeroDivisionError:
+ return 0
+
+ def pre_hook_func(self, state, sync: bool, *args, **kargs):
+ used_memory, _ = self.get_memory_status()
+ torch.npu.reset_max_memory_allocated()
+ state['memory'] = used_memory
+ size = 0
+ for arg in args:
+ if isinstance(arg, torch.Tensor):
+ size += self._cal_tensor_size(arg)
+ elif isinstance(arg, tuple) or isinstance(arg, list):
+ for t in arg:
+ if isinstance(t, torch.Tensor):
+ size += self._cal_tensor_size(t)
+ state['input'] = size
+
+ def post_hook_func(self, state, sync: bool, *args, **kargs):
+ used_memory, _ = self.get_memory_status()
+ max_mem = torch.npu.max_memory_allocated()
+ state['peak_memory'] = max_mem - state['memory']
+ state['memory'] = (used_memory - state['memory']) // self.unit_mb
+
+ def forward_pre_hook(self, name, parent_ctx, ctx):
+ if self.profiling_step < self.stop_profiling_step:
+ ctx['name'] = name
+ if 'layers' in parent_ctx:
+ parent_ctx['layers'].append(ctx)
+
+ def hook(module, *args, **kargs):
+ if self.profiling_step < self.stop_profiling_step:
+ if 'module' in self.context:
+ self.context['module'].append(ctx)
+ self.pre_hook_func(ctx, True, *args, **kargs)
+
+ return hook
+
+ def forward_post_hook(self, ctx):
+ def hook(module, *args, **kargs):
+ if self.profiling_step < self.stop_profiling_step:
+ self.post_hook_func(ctx, True, *args)
+ if 'module' in self.context:
+ self.context['module'].pop()
+
+ return hook
+
+ def register_recursive_hook(self, prefix_name, model, ctx):
+ for name, module in model.named_children():
+ if 'layers' not in ctx:
+ ctx['layers'] = []
+ current_ctx = {}
+
+ next_name = prefix_name + "." + name if prefix_name != "" else name
+ if next_name == "module.module":
+ pre_hook = module.register_forward_pre_hook(self.forward_pre_hook(name, ctx, current_ctx))
+ post_hook = module.register_forward_hook(self.forward_post_hook(current_ctx))
+ self.modules_hooks.append(pre_hook)
+ self.modules_hooks.append(post_hook)
+ self.register_recursive_hook(next_name, module, current_ctx)
+
+ def step_hook(self, model):
+ self.profiling_step += 1
+
+ def hook_step_func(self, step_func, models):
+ def custom_step_func(*args, **kargs):
+ result = step_func(*args, **kargs)
+ if self.profiling_step < self.stop_profiling_step:
+ used_memory, reserved_memory = self.get_memory_status()
+ self.context['used_mem'] = used_memory // self.unit_mb
+ if isinstance(models, list):
+ for model in models:
+ self.step_hook(model)
+ else:
+ self.step_hook(models)
+ return result
+
+ return custom_step_func
+
+ def remove_outliers(self, data, m=2):
+ data = sorted(data)
+ median = data[len(data) // 2]
+ deviation = [x for x in data if median - m * median < x < median + m * median]
+ return deviation
+
+ def get_forward_context(self):
+ global profile_context
+ if "fwd_time" in profile_context:
+ fwd_time_list = self.remove_outliers(profile_context["fwd_time"])
+ try:
+ self.context["fwd_time"] = sum(fwd_time_list) / len(fwd_time_list)
+ except ZeroDivisionError:
+ print("[Error] Divided by zero.")
+ else:
+ self.context["fwd_time"] = 0
+
+ def get_backward_context(self):
+ global profile_context
+ if "bwd_time" in profile_context:
+ bwd_time_list = self.remove_outliers(profile_context["bwd_time"])
+ try:
+ self.context["bwd_time"] = sum(bwd_time_list) / len(bwd_time_list)
+ except ZeroDivisionError:
+ print("[Error] Divided by zero.")
+ else:
+ self.context["bwd_time"] = 0
+
+ def clear_global_context(self):
+ global profile_context
+ profile_context["fwd_time"] = []
+ profile_context["bwd_time"] = []
+
+ def get_comm_time(self, config, sync: bool):
+ if torch.distributed.get_rank() == 0:
+ if sync:
+ torch.cuda.synchronize()
+ input_tensor = torch.ones(self.args.seq_length, self.args.micro_batch_size, self.args.hidden_size)
+ start_time = time.time()
+ p2p_communication.send_backward(input_tensor, config)
+ comm_time = (time.time() - start_time) * 1000
+ self.context['comm_time'] = comm_time
+ else:
+ self.context['comm_time'] = 0.028
+
+ def get_peak_memory(self, sync: bool):
+ if sync:
+ torch.cuda.synchronize()
+ max_mem = torch.npu.max_memory_allocated() / (1 << 20)
+ self.context['peak_memory'] = max_mem
+
+ def get_smi_peak_memory(self, sync: bool):
+ if sync:
+ torch.cuda.synchronize()
+ mem_infos = torch.npu.mem_get_info()
+ smi_peak_memory = (mem_infos[1] - mem_infos[0]) / (1 << 20)
+ self.context['smi_peak_memory'] = smi_peak_memory
+
+ def get_smi_left_memory(self, sync: bool):
+ if sync:
+ torch.cuda.synchronize()
+ mem_infos = torch.npu.mem_get_info()
+ smi_left_memory = mem_infos[0] / (1 << 20)
+ self.context['smi_left_memory'] = smi_left_memory
+
+ def get_data_parallel_size(self, data_parallel_size):
+ if data_parallel_size:
+ self.context['data_parallel_size'] = data_parallel_size
+ else:
+ self.context['data_parallel_size'] = 1
+
+ def broadcast_param_in_ranks(self, src_rank, param, init_memory):
+ if torch.distributed.get_rank() == src_rank:
+ try:
+ param = torch.npu.max_memory_allocated() / self.unit_mb - init_memory
+ except ZeroDivisionError:
+ print("[Error] Divided by zero.")
+ tmp_param = torch.cuda.IntTensor([param])
+ torch.distributed.broadcast(tmp_param, src=src_rank)
+ param = tmp_param.item()
+ return param
+
+ def update_args_for_profiling(self, micro_batch_size=None):
+ args = get_args()
+ args.train_iters = self.stop_profiling_step
+ if micro_batch_size:
+ args.micro_batch_size = micro_batch_size
+ args.global_batch_size = args.micro_batch_size * 16
+ args.save = False
+ args.log_interval = 10
+
+ def restore_args_for_training(self):
+ args = get_args()
+ if args.num_layers_per_virtual_pipeline_stage is None:
+ args.num_layers = self.args.num_layers
+ args.encoder_num_layers = self.args.num_layers
+ args.train_iters = self.args.train_iters
+ args.micro_batch_size = self.args.micro_batch_size
+ args.global_batch_size = self.args.global_batch_size
+ args.save = self.args.save
+ args.log_interval = self.args.log_interval
+
+
+def check_equal_model_configs(args, parsed_contents):
+ model_index = 0
+ for model_instance in parsed_contents:
+ if args.hidden_size == model_instance.get("model_configs", {}).get("hidden_size") \
+ and args.ffn_hidden_size == model_instance.get("model_configs", {}).get("ffn_hidden_size") \
+ and args.seq_length == model_instance.get("model_configs", {}).get("seq_length") \
+ and args.num_attention_heads == model_instance.get("model_configs", {}).get("num_attention_heads"):
+ return model_index
+ else:
+ model_index += 1
+ return -1
+
+
+def check_equal_parallel_configs(args, parsed_content):
+ for parallel_instance in parsed_content.get("optimpipeline_policy"):
+ if args.num_layers == parallel_instance.get("num_layers") \
+ and args.pipeline_model_parallel_size == parallel_instance.get("pipeline_model_parallel_size") \
+ and args.tensor_model_parallel_size == parallel_instance.get("tensor_model_parallel_size") \
+ and args.micro_batch_size == parallel_instance.get("micro_batch_size") \
+ and args.global_batch_size == parallel_instance.get("global_batch_size"):
+ return parallel_instance.get("enable_scheduler"), parallel_instance.get("optimized_mbs_list"), parallel_instance.get(
+ "pp_schedule_list"), parallel_instance.get("optimal_layers")
+ return None, None, None, None
+
+
+def check_skip_profiling(args, config_file):
+ if os.path.exists(config_file):
+ with open(config_file) as config_json:
+ config_contents = config_json.read()
+ parsed_contents = json.loads(config_contents)
+ index = check_equal_model_configs(args, parsed_contents)
+ if index != -1:
+ optimized_type, optimized_mbs_list, pp_schedule_list, optimal_layers = check_equal_parallel_configs(args, parsed_contents[index])
+ if optimized_mbs_list or pp_schedule_list:
+ return True, (optimized_type, optimized_mbs_list, pp_schedule_list, optimal_layers)
+ return False, (None, None, None, None)
+
+
+def check_out_of_memory(args, context, mbs_tries):
+ total_memory = torch_npu.npu.get_device_properties(0).total_memory / (1 << 20)
+ per_activation_memory_allocated = context["layers"][0]["memory"] // mbs_tries
+ predict_next_max_memory_allocated = context["smi_peak_memory"] + per_activation_memory_allocated * args.pipeline_model_parallel_size + 1000
+ if predict_next_max_memory_allocated > total_memory:
+ return True
+ else:
+ return False
+
+
+def broadcast_skip_in_ranks(src_rank, policy):
+ is_skip = [False]
+ if torch.distributed.get_rank() == src_rank:
+ is_skip = [policy]
+ tmp_is_skip = torch.cuda.BoolTensor(is_skip)
+ torch.distributed.broadcast(tmp_is_skip, src=src_rank)
+ return tmp_is_skip.item()
+
+
+def calculate_num_of_activations(context):
+ total_memory = torch_npu.npu.get_device_properties(0).total_memory / (1 << 20)
+ activation_memory_allocated = context["layers"][0]["memory"]
+ num_of_activations_left = (total_memory - context["smi_peak_memory"]) // activation_memory_allocated
+ return int(num_of_activations_left)
+
+
+def get_autopipeline_perf(args):
+ AutoPipeline_Perf.autopipeline_perf = AutoPipeline_Perf(args)
+ return AutoPipeline_Perf.autopipeline_perf
+
+
+def autopipelineperf_profiling(mbs_tries, model_provider, model_type, forward_step_func, train_valid_test_dataset_provider,
+ process_non_loss_data_func):
+ initialize_megatron(extra_args_provider=None,
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
+ set_jit_fusion_options()
+ global _TRAIN_START_TIME
+ start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
+ torch.distributed.all_reduce(start_time_tensor,
+ op=torch.distributed.ReduceOp.MIN)
+ _TRAIN_START_TIME = start_time_tensor.item()
+ print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
+ time.time() - _TRAIN_START_TIME))
+ print_datetime('after megatron is initialized')
+ args = get_args()
+ pipelining = get_autopipeline_perf(args)
+ pipelining.update_args_for_profiling(mbs_tries)
+ models, optimizer, lr_scheduler = training.setup_model_and_optimizer(model_provider, model_type)
+ optimizer.step = pipelining.hook_step_func(optimizer.step, models)
+ config = training.get_model_config(models[0])
+
+ if args.virtual_pipeline_model_parallel_size is not None:
+ train_data_iterator = []
+ valid_data_iterator = []
+ for i in range(len(models)):
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
+ iterators = training.build_train_valid_test_data_iterators(
+ train_valid_test_dataset_provider)
+ train_data_iterator.append(iterators[0])
+ valid_data_iterator.append(iterators[1])
+ else:
+ train_data_iterator, valid_data_iterator, _ = training.build_train_valid_test_data_iterators(
+ train_valid_test_dataset_provider)
+ if isinstance(models, list):
+ for model in models:
+ pipelining.register_recursive_hook("module", model, pipelining.context)
+ else:
+ pipelining.register_recursive_hook("module", models, pipelining.context)
+ checkpointing_context = {}
+ training.train(forward_step_func, models, optimizer, lr_scheduler, train_data_iterator, valid_data_iterator,
+ process_non_loss_data_func, config, checkpointing_context)
+ pipelining.get_smi_peak_memory(sync=True)
+ pipelining.get_smi_left_memory(sync=True)
+ pipelining.get_comm_time(config, sync=True)
+ pipelining.get_peak_memory(sync=True)
+ pipelining.get_data_parallel_size(args.data_parallel_size)
+ pipelining.get_forward_context()
+ pipelining.get_backward_context()
+ pipelining.clear_global_context()
+
+ timers = get_timers()
+ if timers('interval-time'):
+ timers('interval-time').stop(barrier=True)
+
+ for hook_handle in pipelining.modules_hooks:
+ hook_handle.remove()
+ pipelining.modules_hooks.clear()
+ pipelining.restore_args_for_training()
+
+ if hasattr(optimizer, 'chained_optimizers'):
+ for op in optimizer.chained_optimizers:
+ for key, value in op.optimizer.state.items():
+ key.detach()
+ key.grad = None
+ key.storage().resize_(0)
+ if "momentum_buffer" in value:
+ value["momentum_buffer"].detach()
+ value["momentum_buffer"].grad = None
+ value["momentum_buffer"].storage().resize_(0)
+ for ofg in op.param_groups:
+ if "params" in ofg:
+ for og in ofg["params"]:
+ og.detach()
+ og.grad = None
+ og.storage().resize_(0)
+ else:
+ for key, value in optimizer.optimizer.state.items():
+ key.detach()
+ key.grad = None
+ key.storage().resize_(0)
+ if "momentum_buffer" in value:
+ value["momentum_buffer"].detach()
+ value["momentum_buffer"].grad = None
+ value["momentum_buffer"].storage().resize_(0)
+ for ofg in optimizer.param_groups:
+ if "params" in ofg:
+ for og in ofg["params"]:
+ og.detach()
+ og.grad = None
+ og.storage().resize_(0)
+ for md in models:
+ for param in md.parameters():
+ param.detach()
+ param.grad = None
+ param.storage().resize_(0)
+ for param_tensor in md.state_dict():
+ if md.state_dict()[param_tensor] is not None:
+ md.state_dict()[param_tensor].detach()
+ md.state_dict()[param_tensor].grad = None
+ md.state_dict()[param_tensor].storage().resize_(0)
+
+ gc.collect()
+ torch_npu.npu.empty_cache()
+ return pipelining.context
diff --git a/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/data_samplers.py b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/data_samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..bebdbb8717350d12a9e47c1e5d65f9b91d793465
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/data_samplers.py
@@ -0,0 +1,71 @@
+import random
+from functools import wraps
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from megatron.training import get_args
+from megatron.core import mpu
+
+
+def build_pretraining_data_loader_decorator(build_pretraining_data_loader):
+ @wraps(build_pretraining_data_loader)
+ def wrapper(*args, **kwargs):
+ if args[0] is None:
+ return None
+ argument = get_args()
+ if argument.dataloader_type == 'single' and argument.automated_pipeline_perf and argument.optimized_mbs_list:
+ batch_sampler = DynamicMicroBatchPretrainingSampler(
+ total_samples=len(args[0]),
+ consumed_samples=args[1],
+ micro_batch_size=argument.micro_batch_size,
+ data_parallel_rank=mpu.get_data_parallel_rank(),
+ data_parallel_size=mpu.get_data_parallel_world_size())
+ return torch.utils.data.DataLoader(args[0],
+ batch_sampler=batch_sampler,
+ num_workers=argument.num_workers,
+ pin_memory=True)
+ else:
+ dataloader = build_pretraining_data_loader(*args, **kwargs)
+ return dataloader
+ return wrapper
+
+
+class DynamicMicroBatchPretrainingSampler:
+
+ def __init__(self, total_samples, consumed_samples, micro_batch_size,
+ data_parallel_rank, data_parallel_size, drop_last=True):
+
+ args = get_args()
+ self.total_samples = total_samples
+ self.consumed_samples = consumed_samples
+ self.micro_batch_size = micro_batch_size
+ self.data_parallel_rank = data_parallel_rank
+ self.drop_last = drop_last
+ self.dynamic_micro_batch_size = args.optimized_mbs_list
+ self.micro_batch_times_data_parallel_size = [
+ self.dynamic_micro_batch_size[i] * data_parallel_size \
+ for i in range(len(self.dynamic_micro_batch_size))
+ ]
+
+ def __len__(self):
+ return self.total_samples
+
+ def get_start_end_idx(self, n_mbs):
+ start_idx = self.data_parallel_rank * self.dynamic_micro_batch_size[n_mbs]
+ end_idx = start_idx + self.dynamic_micro_batch_size[n_mbs]
+ return start_idx, end_idx
+
+ def __iter__(self):
+ batch = []
+ n_mbs = 0
+ for idx in range(self.consumed_samples, self.total_samples):
+ batch.append(idx)
+ if len(batch) == self.micro_batch_times_data_parallel_size[n_mbs]:
+ start_idx, end_idx = self.get_start_end_idx(n_mbs)
+ yield batch[start_idx:end_idx]
+ batch = []
+ n_mbs = (n_mbs + 1) % len(self.micro_batch_times_data_parallel_size)
+
+ if len(batch) > 0 and not self.drop_last:
+ start_idx, end_idx = self.get_start_end_idx()
+ yield batch[start_idx:end_idx]
diff --git a/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/global_vars.py b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/global_vars.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91a5703c4c060f861bc6b4544084b92600a081c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/global_vars.py
@@ -0,0 +1,16 @@
+from functools import wraps
+from megatron.training import get_args
+
+
+def get_num_microbatches_wrapper(get_num_microbatches):
+ @wraps(get_num_microbatches)
+ def wrapper(*args, **kwargs):
+ argument = get_args()
+ automated_pipeline_profile = argument.automated_pipeline_perf and not argument.optimized_mbs_list
+ if argument.automated_pipeline_perf and argument.optimized_mbs_list and argument.optimized_mbs_mode:
+ return len(argument.optimized_mbs_list)
+ elif automated_pipeline_profile:
+ return argument.global_batch_size // argument.data_parallel_size // argument.micro_batch_size
+ else:
+ return get_num_microbatches(*args, **kwargs)
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/optimpipeline_solver.py b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/optimpipeline_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..b643d85ea61cbeb0201ba466d48da634819e1a90
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/optimpipeline_solver.py
@@ -0,0 +1,304 @@
+import os
+import json
+import math
+import time
+from datetime import datetime
+from itertools import product
+import numpy as np
+import torch
+from megatron.training import get_args
+from megatron.training.arguments import parse_args
+from mindspeed.arguments import parse_args_wrapper
+from .autopipeline_perf import check_equal_model_configs
+
+
+class Parallel_Paras:
+ def __init__(self,
+ num_stages,
+ fwd_durations,
+ bwd_durations,
+ num_microbatch,
+ comm_matrix):
+ self.num_stages = num_stages
+ self.num_microbatch = num_microbatch
+ self.fwd_durations = fwd_durations
+ self.bwd_durations = bwd_durations
+ self.comm_matrix = comm_matrix
+
+
+def dynamic_mbs_1f1b(paras):
+ num_stages = paras.num_stages
+ num_microbatch = paras.num_microbatch
+ computation_placement = list(range(num_stages)) + list(range(num_stages - 1, -1, -1))
+ fwd_durations = paras.fwd_durations
+ bwd_durations = paras.bwd_durations
+ comm_matrix = paras.comm_matrix
+
+ fwd_bwd_order = ([f'F_{i}' for i in range(num_stages)] +
+ [f'B_{i}' for i in range(num_stages - 1, -1, -1)])
+ fwd_bwd_chunk_stage = dict(zip(fwd_bwd_order, computation_placement))
+
+ def get_stage_list(fwd_seq, bwd_seq, num_advanced):
+ stage_order = []
+ n = len(fwd_seq)
+ for idx in range(n):
+ if idx < num_advanced:
+ stage_order.append(fwd_seq[idx])
+ else:
+ stage_order.append(fwd_seq[idx])
+ stage_order.append(bwd_seq[idx - num_advanced])
+ if idx == n - 1:
+ for i in range(num_advanced):
+ stage_order.append(bwd_seq[i - num_advanced])
+
+ return stage_order
+
+ def get_stage_schedule(all_jobs_array, comp_placement):
+ stage_list = []
+ for s in range(num_stages):
+ stage_chunk_id = [index for index, element in enumerate(comp_placement) if element == s]
+ warmup = num_stages - s
+ stage_s_list = get_stage_list(all_jobs_array[stage_chunk_id[0]],
+ all_jobs_array[stage_chunk_id[1]],
+ warmup - 1)
+ stage_list.append(stage_s_list)
+
+ return stage_list
+
+ all_jobs = np.array([[s + f'-{i}' for i in range(num_microbatch)] for s in fwd_bwd_order])
+ stage_list = get_stage_schedule(all_jobs, computation_placement)
+
+ fwd_bwd_list = ([f"F_{j}-{i}" for i in range(num_microbatch) for j in range(num_stages)]
+ + [f"B_{j}-{i}" for i in range(num_microbatch) for j in range(num_stages)])
+ values = [0 for _ in range(num_stages * num_microbatch * 2)]
+ start_time = dict(zip(fwd_bwd_list, values))
+ fwd_bwd_durations = dict()
+ for j in range(num_stages):
+ for i in range(num_microbatch):
+ fwd_bwd_durations[f"F_{j}-{i}"] = fwd_durations[j, i]
+ fwd_bwd_durations[f"B_{j}-{i}"] = bwd_durations[j, i]
+
+ for n in range(num_stages - 1):
+ for s in range(n + 1):
+ start_time[f"F_{s}-{n - s + 1}"] = max(start_time[f"F_{s}-{n - s + 1}"],
+ start_time[f"F_{s}-{n - s}"] + fwd_durations[s, n - s] + comm_matrix[s][s + 1])
+ start_time[f"F_{s + 1}-{n - s}"] = max(start_time[f"F_{s + 1}-{n - s}"],
+ start_time[f"F_{s}-{n - s}"] + fwd_durations[s, n - s] + comm_matrix[s][s + 1])
+
+ def get_prev_job_time(comp_start_time, pp_list, pp_id, mb_idx,
+ comp_chunk_stage, comp_order, model_chunk_times,
+ comm_time_matrix):
+ current_job = pp_list[pp_id][mb_idx]
+ prev_job_stage = pp_list[pp_id][mb_idx - 1]
+ chunk_prev_job_stage, _ = prev_job_stage.split('-')
+ stage_id_prev_job = comp_chunk_stage[chunk_prev_job_stage]
+ chunk_position = comp_order.index(chunk_prev_job_stage)
+ if chunk_position < len(comp_order) - 1:
+ stage_id_next = comp_chunk_stage[comp_order[chunk_position + 1]]
+ comm_time = comm_time_matrix[stage_id_prev_job][stage_id_next]
+ else:
+ comm_time = 0
+ end_time_prev_job_stage = (comp_start_time[prev_job_stage] + model_chunk_times[prev_job_stage]
+ + comm_time)
+
+ cur_model_chunk, cur_mb = current_job.split('-')
+ chunk_position = comp_order.index(cur_model_chunk)
+ if chunk_position > 0:
+ prev_model_chunk = comp_order[chunk_position - 1]
+ prev_job_batch = prev_model_chunk + '-' + cur_mb
+ comm_time = comm_time_matrix[comp_chunk_stage[prev_model_chunk]][comp_chunk_stage[cur_model_chunk]]
+ end_time_prev_job_batch = comp_start_time[prev_job_batch] + model_chunk_times[prev_job_batch] + comm_time
+ completed_flag = comp_start_time[prev_job_stage] > 0 and comp_start_time[prev_job_batch] > 0
+ else:
+ end_time_prev_job_batch = 0
+ completed_flag = comp_start_time[prev_job_stage] > 0
+
+ return end_time_prev_job_stage, end_time_prev_job_batch, completed_flag
+
+ begin_up = [num_stages - s for s in range(num_stages)]
+ remaining = [num_microbatch * 2 - begin_up[p] for p in range(num_stages)]
+ remaining_flag = True
+ while remaining_flag:
+ ids_old = []
+ ids_new = []
+ for s in range(num_stages):
+ ids_old.append(remaining[s])
+ if remaining[s]:
+ idx = len(stage_list[0]) - remaining[s]
+ end_time_prev_stage, end_time_prev_batch, job_flag = get_prev_job_time(start_time, stage_list, s, idx,
+ fwd_bwd_chunk_stage,
+ fwd_bwd_order,
+ fwd_bwd_durations,
+ comm_matrix)
+
+ if job_flag:
+ start_time[stage_list[s][idx]] = max(end_time_prev_stage, end_time_prev_batch)
+ remaining[s] = remaining[s] - 1
+
+ ids_new.append(remaining[s])
+ if all(item == 0 for item in remaining):
+ remaining_flag = False
+ if ids_old == ids_new:
+ break
+
+ e2e_time = start_time[f'B_0-{num_microbatch-1}'] + bwd_durations[0, -1]
+ stage_start_time = [[start_time[job_name] for job_name in stage_list[s]] for s in range(num_stages)]
+ return e2e_time, stage_start_time, stage_list, start_time
+
+
+def find_integer_solutions(coefficients, global_batch_size):
+ n = len(coefficients)
+ mbs_max_value = (n + 1) // 2
+ solutions = []
+ all_comb = []
+ for i in range(n):
+ if i == mbs_max_value - 1:
+ batch_using = sum(coefficients[0:mbs_max_value - 1] * 4)
+ all_comb.append(list(range((global_batch_size - batch_using) // mbs_max_value,
+ global_batch_size // mbs_max_value + 1)))
+ else:
+ all_comb.append(list(range(4)))
+
+ for x in product(*all_comb):
+ if sum(coefficients[i] * x[i] for i in range(n)) == global_batch_size:
+ solutions.append(x)
+
+ return solutions
+
+
+def dynamic_mbs_search(num_stages, global_batch_size, fwd_mbs, bwd_mbs, comm_matrix):
+ comp_mbs_ratio = [value / (index + 1) for index, value in enumerate(fwd_mbs)]
+ fwd_mbs_selected = fwd_mbs[0:comp_mbs_ratio.index(min(comp_mbs_ratio)) + 1]
+ bwd_mbs_selected = bwd_mbs[0:comp_mbs_ratio.index(min(comp_mbs_ratio)) + 1]
+ mbs_max_value = len(fwd_mbs_selected)
+ bwd_mbs_stages = [fwd_mbs_selected] * num_stages
+ fwd_mbs_stages = [bwd_mbs_selected] * num_stages
+
+ coefficients = list(range(1, mbs_max_value + 1)) + list(range(mbs_max_value - 1, 0, -1))
+ solutions = find_integer_solutions(coefficients, global_batch_size)
+
+ mbs_list = sum([solutions[0][i] * [coefficients[i]] for i in range(len(solutions[0]))], [])
+ num_microbatch = len(mbs_list)
+ fwd_durations = np.zeros([num_stages, num_microbatch])
+ bwd_durations = np.zeros([num_stages, num_microbatch])
+ for j in range(num_microbatch):
+ for i in range(num_stages):
+ fwd_durations[i, j] = fwd_mbs_stages[i][mbs_list[j] - 1]
+ bwd_durations[i, j] = bwd_mbs_stages[i][mbs_list[j] - 1]
+
+ paras = Parallel_Paras(num_stages, fwd_durations, bwd_durations, num_microbatch, comm_matrix)
+ e2e_time = []
+ for sol in solutions:
+ mbs_list = sum([sol[i] * [coefficients[i]] for i in range(len(sol))], [])
+ num_microbatch = len(mbs_list)
+ fwd_durations = np.zeros([num_stages, num_microbatch])
+ bwd_durations = np.zeros([num_stages, num_microbatch])
+ for j in range(num_microbatch):
+ for i in range(num_stages):
+ fwd_durations[i, j] = fwd_mbs_stages[i][mbs_list[j] - 1]
+ bwd_durations[i, j] = bwd_mbs_stages[i][mbs_list[j] - 1]
+
+ paras.fwd_durations = fwd_durations
+ paras.bwd_durations = bwd_durations
+ paras.num_microbatch = num_microbatch
+
+ e2e_time0, stage_start_time0, stage_list0, start_time0 = dynamic_mbs_1f1b(paras)
+ e2e_time.append(e2e_time0)
+
+ e2e_time_array = np.array(e2e_time)
+ optimal_solution = solutions[e2e_time_array.argmin()]
+ return optimal_solution, e2e_time_array.min()
+
+
+def broadcast_oom_in_ranks(src_rank, policy):
+ is_oom = [True]
+ if torch.distributed.get_rank() == src_rank:
+ is_oom = [policy]
+ tmp_is_oom = torch.cuda.BoolTensor(is_oom)
+ torch.distributed.broadcast(tmp_is_oom, src=src_rank)
+ return tmp_is_oom.item()
+
+
+def broadcast_mbs_in_ranks(src_rank, optimal_solution):
+ args = get_args()
+ solution_length = [0]
+ if torch.distributed.get_rank() == src_rank:
+ solution_length = [len(optimal_solution)]
+ tmp_solution_length = torch.cuda.IntTensor(solution_length)
+ torch.distributed.broadcast(tmp_solution_length, src=src_rank)
+ solution_length = tmp_solution_length.item()
+
+ tmp_optimal_solution = [0] * solution_length
+ if torch.distributed.get_rank() == src_rank:
+ tmp_optimal_solution = optimal_solution
+ tmp_optimal_solution = torch.cuda.IntTensor(tmp_optimal_solution)
+ torch.distributed.broadcast(tmp_optimal_solution, src=src_rank)
+ tmp_optimal_solution = tmp_optimal_solution.tolist()
+ mbs_max_value = math.ceil(len(tmp_optimal_solution) / 2)
+ coefficients = list(range(1, mbs_max_value + 1)) + list(range(mbs_max_value - 1, 0, -1))
+ optimal_mbs_list = sum([tmp_optimal_solution[i] * [coefficients[i]] for i in range(len(tmp_optimal_solution))], [])
+ args.optimized_mbs_list = optimal_mbs_list
+ return optimal_mbs_list
+
+
+def get_profiling_data(policy, args):
+ instance = {"model_configs": {
+ "hidden_size": args.hidden_size,
+ "ffn_hidden_size": args.ffn_hidden_size,
+ "seq_length": args.seq_length,
+ "num_attention_heads": args.num_attention_heads
+ }, "optimpipeline_policy": [{
+ "num_layers": args.num_layers,
+ "pipeline_model_parallel_size": args.pipeline_model_parallel_size,
+ "tensor_model_parallel_size": args.tensor_model_parallel_size,
+ "micro_batch_size": args.micro_batch_size,
+ "global_batch_size": args.global_batch_size,
+ "enable_scheduler": policy[0],
+ "optimized_mbs_list": policy[1],
+ "pp_schedule_list": policy[2],
+ "optimal_layers": policy[3]
+ }]}
+ return instance
+
+
+def save_profiling_data(policy, config_file):
+ if torch.distributed.get_rank() % int(os.getenv('GPUS_PER_NODE', '8')) == 0:
+ new_parse_args = parse_args_wrapper(parse_args)
+ args = new_parse_args(None, False)
+ instance = get_profiling_data(policy, args)
+ if os.path.exists(config_file):
+ with open(config_file, "r") as config_json:
+ config_contents = config_json.read()
+ parsed_contents = json.loads(config_contents)
+ index = check_equal_model_configs(args, parsed_contents)
+ if index != -1:
+ if "optimpipeline_policy" in parsed_contents[index]:
+ parsed_contents[index]["optimpipeline_policy"].append(instance["optimpipeline_policy"][0])
+ else:
+ parsed_contents.append(instance)
+ with open(config_file, "w") as f:
+ json.dump(parsed_contents, f, ensure_ascii=False)
+ os.chmod(config_file, 0o644)
+ else:
+ with open(config_file, "w") as f:
+ json.dump([instance], f, ensure_ascii=False)
+ os.chmod(config_file, 0o644)
+
+
+def solve_optimpipeline(args, data_parallel_size, global_context):
+ mbs_max_value = len(global_context)
+ coefficients = list(range(1, mbs_max_value + 1)) + list(range(mbs_max_value - 1, 0, -1))
+ optimal_solution = [0] * len(coefficients)
+ optimal_time = 0
+ if torch.distributed.get_rank() == 0:
+ num_stages = args.pipeline_model_parallel_size
+ global_batch_size = args.global_batch_size // data_parallel_size
+ fwd_mbs = [item[0] for item in global_context]
+ bwd_mbs = [item[1] for item in global_context]
+ comm_matrix = [[0.05] * num_stages for _ in range(num_stages)]
+ for i in range(num_stages):
+ comm_matrix[i][i] = 0
+
+ optimal_solution, optimal_time = dynamic_mbs_search(num_stages, global_batch_size, fwd_mbs, bwd_mbs, comm_matrix)
+ torch.distributed.barrier()
+ return optimal_solution, optimal_time
diff --git a/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/schedulepipeline_solver.py b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/schedulepipeline_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..13bacd71df827f97eacc198de85485d6716f9178
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/schedulepipeline_solver.py
@@ -0,0 +1,445 @@
+import time
+import json
+import numpy as np
+import torch
+import torch_npu
+from megatron.training import get_args
+from megatron.training import print_rank_0
+
+
+class PipelineParallelParas:
+ def __init__(self,
+ num_stages,
+ fwd_durations,
+ bwd_durations,
+ num_microbatches,
+ comm_matrix,
+ num_layers):
+ self.num_stages = num_stages
+ self.num_microbatches = num_microbatches
+ self.fwd_durations = fwd_durations
+ self.bwd_durations = bwd_durations
+ self.comm_matrix = comm_matrix
+ self.num_layers = num_layers
+
+
+
+def time_model_1f1b(paras):
+ # obtain the E2E time for 1F1B scheme
+ num_stages = paras.num_stages
+ num_micro_batches = paras.num_microbatches
+ fwd_durations = paras.fwd_durations
+ bwd_durations = paras.bwd_durations
+ p2p_matrix = paras.comm_matrix
+ fwd_start = np.zeros([num_stages, num_micro_batches])
+ bwd_start = np.zeros([num_stages, num_micro_batches])
+
+ warmup = [num_stages - s for s in range(num_stages)]
+ remaining = [num_micro_batches - warmup[p] for p in range(num_stages)]
+ # warm_up stage-0
+ for m in range(num_stages):
+ fwd_start[0, m] = m * fwd_durations[0]
+ # warm_up stage
+ for s in range(1, num_stages, 1):
+ fwd_start[s, 0] = fwd_start[s - 1, 0] + fwd_durations[s - 1] + p2p_matrix[s - 1][s]
+ for m in range(1, num_stages - s, 1):
+ fwd_start[s, m] = max(fwd_start[s - 1, m] + fwd_durations[s - 1] + p2p_matrix[s - 1][s],
+ fwd_start[s, m - 1] + fwd_durations[s])
+
+ # 0 micro batch at last stage bwd start
+ bwd_start[num_stages - 1, 0] = fwd_start[num_stages - 1, 0] + fwd_durations[num_stages - 1]
+ for s in range(num_stages - 2, -1, -1):
+ bwd_start[s, 0] = bwd_start[s + 1, 0] + bwd_durations[s + 1] + p2p_matrix[s + 1][s]
+
+ # steady state
+ for m in range(1, num_micro_batches, 1):
+ # forward time
+ for s in range(num_stages):
+ if m > remaining[s]:
+ continue
+ if s == 0:
+ fwd_start[s, m + num_stages - 1] = bwd_start[s, m - 1] + bwd_durations[s]
+ else:
+ fwd_start[s, m + num_stages - s - 1] = max(
+ fwd_start[s - 1, m + num_stages - s - 1] + fwd_durations[s - 1] + p2p_matrix[s - 1][s],
+ bwd_start[s, m - 1] + bwd_durations[s])
+
+ # backward time
+ for s in range(num_stages - 1, -1, -1):
+ # cool down stage
+ if m + num_stages - s > num_micro_batches:
+ bwd_start[s, m] = bwd_start[s + 1, m] + bwd_durations[s + 1] + p2p_matrix[s + 1][s]
+ continue
+
+ if s == num_stages - 1:
+ bwd_start[s, m] = fwd_start[s, m] + fwd_durations[s]
+ else:
+ bwd_start[s, m] = max(bwd_start[s + 1, m] + bwd_durations[s + 1] + p2p_matrix[s + 1][s],
+ fwd_start[s, m + num_stages - s - 1] + fwd_durations[s])
+
+ e2e_time = bwd_start[0, -1] + bwd_durations[0]
+ return e2e_time, fwd_start, bwd_start
+
+
+def time_model_nfmb(paras, stage_schedule):
+ # 给定一个调度序列,计算端到端时间
+ num_stages = paras.num_stages
+ num_mb = paras.num_microbatches
+ comm_matrix = paras.comm_matrix
+ chunk_placement = list(range(num_stages)) + list(range(num_stages - 1, -1, -1))
+ # Fwd Bwd执行顺序
+ fwd_bwd_comp_order = ([f'F_{i}' for i in range(num_stages)] +
+ [f'B_{i}' for i in range(num_stages - 1, -1, -1)])
+ chunk_stage_map = dict(zip(fwd_bwd_comp_order, chunk_placement))
+
+ if isinstance(stage_schedule, dict):
+ stage_list = []
+ for s in range(num_stages):
+ fb_list = stage_schedule[f"stage{s}"]
+ stage_list.append([element[0]+f"_{s}-"+element[1:] for element in fb_list])
+ else:
+ stage_list = stage_schedule
+
+ # 初始化
+ fwd_bwd_list = ([f"F_{j}-{i}" for i in range(num_mb) for j in range(num_stages)]
+ + [f"B_{j}-{i}" for i in range(num_mb) for j in range(num_stages)])
+ values = [0 for _ in range(num_stages * num_mb * 2)]
+ start_time = dict(zip(fwd_bwd_list, values))
+ fwd_bwd_durations = dict()
+ fwd_durations = np.array(paras.fwd_durations * num_mb).reshape(num_mb, num_stages).transpose()
+ bwd_durations = np.array(paras.bwd_durations * num_mb).reshape(num_mb, num_stages).transpose()
+ for j in range(num_stages):
+ for i in range(num_mb):
+ fwd_bwd_durations[f"F_{j}-{i}"] = fwd_durations[j, i]
+ fwd_bwd_durations[f"B_{j}-{i}"] = bwd_durations[j, i]
+
+ start_time[f"F_{0}-{0}"] = 0.1
+ for s in range(num_stages - 1):
+ start_time[f"F_{s + 1}-{0}"] = start_time[f"F_{s}-{0}"] + fwd_durations[s, 0] + comm_matrix[s][s + 1]
+
+ # 获取当前任务的上一个任务以及依赖任务的结束时间
+ def get_prev_task_time(task_start_time, task_list, pp_stage_id, mb_idx,
+ chunk_stage_map, comp_order, model_chunk_times,
+ comm_time_matrix):
+ current_task = task_list[pp_stage_id][mb_idx]
+ prev_task_same_stage = task_list[pp_stage_id][mb_idx - 1]
+ chunk_id_prev_task_same_stage, _ = prev_task_same_stage.split('-')
+ stage_id_prev_task = chunk_stage_map[chunk_id_prev_task_same_stage]
+ chunk_position = comp_order.index(chunk_id_prev_task_same_stage)
+ # 前一个任务计算完成后的通信时间
+ if chunk_position < len(comp_order) - 1:
+ stage_id_next = chunk_stage_map[comp_order[chunk_position + 1]]
+ comm_time = comm_time_matrix[stage_id_prev_task][stage_id_next]
+ else:
+ comm_time = 0.01
+ # 同一个stage上,前一个任务完成时间
+ end_time_prev_task_stage = (task_start_time[prev_task_same_stage]
+ + model_chunk_times[prev_task_same_stage]
+ + comm_time)
+
+ # 相同micro batch id,上一个model chunk上的计算时间
+ cur_model_chunk, cur_mb = current_task.split('-')
+ chunk_position = comp_order.index(cur_model_chunk)
+ if chunk_position > 0:
+ prev_model_chunk = comp_order[chunk_position - 1]
+ prev_task_batch = prev_model_chunk + '-' + cur_mb
+ comm_time = comm_time_matrix[chunk_stage_map[prev_model_chunk]][chunk_stage_map[cur_model_chunk]]
+ end_time_dependent_task_batch = (task_start_time[prev_task_batch]
+ + model_chunk_times[prev_task_batch]
+ + comm_time)
+ completed_flag = task_start_time[prev_task_same_stage] > 0 and task_start_time[prev_task_batch] > 0
+ else:
+ end_time_dependent_task_batch = 0.1
+ completed_flag = task_start_time[prev_task_same_stage] > 0
+
+ return end_time_prev_task_stage, end_time_dependent_task_batch, completed_flag
+
+ # 更新计算时间
+ begin_up = [1] * num_stages
+ remaining = [num_mb * 2 - begin_up[p] for p in range(num_stages)]
+ remaining_flag = True
+ count = 0
+ while remaining_flag:
+ ids_old = []
+ ids_new = []
+ for s in range(num_stages):
+ ids_old.append(remaining[s])
+ if remaining[s]:
+ microbatch_idx = len(stage_list[0]) - remaining[s]
+ (end_time_prev_task_same_stage,
+ end_time_dependent_task_same_microbatch,
+ job_flag) = get_prev_task_time(start_time, stage_list, s, microbatch_idx, chunk_stage_map,
+ fwd_bwd_comp_order, fwd_bwd_durations, comm_matrix)
+
+ if job_flag:
+ start_time[stage_list[s][microbatch_idx]] = max(end_time_prev_task_same_stage,
+ end_time_dependent_task_same_microbatch)
+ remaining[s] = remaining[s] - 1
+
+ ids_new.append(remaining[s])
+
+ if all(item == 0 for item in remaining):
+ remaining_flag = False
+
+ if ids_old == ids_new:
+ count += 1
+ if count == 3:
+ start_time[f'B_0-{num_mb - 1}'] = 1e7
+ break
+
+ e2e_time = start_time[f'B_0-{num_mb - 1}'] + bwd_durations[0, -1]
+ stage_start_time = [[start_time[job_name] for job_name in stage_list[s]] for s in range(num_stages)]
+
+ return e2e_time, stage_start_time
+
+
+def get_schedule_1f1b(paras):
+ # generate 1f1b schedule list
+ num_stages = paras.num_stages
+ num_microbatches = paras.num_microbatches
+ computation_placement = list(range(num_stages)) + list(range(num_stages - 1, -1, -1))
+
+ # Fwd Bwd执行顺序
+ fwd_bwd_order = ([f'F_{i}' for i in range(num_stages)] +
+ [f'B_{i}' for i in range(num_stages - 1, -1, -1)])
+
+ # 根据1F1B策略生成每个stage上的调度顺序
+ def get_stage_list(fwd_seq, bwd_seq, num_advanced):
+ stage_order = []
+ n = len(fwd_seq)
+ for idx in range(n):
+ if idx < num_advanced:
+ stage_order.append(fwd_seq[idx])
+ else:
+ stage_order.append(fwd_seq[idx])
+ stage_order.append(bwd_seq[idx - num_advanced])
+ if idx == n - 1:
+ for i in range(num_advanced):
+ stage_order.append(bwd_seq[i - num_advanced])
+
+ return stage_order
+
+ def get_stage_schedule(all_jobs_array, comp_placement, num_stages):
+ stage_list = []
+ for s in range(num_stages):
+ stage_chunk_id = [index for index, element in enumerate(comp_placement) if element == s]
+ warmup = num_stages - s
+ stage_s_list = get_stage_list(all_jobs_array[stage_chunk_id[0]],
+ all_jobs_array[stage_chunk_id[1]],
+ warmup - 1)
+ stage_list.append(stage_s_list)
+ return stage_list
+
+ all_jobs = np.array([[s + f'-{i}' for i in range(num_microbatches)] for s in fwd_bwd_order])
+ stage_list = get_stage_schedule(all_jobs, computation_placement, num_stages)
+ stage_schedule_dict = dict()
+ for s in range(paras.num_stages):
+ stage_s_list = []
+ for element in stage_list[s]:
+ item1, item2 = element.split("-")
+ stage_s_list.append(item1[0] + item2)
+ stage_schedule_dict[f"stage{s}"] = stage_s_list
+ return stage_schedule_dict
+
+
+def get_schedule_eager1f1b(paras, num_forwards, layers_placement):
+ # generate 1f1b schedule list
+ num_stages = paras.num_stages
+ num_microbatches = paras.num_microbatches
+ # 将原始模型切分为多个model chunk,chunk在PP stage上的放置顺序
+ chunk_placement = list(range(num_stages)) + list(range(num_stages - 1, -1, -1))
+
+ # Fwd Bwd执行顺序
+ fwd_bwd_comp_order = ([f'F_{i}' for i in range(num_stages)] +
+ [f'B_{i}' for i in range(num_stages - 1, -1, -1)])
+
+ # 根据1F1B策略生成每个stage上的调度顺序
+ def get_stage_list(fwd_seq, bwd_seq, num_advanced):
+ stage_order = []
+ n = len(fwd_seq)
+ for idx in range(n):
+ if idx < num_advanced:
+ stage_order.append(fwd_seq[idx])
+ else:
+ stage_order.append(fwd_seq[idx])
+ stage_order.append(bwd_seq[idx - num_advanced])
+ if idx == n - 1:
+ for i in range(num_advanced):
+ stage_order.append(bwd_seq[i - num_advanced])
+
+ return stage_order
+
+ def get_stage_schedule(all_jobs_array, comp_placement, num_advanced, paras, layers_placement):
+ stage_list = []
+ activations_num = int(paras.num_layers // paras.num_stages) * (num_advanced + paras.num_stages)
+ nums_under_memory = [int(activations_num // layers_placement[i]) for i in range(paras.num_stages)]
+ warmups = [min(nums_under_memory[s] - s - 1,
+ 2 * paras.num_stages - 2 * s - 2) for s in range(paras.num_stages)]
+ for i in range(paras.num_stages - 1):
+ warmups[i + 1] = min(warmups[i] - 1, warmups[i + 1])
+ warmups[i + 1] = max(warmups[i + 1], 0)
+
+ for s in range(paras.num_stages):
+ stage_chunk_id = [index for index, element in enumerate(comp_placement) if element == s]
+ num = sum(np.array(paras.bwd_durations[s + 1:])
+ + np.array(paras.fwd_durations[s + 1:])) // np.array(paras.fwd_durations[s])
+ stage_s_list = get_stage_list(all_jobs_array[stage_chunk_id[0]],
+ all_jobs_array[stage_chunk_id[1]],
+ warmups[s])
+ stage_list.append(stage_s_list)
+ return stage_list
+
+ all_jobs = np.array([[s + f'-{i}' for i in range(num_microbatches)] for s in fwd_bwd_comp_order])
+ stage_list = get_stage_schedule(all_jobs, chunk_placement, num_forwards, paras, layers_placement)
+
+ # 转换为dictionary
+ stage_schedule_dict = dict()
+ for s in range(paras.num_stages):
+ stage_s_list = []
+ for element in stage_list[s]:
+ item1, item2 = element.split("-")
+ stage_s_list.append(item1[0] + item2)
+ stage_schedule_dict[f"stage{s}"] = stage_s_list
+
+ return stage_schedule_dict
+
+
+def schedule_layers(paras, num_mb_for_remaining_memory):
+ # 调整层分布,对比层分布改变后,1F1B建模时间
+ stage_layers = int(paras.num_layers // paras.num_stages)
+ if paras.num_stages > 2:
+ fwd_time_per_layer = sum(paras.fwd_durations[1:-1]) / (paras.num_stages - 2) / stage_layers
+ bwd_time_per_layer = sum(paras.bwd_durations[1:-1]) / (paras.num_stages - 2) / stage_layers
+ else:
+ fwd_time_per_layer = paras.fwd_durations[0] / stage_layers
+ bwd_time_per_layer = paras.bwd_durations[0] / stage_layers
+
+ # 1f1b as baseline
+ e2e_time = np.ones([2, paras.num_stages]) * 1e9
+ paras_all = []
+ layers_placement = []
+ schedule_1f1b = get_schedule_1f1b(paras)
+ e2e_time[0, 0], stage_start_time1 = time_model_nfmb(paras, schedule_1f1b)
+ paras_all.append(paras)
+ layers_p1 = [stage_layers] * paras.num_stages
+ layers_placement.append(layers_p1)
+ # 调度序列
+ schedule_eager_1f1b = get_schedule_eager1f1b(paras, num_mb_for_remaining_memory, layers_p1)
+ e2e_time[1, 0], stage_start_time2 = time_model_nfmb(paras, schedule_eager_1f1b)
+
+ if stage_layers >= 2:
+ for i in range(paras.num_stages - 1):
+ fwd_new = np.array(paras.fwd_durations)
+ fwd_new[i] += fwd_time_per_layer
+ fwd_new[-1] -= fwd_time_per_layer
+ bwd_new = np.array(paras.bwd_durations)
+ bwd_new[i] += bwd_time_per_layer
+ bwd_new[-1] -= bwd_time_per_layer
+ paras1 = PipelineParallelParas(paras.num_stages,
+ fwd_new.tolist(),
+ bwd_new.tolist(),
+ paras.num_microbatches,
+ paras.comm_matrix,
+ paras.num_layers)
+ e2e_time[0, i + 1], stage_start_time1 = time_model_nfmb(paras1, schedule_1f1b)
+ paras_all.append(paras1)
+ layers_p1 = [stage_layers] * paras.num_stages
+ layers_p1[i] += 1
+ layers_p1[-1] -= 1
+ layers_placement.append(layers_p1)
+ schedule_eager_1f1b = get_schedule_eager1f1b(paras1, num_mb_for_remaining_memory, layers_p1)
+ e2e_time[1, i + 1], stage_start_time2 = time_model_nfmb(paras1, schedule_eager_1f1b)
+
+ optimal_paras = paras_all[e2e_time[1, :].argmin()]
+ optimal_layer = layers_placement[e2e_time[1, :].argmin()]
+ schedule_scheme = get_schedule_eager1f1b(optimal_paras, num_mb_for_remaining_memory, optimal_layer)
+
+ return schedule_scheme, optimal_layer, e2e_time[1, :].min()
+
+
+def broadcast_enable_schedule_in_ranks(src_rank, policy):
+ enable_schedule = [False]
+ if torch.distributed.get_rank() == src_rank:
+ enable_schedule = [policy]
+ tmp_enable_schedule = torch.cuda.BoolTensor(enable_schedule)
+ torch.distributed.broadcast(tmp_enable_schedule, src=src_rank)
+ return tmp_enable_schedule.item()
+
+
+def broadcast_scheduler_in_ranks(src_rank, policy):
+ args = get_args()
+ policy_str = json.dumps(policy)
+ byte_tensor = torch.cuda.ByteTensor(list(policy_str.encode()))
+ torch.distributed.broadcast(byte_tensor, src_rank)
+ if torch.distributed.get_rank() != 0:
+ received_byte_tensor = torch.cuda.ByteTensor([0] * len(byte_tensor))
+ else:
+ received_byte_tensor = byte_tensor.clone()
+ torch.distributed.broadcast(received_byte_tensor, src_rank)
+ received_policy_str = ''.join([chr(byte) for byte in received_byte_tensor.tolist()])
+ received_policy_data = json.loads(received_policy_str)
+ args.pp_schedule_list = received_policy_data
+ return received_policy_data
+
+
+def broadcast_layer_in_ranks(src_rank, policy):
+ args = get_args()
+ num_layer_list = args.pipeline_model_parallel_size * [0]
+ if torch.distributed.get_rank() == 0:
+ num_layer_list = policy
+ tmp_layer_list = torch.cuda.IntTensor(num_layer_list)
+ torch.distributed.broadcast(tmp_layer_list, src=src_rank)
+ args.num_layer_list = tmp_layer_list.tolist()
+ return tmp_layer_list.tolist()
+
+
+def all_gather_time(args, gather_time):
+ recv_gather_time_list = []
+ world_size = torch.distributed.get_world_size()
+ gather_time = torch.cuda.FloatTensor([gather_time])
+ gathered_tensors = [torch.zeros_like(gather_time) for _ in range(world_size)]
+ torch.distributed.all_gather(gathered_tensors, gather_time)
+ for rank, tensor in enumerate(gathered_tensors):
+ pipeline_stage_rank = get_pipeline_stage_rank(world_size, args.pipeline_model_parallel_size, rank)
+ recv_gather_time_list.append((pipeline_stage_rank, tensor.item()))
+ return recv_gather_time_list
+
+
+def average_time_by_rank(time_list):
+ time_dict = {}
+ for item in time_list:
+ if item[0] not in time_dict:
+ time_dict[item[0]] = item[1]
+ else:
+ time_dict[item[0]] += item[1]
+ time_dict[item[0]] /= 2
+ return time_dict
+
+
+def get_pipeline_stage_rank(world_size, num_stages, global_rank):
+ assert world_size % num_stages == 0, "World size must be divisible by the number of stages"
+ assert global_rank < world_size, "Global rank must be less than world size"
+
+ stage_size = world_size // num_stages
+ return global_rank // stage_size
+
+
+def solve_pipelineschedule(args, data_parallel_size, num_forwards_first_stage, forward_time_dict, backward_time_dict):
+ pipeline_stages = args.pipeline_model_parallel_size
+ forward_time_each_stage = [forward_time_dict[rank] for rank in forward_time_dict]
+ backward_time_each_stage = [backward_time_dict[rank] for rank in backward_time_dict]
+ comm_matrix = [[0.05] * pipeline_stages for _ in range(pipeline_stages)]
+ num_micro_batches = args.global_batch_size // data_parallel_size // args.micro_batch_size
+ num_layers = args.num_layers
+
+ pp_paras = PipelineParallelParas(pipeline_stages,
+ forward_time_each_stage,
+ backward_time_each_stage,
+ num_micro_batches,
+ comm_matrix,
+ num_layers)
+ # FB schedule
+ start_time = time.time()
+ schedule_list, optimal_layers, schedule_time = schedule_layers(pp_paras, num_forwards_first_stage)
+ end_time = time.time()
+ return schedule_list, schedule_time, optimal_layers
diff --git a/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/schedules.py b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c009216e7ec0966d85a9347beafbb8b2da36201
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/schedules.py
@@ -0,0 +1,274 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+from functools import wraps
+import contextlib
+from typing import Callable, Iterator, List, Optional, Union
+import torch
+from torch.autograd.variable import Variable
+from megatron.core import parallel_state
+from megatron.core.enums import ModelType
+from megatron.core.pipeline_parallel import p2p_communication
+from megatron.core.utils import get_model_config, get_model_type
+from megatron.training import get_args
+from megatron.core.pipeline_parallel.schedules import forward_step, backward_step, deallocate_output_tensor, check_first_val_step
+from mindspeed.core.performance.auto_pipeline_perf.autopipeline_perf import profile_context
+import mindspeed.core.training as training
+
+
+def get_forward_backward_func_decorator(get_forward_backward_func):
+ @wraps(get_forward_backward_func)
+ def wrapper(*args, **kwargs):
+ argument = get_args()
+ pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ if pipeline_model_parallel_size > 1 and argument.automated_pipeline_perf and argument.optimized_mbs_list:
+ forward_backward_func = optimized_forward_backward_pipelining
+ else:
+ forward_backward_func = get_forward_backward_func(*args, **kwargs)
+ return forward_backward_func
+ return wrapper
+
+
+def forward_step_decorator(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ argument = get_args()
+ if argument.automated_pipeline_perf and not (argument.optimized_mbs_list or argument.pp_schedule_list):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ output_tensor = fn(*args, **kwargs)
+ torch.cuda.synchronize()
+ profile_context["fwd_time"].append((time.time() - start_time) * 1000)
+ else:
+ output_tensor = fn(*args, **kwargs)
+ return output_tensor
+
+ return wrapper
+
+
+def backward_step_decorator(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ argument = get_args()
+ if argument.automated_pipeline_perf and not (argument.optimized_mbs_list or argument.pp_schedule_list):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ input_tensor_grad = fn(*args, **kwargs)
+ torch.cuda.synchronize()
+ profile_context["bwd_time"].append((time.time() - start_time) * 1000)
+ else:
+ input_tensor_grad = fn(*args, **kwargs)
+ return input_tensor_grad
+ return wrapper
+
+
+def get_tensor_shapes():
+ args = get_args()
+ tensor_shapes = []
+ mbs = args.optimized_mbs_list
+ for m in mbs:
+ tensor_shapes.append((args.seq_length // parallel_state.get_context_parallel_world_size() // parallel_state.get_tensor_model_parallel_world_size(), m, args.hidden_size))
+ return tensor_shapes
+
+
+def optimized_forward_backward_pipelining(
+ *,
+ forward_step_func,
+ data_iterator: Union[Iterator, List[Iterator]],
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
+ num_microbatches: int,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int = None,
+ forward_only: bool = False,
+ collect_non_loss_data: bool = False,
+ first_val_step: bool = None,
+):
+ """Run non-interleaved 1F1B schedule, with reduced pipeline bubble.
+ Returns dictionary with losses if the last stage, empty dict otherwise.
+ """
+ if isinstance(model, list):
+ model = model[0]
+ if isinstance(data_iterator, list):
+ data_iterator = data_iterator[0]
+ argument = get_args()
+ config = get_model_config(model)
+ model_type = get_model_type(model)
+ tensor_shapes = get_tensor_shapes()
+ cnt_fwd, cnt_bwd = 0, 0
+ argument.mbs_idx = cnt_fwd
+ argument.optimized_mbs_mode = True
+ num_microbatches = len(argument.optimized_mbs_list)
+ if config.overlap_p2p_comm:
+ raise ValueError(
+ "Optimized pipeline parallelism does not support overlapping p2p communication"
+ )
+
+ # Disable async grad reductions
+ no_sync_func = config.no_sync_func
+ if no_sync_func is None:
+ no_sync_func = contextlib.nullcontext
+ no_sync_context = None
+
+ def disable_grad_sync():
+ """Disable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is None:
+ no_sync_context = no_sync_func()
+ no_sync_context.__enter__()
+
+ def enable_grad_sync():
+ """Enable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is not None:
+ no_sync_context.__exit__(None, None, None)
+ no_sync_context = None
+
+ disable_grad_sync()
+
+ # Compute number of warmup microbatches.
+ num_warmup_microbatches = \
+ (parallel_state.get_pipeline_model_parallel_world_size() -
+ parallel_state.get_pipeline_model_parallel_rank() - 1)
+ num_warmup_microbatches = min(
+ num_warmup_microbatches,
+ num_microbatches)
+ num_microbatches_remaining = \
+ num_microbatches - num_warmup_microbatches
+
+ input_tensors = []
+ output_tensors = []
+ forward_data_store = []
+ rank = parallel_state.get_pipeline_model_parallel_rank()
+
+ # Run warmup forward passes.
+ for i in range(num_warmup_microbatches):
+ input_tensor = p2p_communication.recv_forward(config=config,
+ tensor_shape=tensor_shapes[cnt_fwd])
+ argument.micro_batch_size = argument.optimized_mbs_list[cnt_fwd]
+ output_tensor = forward_step(
+ forward_step_func,
+ data_iterator,
+ model,
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ None,
+ check_first_val_step(first_val_step, forward_only, i == 0),
+ )
+ p2p_communication.send_forward(output_tensor, config=config)
+ cnt_fwd += 1
+ input_tensors.append(input_tensor)
+ output_tensors.append(output_tensor)
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Before running 1F1B, need to receive first forward tensor.
+ # If all microbatches are run in warmup / cooldown phase, then no need to
+ # receive this tensor here.
+ if num_microbatches_remaining > 0:
+ input_tensor = p2p_communication.recv_forward(config=config,
+ tensor_shape=tensor_shapes[cnt_fwd])
+
+ # Run 1F1B in steady state.
+ for i in range(num_microbatches_remaining):
+ last_iteration = (i == (num_microbatches_remaining - 1))
+ argument.micro_batch_size = argument.optimized_mbs_list[cnt_fwd]
+ output_tensor = forward_step(
+ forward_step_func,
+ data_iterator,
+ model,
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ None,
+ check_first_val_step(
+ first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0)
+ ),
+ )
+ if forward_only:
+ p2p_communication.send_forward(output_tensor, config=config)
+ if not last_iteration:
+ input_tensor = p2p_communication.recv_forward(tensor_shapes=tensor_shapes[cnt_fwd], config=config)
+ else:
+ output_tensor_grad = \
+ p2p_communication.send_forward_recv_backward(output_tensor,
+ tensor_shape=tensor_shapes[cnt_bwd], config=config)
+
+ cnt_fwd += 1
+ # Add input_tensor and output_tensor to end of list, then pop from the
+ # start of the list for backward pass.
+ input_tensors.append(input_tensor)
+ output_tensors.append(output_tensor)
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ if forward_only:
+ if not last_iteration:
+ input_tensor = p2p_communication.recv_forward(config=config,
+ tensor_shape=tensor_shapes[cnt_fwd])
+ else:
+ input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
+ if num_warmup_microbatches == 0 and last_iteration:
+ if config.grad_sync_func is None or rank == 0:
+ enable_grad_sync()
+
+ input_tensor_grad = \
+ backward_step(input_tensor, output_tensor,
+ output_tensor_grad, model_type, config)
+
+ if last_iteration:
+ input_tensor = None
+ p2p_communication.send_backward(input_tensor_grad, config=config)
+ else:
+ input_tensor = \
+ p2p_communication.send_backward_recv_forward(
+ input_tensor_grad, tensor_shape=tensor_shapes[cnt_fwd], config=config)
+ cnt_bwd += 1
+
+ # Run cooldown backward passes.
+ if not forward_only:
+ for _ in range(num_warmup_microbatches):
+ if i == num_warmup_microbatches - 1:
+ if config.grad_sync_func is None or rank == 0:
+ enable_grad_sync()
+
+ input_tensor = input_tensors.pop(0)
+ output_tensor = output_tensors.pop(0)
+ output_tensor_grad = p2p_communication.recv_backward(
+ tensor_shape=tensor_shapes[cnt_bwd], config=config)
+
+ input_tensor_grad = \
+ backward_step(input_tensor, output_tensor,
+ output_tensor_grad, model_type, config)
+ p2p_communication.send_backward(input_tensor_grad, config)
+ cnt_bwd += 1
+
+ # Launch any remaining grad reductions.
+ if no_sync_context is not None:
+ enable_grad_sync()
+ if config.grad_sync_func is not None:
+ config.grad_sync_func(model.parameters())
+
+ if config.finalize_model_grads_func is not None and not forward_only:
+ # Finalize model grads (perform full grad all-reduce / reduce-scatter for
+ # data parallelism, layernorm all-reduce for sequence parallelism, and
+ # embedding all-reduce for pipeline parallelism).
+ config.finalize_model_grads_func([model])
+ argument.optimized_mbs_mode = False
+ argument.micro_batch_size = training.ORIGIN_MBS
+ return forward_data_store
diff --git a/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/transformer.py b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..56a2f94e8ebfdead0edf18651a80c7bfb96a7296
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/performance/auto_pipeline_perf/transformer.py
@@ -0,0 +1,17 @@
+from functools import wraps
+from megatron.training import get_args
+from mindspeed.model.transformer import generate_attention_mask
+import mindspeed.model.transformer
+
+
+def get_attention_mask_wrapper(get_attention_mask):
+ @wraps(get_attention_mask)
+ def wrapper(*args, **kwargs):
+ argument = get_args()
+ automated_pipeline_perf = argument.automated_pipeline_perf and argument.optimized_mbs_list
+ if automated_pipeline_perf:
+ generate_attention_mask()
+ else:
+ get_attention_mask(*args, **kwargs)
+ return mindspeed.model.transformer._GLOBAL_ATTN_MASK
+ return wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/pipeline_parallel/__init__.py b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/pipeline_parallel/flexible_schedules.py b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/flexible_schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a6b901e9de874d6c862fd4eb96b9b796bf951b3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/flexible_schedules.py
@@ -0,0 +1,1771 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import contextlib
+from functools import wraps
+from typing import Callable, Iterator, List, Optional, Union, Tuple
+import torch
+import torch.distributed as dist
+from megatron import core
+from megatron.core import ModelParallelConfig, parallel_state
+from megatron.core.enums import ModelType
+from megatron.core.utils import get_model_config, get_model_type
+from megatron.core.pipeline_parallel.schedules import (
+ get_tensor_shapes,
+ forward_step,
+ backward_step,
+ deallocate_output_tensor,
+ check_first_val_step,
+ clear_embedding_activation_buffer,
+ finish_embedding_wgrad_compute
+)
+from megatron.core.pipeline_parallel import p2p_communication
+from megatron.core.pipeline_parallel.p2p_communication import (
+ Shape,
+ _communicate_shapes,
+ _communicate,
+ _batched_p2p_ops,
+ _p2p_ops
+)
+from megatron.core.parallel_state import get_pipeline_model_parallel_group
+from megatron.training import get_args
+from mindspeed.core.parallel_state import get_pipeline_parallel_group_for_new_stream
+from mindspeed.core.weight_grad_store import WeightGradStore
+
+
+forward_comm_stream = None
+backward_comm_stream = None
+default_stream = None
+scheduler_plan = None
+
+
+def recv_forward(tensor_shapes, config, group):
+ input_tensors = []
+ wait_handles = []
+ for tensor_shape in tensor_shapes:
+ if tensor_shape is None or core.parallel_state.is_pipeline_first_stage():
+ input_tensor = None
+ wait_handle = None
+ else:
+ input_tensor, _, wait_handle = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ recv_prev=True,
+ recv_next=False,
+ tensor_shape=tensor_shape,
+ config=config,
+ group=group,
+ wait_on_reqs=False
+ )
+ input_tensors.append(input_tensor)
+ wait_handles.append(wait_handle)
+ return input_tensors, wait_handles
+
+
+def recv_backward(tensor_shapes, config, group):
+ output_tensor_grads = []
+ wait_handlers = []
+ for tensor_shape in tensor_shapes:
+ if tensor_shape is None or core.parallel_state.is_pipeline_last_stage():
+ output_tensor_grad = None
+ wait_handle = None
+ else:
+ _, output_tensor_grad, wait_handle = _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=None,
+ recv_prev=False,
+ recv_next=True,
+ tensor_shape=tensor_shape,
+ config=config,
+ group=group,
+ wait_on_reqs=False
+ )
+ output_tensor_grads.append(output_tensor_grad)
+ wait_handlers.append(wait_handle)
+ return output_tensor_grads, wait_handlers
+
+
+def send_forward(output_tensors, tensor_shapes, config, group):
+ if not isinstance(output_tensors, list):
+ output_tensors = [output_tensors]
+ for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
+ if tensor_shape is None or core.parallel_state.is_pipeline_last_stage():
+ continue
+
+ _communicate(
+ tensor_send_next=output_tensor,
+ tensor_send_prev=None,
+ recv_prev=False,
+ recv_next=False,
+ tensor_shape=None,
+ config=config,
+ group=group,
+ wait_on_reqs=False
+ )
+
+
+def send_backward(input_tensor_grads, tensor_shapes, config, group):
+ if not isinstance(input_tensor_grads, list):
+ input_tensor_grads = [input_tensor_grads]
+
+ for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
+ if tensor_shape is None or core.parallel_state.is_pipeline_first_stage():
+ continue
+
+ _communicate(
+ tensor_send_next=None,
+ tensor_send_prev=input_tensor_grad,
+ recv_prev=False,
+ recv_next=False,
+ tensor_shape=None,
+ config=config,
+ group=group,
+ wait_on_reqs=False
+ )
+
+
+def _communicate(
+ *,
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_send_prev: Optional[torch.Tensor],
+ recv_prev: bool,
+ recv_next: bool,
+ tensor_shape: Shape,
+ config: ModelParallelConfig,
+ wait_on_reqs: bool = True,
+ group: dist.ProcessGroup = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Communicate tensors between stages. Used as helper method in other
+ communication methods that are used in megatron/schedules.py.
+
+ Args:
+ tensor_send_next (torch.Tensor, optional):
+ Tensor to send to next rank (no tensor sent if None)
+
+ tensor_send_prev (torch.Tensor, optional):
+ Tensor to send to prev rank (no tensor sent if None)
+
+ recv_prev (boolean, required):
+ whether tensor should be received from previous rank.
+
+ recv_next (boolean, required):
+ whether tensor should be received from next rank.
+
+ tensor_shape (List[int] or torch.Size, required):
+ shape of tensor to receive (this method assumes that all
+ tensors sent and received in a single function call are
+ the same shape).
+
+ wait_on_reqs (boolean, optional, default=False):
+ For non-batched p2p communication, wait on each request
+ before returning.
+
+ Returns:
+ tuple containing
+
+ - tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
+ - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
+
+ """
+ # Create placeholder tensors for receive in forward and backward directions
+ # if needed.
+ tensor_recv_prev = None
+ tensor_recv_next = None
+
+ if not config.variable_seq_lengths:
+ recv_prev_shape = tensor_shape
+ recv_next_shape = tensor_shape
+ else:
+ recv_prev_shape, recv_next_shape = _communicate_shapes(
+ tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
+ )
+
+ if recv_prev:
+ if config.pipeline_dtype is None:
+ raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
+ if tensor_shape is None:
+ raise RuntimeError(
+ "tensor_shape must be specified if recv_prev is True. "
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
+ )
+ tensor_recv_prev = torch.empty(
+ recv_prev_shape,
+ requires_grad=True,
+ device=torch.cuda.current_device(),
+ dtype=config.pipeline_dtype,
+ )
+ if recv_next:
+ if config.pipeline_dtype is None:
+ raise RuntimeError("dtype must be provided if recv_next is True")
+ if tensor_shape is None:
+ raise RuntimeError(
+ "tensor_shape must be specified if recv_next is True. "
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
+ )
+ tensor_recv_next = torch.empty(
+ recv_next_shape,
+ requires_grad=True,
+ device=torch.cuda.current_device(),
+ dtype=config.pipeline_dtype,
+ )
+
+ # Send tensors in both the forward and backward directions as appropriate.
+ if config.use_ring_exchange_p2p:
+
+ def _ring_exchange_wrapper(**kwargs):
+ torch.distributed.ring_exchange(**kwargs)
+ return []
+
+ p2p_func = _ring_exchange_wrapper
+ elif config.batch_p2p_comm:
+ if not wait_on_reqs:
+ raise AssertionError('wait_on_reqs must be True when use batch_p2p_comm')
+ p2p_func = _batched_p2p_ops
+ else:
+ p2p_func = _p2p_ops
+
+ reqs = p2p_func(
+ tensor_send_prev=tensor_send_prev,
+ tensor_recv_prev=tensor_recv_prev,
+ tensor_send_next=tensor_send_next,
+ tensor_recv_next=tensor_recv_next,
+ group=group
+ )
+
+ if wait_on_reqs and len(reqs) > 0:
+ for req in reqs:
+ req.wait()
+ reqs = None
+
+ if config.batch_p2p_comm and config.batch_p2p_sync:
+ # To protect against race condition when using batch_isend_irecv().
+ # User should assert that we have a modern enough PyTorch to not need this
+ torch.cuda.synchronize()
+
+ return tensor_recv_prev, tensor_recv_next, reqs
+
+
+def generate_1f1b_scheduler_plan(pp_size, num_micro_batch):
+ scheduler_plan_all_stages = {}
+
+ num_warmup_microbatch = [pp_size - r - 1 for r in range(pp_size)]
+ num_cooldown_microbatch = num_warmup_microbatch
+ num_stable_microbatch = [(num_micro_batch * 2 - num_warmup_microbatch[r] - num_cooldown_microbatch[r]) // 2
+ for r in range(pp_size)]
+
+ forward_count = [1 for _ in range(pp_size)]
+ backward_count = [1 for _ in range(pp_size)]
+
+ # warmup
+ for pp_rank in range(pp_size):
+ key = 'stage{}'.format(pp_rank)
+ scheduler_plan_all_stages[key] = []
+ for i in range(num_warmup_microbatch[pp_rank]):
+ value = 'F{}'.format(forward_count[pp_rank])
+ scheduler_plan_all_stages[key].append(value)
+ forward_count[pp_rank] += 1
+
+ # stable
+ for pp_rank in range(pp_size):
+ key = 'stage{}'.format(pp_rank)
+ for i in range(num_stable_microbatch[pp_rank]):
+ value = 'F{}'.format(forward_count[pp_rank])
+ scheduler_plan_all_stages[key].append(value)
+ forward_count[pp_rank] += 1
+
+ value = 'B{}'.format(backward_count[pp_rank])
+ scheduler_plan_all_stages[key].append(value)
+ backward_count[pp_rank] += 1
+
+ # cooldown
+ for pp_rank in range(pp_size):
+ key = 'stage{}'.format(pp_rank)
+ for i in range(num_cooldown_microbatch[pp_rank]):
+ value = 'B{}'.format(backward_count[pp_rank])
+ scheduler_plan_all_stages[key].append(value)
+ backward_count[pp_rank] += 1
+
+ return scheduler_plan_all_stages
+
+
+def forward_backward_pipelining_without_interleaving(
+ *,
+ forward_step_func,
+ data_iterator: Union[Iterator, List[Iterator]],
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
+ num_microbatches: int,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int = None,
+ forward_only: bool = False,
+ collect_non_loss_data: bool = False,
+ first_val_step: bool = None
+):
+ """Run non-interleaved 1F1B schedule, with communication between pipeline
+ stages.
+
+ Returns dictionary with losses if the last stage, empty dict otherwise.
+
+ """
+
+ if isinstance(model, list):
+ if not len(model) == 1:
+ raise AssertionError("non-interleaved pipeline parallelism does not support model chunking")
+ model = model[0]
+ if isinstance(data_iterator, list):
+ if not len(data_iterator) == 1:
+ raise AssertionError("non-pipeline-parallel schedule does not support model chunking")
+ data_iterator = data_iterator[0]
+
+ config = get_model_config(model)
+ if config.timers is not None:
+ config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
+
+ # Disable async grad reductions
+ no_sync_func = config.no_sync_func
+ if no_sync_func is None:
+ no_sync_func = contextlib.nullcontext
+ no_sync_context = None
+
+ def disable_grad_sync():
+ """Disable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is None:
+ no_sync_context = no_sync_func()
+ no_sync_context.__enter__()
+
+ def enable_grad_sync():
+ """Enable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is not None:
+ no_sync_context.__exit__(None, None, None)
+ no_sync_context = None
+
+ disable_grad_sync()
+
+ # Compute number of warmup microbatches.
+ num_warmup_microbatches = (
+ parallel_state.get_pipeline_model_parallel_world_size()
+ - parallel_state.get_pipeline_model_parallel_rank()
+ - 1
+ )
+ num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
+
+ # Checkpoint the activations of partial Transformer layers in a number of micro-batches
+ # within the maximum outstanding micro-batch backpropagations.
+ # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
+ # checkpoint partial Transformer layers (or skip checkpointing) and
+ # the rest of micro-batches within a window of micro-batches checkpoint
+ # all Transformer layers. The window of micro-batches is set by the maximum
+ # outstanding backpropagations and becomes smaller at later pipeline stages.
+ # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
+ max_outstanding_backprops = None
+ if config.num_microbatches_with_partial_activation_checkpoints is not None:
+ max_outstanding_backprops = num_warmup_microbatches + 1
+
+ model_type = get_model_type(model)
+
+ rank = parallel_state.get_pipeline_model_parallel_rank()
+ recv_tensor_shapes = get_tensor_shapes(
+ rank=rank - 1,
+ model_type=model_type,
+ seq_length=seq_length,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+ send_tensor_shapes = get_tensor_shapes(
+ rank=rank,
+ model_type=model_type,
+ seq_length=seq_length,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+
+ # Input, output tensors only need to be saved when doing backward passes
+ input_tensors = None
+ output_tensors = None
+ if not forward_only:
+ input_tensors = []
+ output_tensors = []
+ forward_data_store = []
+
+ def wait_helper(wait_handlers):
+ for reqs in wait_handlers:
+ if reqs is not None:
+ for req in reqs:
+ req.wait()
+
+ global forward_comm_stream
+ if forward_comm_stream is None:
+ forward_comm_stream = torch.cuda.Stream()
+
+ global backward_comm_stream
+ if backward_comm_stream is None:
+ backward_comm_stream = torch.cuda.Stream()
+
+ global default_stream
+ if default_stream is None:
+ default_stream = torch.cuda.default_stream()
+
+ global scheduler_plan
+ arguments = get_args()
+ key = 'stage{}'.format(parallel_state.get_pipeline_model_parallel_rank())
+ if scheduler_plan is None and arguments.pp_schedule_list:
+ scheduler_plan = arguments.pp_schedule_list.get(key)
+ elif scheduler_plan is None and arguments.pp_schedule_list is None:
+ scheduler_plan = generate_1f1b_scheduler_plan(parallel_state.get_pipeline_model_parallel_world_size(),
+ num_microbatches)
+ scheduler_plan = scheduler_plan.get(key)
+
+ config.batch_p2p_comm = False
+ fwd_wait_handles, bwd_wait_handles = None, None
+ current_tag_id = -1
+ for tag in scheduler_plan:
+ current_tag_id += 1
+ if tag.startswith('F'):
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ current_tag_id % max_outstanding_backprops >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ with torch.cuda.stream(forward_comm_stream):
+ input_tensor, fwd_wait_handles = recv_forward(
+ recv_tensor_shapes, config, get_pipeline_model_parallel_group()
+ )
+
+ wait_helper(fwd_wait_handles)
+ output_tensor, _ = forward_step(
+ forward_step_func,
+ data_iterator,
+ model,
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ checkpoint_activations_microbatch,
+ check_first_val_step(first_val_step, forward_only, current_tag_id == 0)
+ )
+
+ with torch.cuda.stream(forward_comm_stream):
+ forward_comm_stream.wait_stream(default_stream)
+ send_forward(
+ output_tensor,
+ send_tensor_shapes,
+ config,
+ get_pipeline_model_parallel_group()
+ )
+ for tensor in output_tensor:
+ if tensor is not None:
+ tensor.record_stream(forward_comm_stream)
+
+
+ if not forward_only:
+ input_tensors.append(input_tensor)
+ output_tensors.append(output_tensor)
+ deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
+
+ else:
+ if forward_only:
+ continue
+
+ if current_tag_id == len(scheduler_plan) - 1:
+ if config.grad_sync_func is None or rank == 0:
+ enable_grad_sync()
+
+ with torch.cuda.stream(backward_comm_stream):
+ output_tensor_grads, bwd_wait_handles = recv_backward(
+ send_tensor_shapes, config, get_pipeline_parallel_group_for_new_stream()
+ )
+
+ input_tensor = input_tensors.pop(0)
+ output_tensor = output_tensors.pop(0)
+
+ wait_helper(bwd_wait_handles)
+ input_tensor_grad = backward_step(
+ input_tensor,
+ output_tensor,
+ output_tensor_grads,
+ model_type,
+ config
+ )
+
+ with torch.cuda.stream(backward_comm_stream):
+ backward_comm_stream.wait_stream(default_stream)
+ send_backward(
+ input_tensor_grad,
+ recv_tensor_shapes,
+ config,
+ get_pipeline_parallel_group_for_new_stream()
+ )
+ for tensor in input_tensor_grad:
+ if tensor is not None:
+ tensor.record_stream(backward_comm_stream)
+
+ if not forward_only:
+ if no_sync_context is not None:
+ enable_grad_sync()
+ if config.grad_sync_func is not None:
+ config.grad_sync_func(model.parameters())
+
+ if config.timers is not None:
+ config.timers('forward-backward').stop()
+
+ if config.finalize_model_grads_func is not None and not forward_only:
+ # Finalize model grads (perform full grad all-reduce / reduce-scatter for
+ # data parallelism, layernorm all-reduce for sequence parallelism, and
+ # embedding all-reduce for pipeline parallelism).
+ config.finalize_model_grads_func([model])
+
+ return forward_data_store
+
+
+def forward_backward_pipelining_with_interleaving_nano_pipe(
+ *,
+ forward_step_func,
+ data_iterator: Union[Iterator, List[Iterator]],
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
+ num_microbatches: int,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int = None,
+ forward_only: bool = False,
+ collect_non_loss_data: bool = False,
+ first_val_step: bool = None,
+):
+ """Run interleaved 1F1B-nanopipe schedule (model split into model chunks), with
+ communication between pipeline stages as needed.
+
+ Returns dictionary with losses if the last stage, empty dict otherwise.
+ """
+ if not isinstance(model, list):
+ raise AssertionError("interleaved pipeline parallelism expected model chunking")
+ if not all(isinstance(chunk, torch.nn.Module) for chunk in model):
+ raise AssertionError("invalid model chunking")
+ if not isinstance(data_iterator, list):
+ raise AssertionError("interleaved pipeline parallelism expected each model chunk to have a data iterator")
+ args = get_args()
+ config = get_model_config(model[0])
+ if config.overlap_p2p_comm and config.batch_p2p_comm:
+ raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
+
+ if config.timers is not None:
+ config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
+
+ # Disable async grad reductions
+ no_sync_func = config.no_sync_func
+ if isinstance(no_sync_func, list):
+
+ def multi_no_sync():
+ stack = contextlib.ExitStack()
+ for model_chunk_no_sync_func in config.no_sync_func:
+ stack.enter_context(model_chunk_no_sync_func())
+ return stack
+
+ no_sync_func = multi_no_sync
+ if no_sync_func is None:
+ no_sync_func = contextlib.nullcontext
+ no_sync_context = None
+
+ if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
+ config.grad_sync_func = [config.grad_sync_func for _ in model]
+
+ if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
+ config.param_sync_func = [config.param_sync_func for _ in model]
+
+ def disable_grad_sync():
+ """Disable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is None:
+ no_sync_context = no_sync_func()
+ no_sync_context.__enter__()
+
+ def enable_grad_sync():
+ """Enable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is not None:
+ no_sync_context.__exit__(None, None, None)
+ no_sync_context = None
+
+ disable_grad_sync()
+
+ # Model chunk IDs with synchronized grads
+ synchronized_model_chunks = set()
+
+ input_tensors = [[] for _ in range(len(model))]
+ output_tensors = [[] for _ in range(len(model))]
+ forward_data_store = []
+ if not forward_only:
+ output_tensor_grads = [[] for _ in range(len(model))]
+
+ pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
+
+ if num_microbatches % pipeline_parallel_size != 0:
+ msg = f'number of microbatches ({num_microbatches}) is not divisible by '
+ msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
+ msg += 'when using interleaved schedule'
+ raise RuntimeError(msg)
+
+ model_type = get_model_type(model[0])
+ if model_type == ModelType.encoder_and_decoder:
+ raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
+
+ if decoder_seq_length is not None and decoder_seq_length != seq_length:
+ raise RuntimeError(
+ "Interleaving is not supported with a different decoder sequence length."
+ )
+
+ tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+ tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
+ if config.sequence_parallel:
+ tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
+ tensor_shape[0] = tensor_shape[0] // args.tp_x
+ tensor_shape[-1] = tensor_shape[-1] // args.tp_y
+ # Compute number of warmup and remaining microbatches.
+ num_model_chunks = len(model)
+ total_num_microbatches = num_microbatches * num_model_chunks
+ all_warmup_microbatches = False
+ if forward_only:
+ num_warmup_microbatches = total_num_microbatches
+ else:
+ # Run all forward passes and then all backward passes if number of
+ # microbatches is just the number of pipeline stages.
+ # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
+ # all workers, followed by more microbatches after depending on
+ # stage ID (more forward passes for earlier stages, later stages can
+ # immediately start with 1F1B).
+ if num_microbatches == pipeline_parallel_size:
+ num_warmup_microbatches = total_num_microbatches
+ all_warmup_microbatches = True
+ else:
+ num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
+ num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
+ num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
+
+ num_fwd = min((pipeline_parallel_size - 1) * 2 + (num_model_chunks - 1) * pipeline_parallel_size, total_num_microbatches)
+ num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
+ num_dx = num_fwd - num_warmup_microbatches
+ overlap_chunks_num = (num_dx + pipeline_parallel_size - 1) // pipeline_parallel_size
+ nano_flag = [True] * len(model)
+ for i in range(overlap_chunks_num):
+ nano_flag[-i - 1] = False
+
+ # Checkpoint the activations of partial Transformer layers in a number of micro-batches
+ # within the maximum outstanding micro-batch backpropagations.
+ # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
+ # checkpoint partial Transformer layers (or skip checkpointing) and
+ # the rest of micro-batches within a window of micro-batches checkpoint
+ # all Transformer layers. The window of micro-batches is set by the maximum
+ # outstanding backpropagations and becomes smaller at later pipeline stages.
+ # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
+ max_outstanding_backprops = None
+ if config.num_microbatches_with_partial_activation_checkpoints is not None:
+ max_outstanding_backprops = num_warmup_microbatches + 1
+
+ # Synchronize params for first two model chunks
+ if config.param_sync_func is not None:
+ config.param_sync_func[0](model[0].parameters())
+ config.param_sync_func[1](model[1].parameters())
+
+ def get_model_chunk_id(microbatch_id, forward):
+ """Helper method to get the model chunk ID given the iteration number."""
+ microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
+ model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
+ if not forward:
+ model_chunk_id = num_model_chunks - model_chunk_id - 1
+ return model_chunk_id
+
+ def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the first for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == 0:
+ return microbatch_id_in_group % pipeline_parallel_size == 0
+ else:
+ return False
+
+ def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the last for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == num_microbatch_groups - 1:
+ return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
+ else:
+ return False
+
+ def forward_step_helper(microbatch_id, checkpoint_activations_microbatch):
+ """Helper method to run forward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ forward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch param synchronization for next model chunk
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.param_sync_func is not None:
+ param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
+ if (
+ param_sync_microbatch_id < total_num_microbatches
+ and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
+ ):
+ param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
+ if 1 < param_sync_chunk_id < num_model_chunks:
+ config.param_sync_func[param_sync_chunk_id](
+ model[param_sync_chunk_id].parameters()
+ )
+
+ # forward step
+ if parallel_state.is_pipeline_first_stage():
+ if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
+ input_tensors[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id][-1]
+ output_tensor, _ = forward_step(
+ forward_step_func,
+ data_iterator[model_chunk_id],
+ model[model_chunk_id],
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ checkpoint_activations_microbatch,
+ check_first_val_step(
+ first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id),
+ ),
+ )
+ output_tensors[model_chunk_id].append(output_tensor)
+
+ # if forward-only, no need to save tensors for a backward pass
+ if forward_only:
+ input_tensors[model_chunk_id].pop()
+ output_tensors[model_chunk_id].pop()
+
+ return output_tensor
+
+ def backward_step_helper(microbatch_id):
+ """Helper method to run backward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ backward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch grad synchronization (default)
+ if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id) and nano_flag[model_chunk_id]:
+ enable_grad_sync()
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if parallel_state.is_pipeline_last_stage():
+ if len(output_tensor_grads[model_chunk_id]) == 0:
+ output_tensor_grads[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id].pop(0)
+ output_tensor = output_tensors[model_chunk_id].pop(0)
+ output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
+ input_tensor_grad = backward_step(
+ input_tensor, output_tensor, output_tensor_grad, model_type, config
+ )
+
+ # launch grad synchronization (custom grad sync)
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.grad_sync_func is not None:
+ grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
+ if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
+ grad_sync_microbatch_id
+ ):
+ grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
+ if nano_flag[grad_sync_chunk_id]:
+ enable_grad_sync()
+ config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
+ synchronized_model_chunks.add(grad_sync_chunk_id)
+ disable_grad_sync()
+
+ return input_tensor_grad
+
+ # Run warmup forward passes.
+ parallel_state.set_virtual_pipeline_model_parallel_rank(0)
+ input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
+
+ fwd_wait_handles = None
+ bwd_wait_handles = None
+
+ for k in range(num_warmup_microbatches):
+
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ output_tensor = forward_step_helper(k, checkpoint_activations_microbatch)
+
+ # Determine if tensor should be received from previous stage.
+ next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ if next_forward_model_chunk_id == 0:
+ recv_prev = False
+ if k == (total_num_microbatches - 1):
+ recv_prev = False
+
+ # Don't send tensor downstream if on last stage.
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Send and receive tensors as appropriate (send tensors computed
+ # in this iteration; receive tensors for next iteration).
+ if not config.overlap_p2p_comm:
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = p2p_communication.send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ )
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ else:
+ input_tensor = p2p_communication.send_forward_recv_forward(
+ output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
+ )
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ else:
+ input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+
+ (
+ output_tensor_grad,
+ bwd_wait_handles,
+ ) = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ output_tensor = None
+ # Run 1F1B in steady state.
+ for k in range(num_microbatches_remaining):
+ # Forward pass.
+ forward_k = k + num_warmup_microbatches
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ forward_k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ if config.overlap_p2p_comm:
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+
+ # Last virtual stage no activation tensor to send
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Send activation tensor to the next stage and receive activation tensor from the
+ # previous stage
+ input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+ # assert fwd_wait_handles is not None
+
+ if bwd_wait_handles is not None:
+ for req in bwd_wait_handles:
+ req.wait()
+
+ # Backward pass.
+ backward_k = k
+ if k < num_dx:
+ WeightGradStore.start_decouple()
+
+ if args.use_nanopipe:
+ WeightGradStore.resize_ori_storage(args.use_nanopipe_swap)
+
+ input_tensor_grad = backward_step_helper(backward_k)
+ if WeightGradStore.is_decoupleBlock:
+ WeightGradStore.flush()
+ if k == num_dx - 1:
+ WeightGradStore.end_decouple()
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+
+
+ # First virtual stage no activation gradient tensor to send
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if the current virtual stage has an activation gradient tensor to receive
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+ else: # no p2p overlap
+ output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
+
+ # Backward pass.
+ backward_k = k
+ if k < num_dx:
+ WeightGradStore.start_decouple()
+
+ if args.use_nanopipe:
+ WeightGradStore.resize_ori_storage(args.use_nanopipe_swap)
+
+ input_tensor_grad = backward_step_helper(backward_k)
+ if WeightGradStore.is_decoupleBlock:
+ WeightGradStore.flush()
+ if k == num_dx - 1:
+ WeightGradStore.end_decouple()
+
+ # Send output_tensor and input_tensor_grad, receive input_tensor
+ # and output_tensor_grad.
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Communicate tensors.
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = p2p_communication.send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ )
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Put input_tensor and output_tensor_grad in data structures in the
+ # right location.
+ if recv_prev:
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ if recv_next:
+ output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+ # Run cooldown backward passes (flush out pipeline).
+ if not forward_only:
+ if config.overlap_p2p_comm and bwd_wait_handles is not None:
+ for wait_handle in bwd_wait_handles:
+ wait_handle.wait()
+
+ if all_warmup_microbatches:
+ output_tensor_grads[num_model_chunks - 1].append(
+ p2p_communication.recv_backward(tensor_shape, config=config)
+ )
+ for k in range(num_microbatches_remaining, total_num_microbatches):
+ input_tensor_grad = backward_step_helper(k)
+ next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ if next_backward_model_chunk_id == (num_model_chunks - 1):
+ recv_next = False
+ if k == (total_num_microbatches - 1):
+ recv_next = False
+ output_tensor_grads[next_backward_model_chunk_id].append(
+ p2p_communication.send_backward_recv_backward(
+ input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
+ )
+ )
+ if args.use_nanopipe_swap and k == max(num_microbatches_remaining + 1, (total_num_microbatches + num_microbatches_remaining) // 2):
+ WeightGradStore.swap_tensors()
+ if nano_flag[0] and 0 not in synchronized_model_chunks:
+ config.grad_sync_func[0](model[0].parameters())
+ synchronized_model_chunks.add(0)
+ overlap_arg = [pipeline_parallel_size, nano_flag, synchronized_model_chunks, config.grad_sync_func, model]
+ WeightGradStore.pop(overlap_arg)
+ # Launch any remaining grad reductions.
+ enable_grad_sync()
+ if config.grad_sync_func is not None:
+ for model_chunk_id in range(num_model_chunks):
+ if model_chunk_id not in synchronized_model_chunks:
+ config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if config.timers is not None:
+ config.timers('forward-backward').stop()
+
+ if config.finalize_model_grads_func is not None and not forward_only:
+ # Finalize model grads (perform full grad all-reduce / reduce-scatter for
+ # data parallelism, layernorm all-reduce for sequence parallelism, and
+ # embedding all-reduce for pipeline parallelism).
+ config.finalize_model_grads_func(model)
+
+ return forward_data_store
+
+
+def forward_backward_pipelining_with_interleaving_patch(
+ *,
+ forward_step_func,
+ data_iterator: Union[Iterator, List[Iterator]],
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
+ num_microbatches: int,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int = None,
+ forward_only: bool = False,
+ collect_non_loss_data: bool = False,
+ first_val_step: bool = None,
+):
+ """Run interleaved 1F1B schedule (model split into model chunks), with
+ communication between pipeline stages as needed.
+
+ Returns dictionary with losses if the last stage, empty dict otherwise."""
+ if not isinstance(model, list):
+ raise AssertionError("interleaved pipeline parallelism expected model chunking")
+ if not all(isinstance(chunk, torch.nn.Module) for chunk in model):
+ raise AssertionError("invalid model chunking")
+ if not isinstance(data_iterator, list):
+ raise AssertionError("interleaved pipeline parallelism expected each model chunk to have a data iterator")
+ config = get_model_config(model[0])
+ if config.overlap_p2p_comm and config.batch_p2p_comm:
+ raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
+
+ # Needed only when gradients are finalized in M-Core
+ if config.finalize_model_grads_func is not None and not forward_only:
+ embedding_module = clear_embedding_activation_buffer(config, model)
+
+ if config.timers is not None:
+ config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
+
+ # Disable async grad reductions
+ no_sync_func = config.no_sync_func
+ if isinstance(no_sync_func, list):
+
+ def multi_no_sync():
+ stack = contextlib.ExitStack()
+ for model_chunk_no_sync_func in config.no_sync_func:
+ stack.enter_context(model_chunk_no_sync_func())
+ return stack
+
+ no_sync_func = multi_no_sync
+ if no_sync_func is None:
+ no_sync_func = contextlib.nullcontext
+ no_sync_context = None
+
+ if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
+ config.grad_sync_func = [config.grad_sync_func for _ in model]
+
+ if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
+ config.param_sync_func = [config.param_sync_func for _ in model]
+
+ def disable_grad_sync():
+ """Disable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is None:
+ no_sync_context = no_sync_func()
+ no_sync_context.__enter__()
+
+ def enable_grad_sync():
+ """Enable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is not None:
+ no_sync_context.__exit__(None, None, None)
+ no_sync_context = None
+
+ disable_grad_sync()
+
+ # Model chunk IDs with synchronized grads
+ synchronized_model_chunks = set()
+
+ input_tensors = [[] for _ in range(len(model))]
+ output_tensors = [[] for _ in range(len(model))]
+ total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
+
+ forward_data_store = []
+ if not forward_only:
+ output_tensor_grads = [[] for _ in range(len(model))]
+
+ pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
+
+ if num_microbatches % pipeline_parallel_size != 0:
+ msg = f'number of microbatches ({num_microbatches}) is not divisible by '
+ msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
+ msg += 'when using interleaved schedule'
+ raise RuntimeError(msg)
+
+ model_type = get_model_type(model[0])
+ if model_type == ModelType.encoder_and_decoder:
+ raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
+
+ if decoder_seq_length is not None and decoder_seq_length != seq_length:
+ raise RuntimeError(
+ "Interleaving is not supported with a different decoder sequence length."
+ )
+
+ tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+ tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
+ if config.sequence_parallel:
+ tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
+ tensor_shape[0] = tensor_shape[0] // get_args().tp_x
+ tensor_shape[-1] = tensor_shape[-1] // get_args().tp_y
+ # Compute number of warmup and remaining microbatches.
+ num_model_chunks = len(model)
+ total_num_microbatches = num_microbatches * num_model_chunks
+ all_warmup_microbatches = False
+ if forward_only:
+ num_warmup_microbatches = total_num_microbatches
+ else:
+ # Run all forward passes and then all backward passes if number of
+ # microbatches is just the number of pipeline stages.
+ # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
+ # all workers, followed by more microbatches after depending on
+ # stage ID (more forward passes for earlier stages, later stages can
+ # immediately start with 1F1B).
+ if num_microbatches == pipeline_parallel_size:
+ num_warmup_microbatches = total_num_microbatches
+ all_warmup_microbatches = True
+ else:
+ num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
+ num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
+ num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
+ num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
+
+ # Checkpoint the activations of partial Transformer layers in a number of micro-batches
+ # within the maximum outstanding micro-batch backpropagations.
+ # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
+ # checkpoint partial Transformer layers (or skip checkpointing) and
+ # the rest of micro-batches within a window of micro-batches checkpoint
+ # all Transformer layers. The window of micro-batches is set by the maximum
+ # outstanding backpropagations and becomes smaller at later pipeline stages.
+ # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
+ max_outstanding_backprops = None
+ if config.num_microbatches_with_partial_activation_checkpoints is not None:
+ max_outstanding_backprops = num_warmup_microbatches + 1
+
+ # Synchronize params for first two model chunks
+ if config.param_sync_func is not None:
+ config.param_sync_func[0](model[0].parameters())
+ config.param_sync_func[1](model[1].parameters())
+
+ def get_model_chunk_id(microbatch_id, forward):
+ """Helper method to get the model chunk ID given the iteration number."""
+ microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
+ model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
+ if not forward:
+ model_chunk_id = num_model_chunks - model_chunk_id - 1
+ return model_chunk_id
+
+ def get_microbatch_id_in_model_chunk(iteration_id, forward):
+ """Helper method to get the microbatch_id within model chunk given the iteration number."""
+ assert forward
+ iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks)
+ microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + (
+ iteration_id % pipeline_parallel_size
+ )
+ return microbatch_id_in_model_chunk
+
+ def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the first for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == 0:
+ return microbatch_id_in_group % pipeline_parallel_size == 0
+ else:
+ return False
+
+ def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the last for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == num_microbatch_groups - 1:
+ return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
+ else:
+ return False
+
+ def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch):
+ """Helper method to run forward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ forward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch param synchronization for next model chunk
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.param_sync_func is not None:
+ param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
+ if (
+ param_sync_microbatch_id < total_num_microbatches
+ and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
+ ):
+ param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
+ if 1 < param_sync_chunk_id < num_model_chunks:
+ config.param_sync_func[param_sync_chunk_id](
+ model[param_sync_chunk_id].parameters()
+ )
+
+ # forward step
+ if parallel_state.is_pipeline_first_stage():
+ if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
+ input_tensors[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id][-1]
+
+ output_tensor, num_tokens = forward_step(
+ forward_step_func,
+ data_iterator[model_chunk_id],
+ model[model_chunk_id],
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ checkpoint_activations_microbatch,
+ check_first_val_step(
+ first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id),
+ ),
+ current_microbatch=current_microbatch,
+ )
+ output_tensors[model_chunk_id].append(output_tensor)
+
+ nonlocal total_num_tokens
+ total_num_tokens += num_tokens.item()
+
+ # if forward-only, no need to save tensors for a backward pass
+ if forward_only:
+ input_tensors[model_chunk_id].pop()
+ output_tensors[model_chunk_id].pop()
+
+ return output_tensor
+
+ def backward_step_helper(microbatch_id):
+ """Helper method to run backward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ backward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch grad synchronization (default)
+ if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
+ enable_grad_sync()
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if parallel_state.is_pipeline_last_stage():
+ if len(output_tensor_grads[model_chunk_id]) == 0:
+ output_tensor_grads[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id].pop(0)
+ output_tensor = output_tensors[model_chunk_id].pop(0)
+ output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
+ input_tensor_grad = backward_step(
+ input_tensor, output_tensor, output_tensor_grad, model_type, config
+ )
+
+ # launch grad synchronization (custom grad sync)
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.grad_sync_func is not None:
+ grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
+ if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
+ grad_sync_microbatch_id
+ ):
+ grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
+ enable_grad_sync()
+ config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
+ synchronized_model_chunks.add(grad_sync_chunk_id)
+ disable_grad_sync()
+
+ return input_tensor_grad
+
+ # Run warmup forward passes.
+ parallel_state.set_virtual_pipeline_model_parallel_rank(0)
+ input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
+
+ fwd_wait_handles = None
+ bwd_wait_handles = None
+
+ for k in range(num_warmup_microbatches):
+
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ cur_model_chunk_id = get_model_chunk_id(k, forward=True)
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True)
+ output_tensor = forward_step_helper(
+ k, current_microbatch, checkpoint_activations_microbatch
+ )
+
+ # Determine if tensor should be received from previous stage.
+ next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ if next_forward_model_chunk_id == 0:
+ recv_prev = False
+ if k == (total_num_microbatches - 1):
+ recv_prev = False
+
+ # Don't send tensor downstream if on last stage.
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Send and receive tensors as appropriate (send tensors computed
+ # in this iteration; receive tensors for next iteration).
+ if not config.overlap_p2p_comm:
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = p2p_communication.send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ )
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ else:
+ input_tensor = p2p_communication.send_forward_recv_forward(
+ output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
+ )
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ else:
+ input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+
+ (
+ output_tensor_grad,
+ bwd_wait_handles,
+ ) = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Run 1F1B in steady state.
+ for k in range(num_microbatches_remaining):
+ # Forward pass.
+ forward_k = k + num_warmup_microbatches
+
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ forward_k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True)
+ if config.overlap_p2p_comm:
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ output_tensor = forward_step_helper(
+ forward_k, current_microbatch, checkpoint_activations_microbatch
+ )
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+
+ # Last virtual stage no activation tensor to send
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Send activation tensor to the next stage and receive activation tensor from the
+ # previous stage
+ input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+ # assert fwd_wait_handles is not None
+
+ if bwd_wait_handles is not None:
+ for req in bwd_wait_handles:
+ req.wait()
+
+ # Backward pass.
+ backward_k = k
+ input_tensor_grad = backward_step_helper(backward_k)
+
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+
+ # First virtual stage no activation gradient tensor to send
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if the current virtual stage has an activation gradient tensor to receive
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ else: # no p2p overlap
+ output_tensor = forward_step_helper(
+ forward_k, current_microbatch, checkpoint_activations_microbatch
+ )
+
+ # Backward pass.
+ backward_k = k
+ input_tensor_grad = backward_step_helper(backward_k)
+
+ # Send output_tensor and input_tensor_grad, receive input_tensor
+ # and output_tensor_grad.
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Communicate tensors.
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = p2p_communication.send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ )
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Put input_tensor and output_tensor_grad in data structures in the
+ # right location.
+ if recv_prev:
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ if recv_next:
+ output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Run cooldown backward passes (flush out pipeline).
+ if not forward_only:
+ if config.overlap_p2p_comm and bwd_wait_handles is not None:
+ for wait_handle in bwd_wait_handles:
+ wait_handle.wait()
+
+ if all_warmup_microbatches:
+ output_tensor_grads[num_model_chunks - 1].append(
+ p2p_communication.recv_backward(tensor_shape, config=config)
+ )
+ for k in range(num_microbatches_remaining, total_num_microbatches):
+ input_tensor_grad = backward_step_helper(k)
+ next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ if next_backward_model_chunk_id == (num_model_chunks - 1):
+ recv_next = False
+ if k == (total_num_microbatches - 1):
+ recv_next = False
+ output_tensor_grads[next_backward_model_chunk_id].append(
+ p2p_communication.send_backward_recv_backward(
+ input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
+ )
+ )
+
+ # Launch any remaining grad reductions.
+ enable_grad_sync()
+ if config.grad_sync_func is not None:
+ for model_chunk_id in range(num_model_chunks):
+ if model_chunk_id not in synchronized_model_chunks:
+ config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if config.finalize_model_grads_func is not None and not forward_only:
+ # If defer_embedding_wgrad_compute is enabled we need to do the
+ # weight gradient GEMM's here.
+ finish_embedding_wgrad_compute(config, embedding_module)
+
+ # Finalize model grads (perform full grad all-reduce / reduce-scatter for
+ # data parallelism, layernorm all-reduce for sequence parallelism, and
+ # embedding all-reduce for pipeline parallelism).
+ config.finalize_model_grads_func(
+ model, total_num_tokens if config.calculate_per_token_loss else None
+ )
+
+ if config.timers is not None:
+ config.timers('forward-backward').stop()
+
+ return forward_data_store
diff --git a/model/train/yoco_moe/mindspeed/core/pipeline_parallel/multiparameter_schedules.py b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/multiparameter_schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d1f5edd4dc426152000272e9c59b8057a99ffd
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/multiparameter_schedules.py
@@ -0,0 +1,965 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import contextlib
+from functools import wraps
+from typing import Iterator, List, Union
+
+import torch
+
+from megatron.training import get_args
+from megatron.core import parallel_state
+from megatron.core.enums import ModelType
+from megatron.core.pipeline_parallel import p2p_communication
+from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
+
+from megatron.core.pipeline_parallel.schedules import (
+ forward_step,
+ deallocate_output_tensor,
+ check_first_val_step
+)
+
+
+def forward_step_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ output_tensor, num_tokens = fn(*arg, **kwargs)
+ if len(output_tensor) > 0 and isinstance(output_tensor[0], list):
+ return output_tensor[0], num_tokens
+ else:
+ return output_tensor, num_tokens
+ return wrapper
+
+
+def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
+ if config.timers is not None:
+ config.timers('backward-compute', log_level=2).start()
+
+ # Retain the grad on the input_tensor.
+ unwrap_input_tensor_grad = False
+ if not isinstance(input_tensor, list):
+ input_tensor = [input_tensor]
+ unwrap_input_tensor_grad = True
+ for x in input_tensor:
+ if x is not None and x.requires_grad:
+ x.retain_grad()
+
+ if not isinstance(output_tensor, list):
+ output_tensor = [output_tensor]
+ if not isinstance(output_tensor_grad, list):
+ output_tensor_grad = [output_tensor_grad]
+
+ # Backward pass.
+ if output_tensor_grad[0] is None and config.grad_scale_func is not None:
+ output_tensor[0] = config.grad_scale_func(output_tensor[0])
+
+ output_tensors = []
+ output_grad_tensors = []
+ if output_tensor_grad[0] is None:
+ # The last stage have no input gradients and only one loss is used to backward
+ torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
+ else:
+ for output, grad in zip(output_tensor, output_tensor_grad):
+ if output.requires_grad:
+ output_tensors.append(output)
+ output_grad_tensors.append(grad)
+ torch.autograd.backward(output_tensors, grad_tensors=output_grad_tensors)
+
+ # Collect the grad of the input_tensor.
+ input_tensor_grad = [None]
+ if input_tensor is not None:
+ input_tensor_grad = []
+ for x in input_tensor:
+ if x is None:
+ input_tensor_grad.append(None)
+ else:
+ if x.grad is None:
+ input_tensor_grad.append(torch.zeros_like(x, device=torch.cuda.current_device()))
+ else:
+ input_tensor_grad.append(x.grad)
+
+ # Handle single skip connection if it exists (encoder_hidden_state in
+ # model with encoder and decoder).
+ if (
+ parallel_state.get_pipeline_model_parallel_world_size() > 1
+ and parallel_state.is_pipeline_stage_after_split()
+ and model_type == ModelType.encoder_and_decoder
+ ):
+ if output_tensor_grad[1] is not None:
+ input_tensor_grad[-1].add_(output_tensor_grad[1])
+ if unwrap_input_tensor_grad:
+ input_tensor_grad = input_tensor_grad[0]
+
+ if config.timers is not None:
+ config.timers('backward-compute').stop()
+
+ return input_tensor_grad
+
+
+def backward_step_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ return backward_step(*arg, **kwargs)
+ return wrapper
+
+
+def get_tensor_shapes_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ args = get_args()
+ return args.pipeline_tensor_shapes
+ return wrapper
+
+
+def forward_backward_pipelining_with_interleaving(
+ *,
+ forward_step_func,
+ data_iterator: Union[Iterator, List[Iterator]],
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
+ num_microbatches: int,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int = None,
+ forward_only: bool = False,
+ collect_non_loss_data: bool = False,
+ first_val_step: bool = None,
+):
+ assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
+ assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
+ assert isinstance(
+ data_iterator, list
+ ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
+
+ config = get_model_config(model[0])
+ if config.overlap_p2p_comm and config.batch_p2p_comm:
+ raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
+
+ if config.timers is not None:
+ config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
+
+ # Disable async grad reductions
+ no_sync_func = config.no_sync_func
+ if isinstance(no_sync_func, list):
+
+ def multi_no_sync():
+ stack = contextlib.ExitStack()
+ for model_chunk_no_sync_func in config.no_sync_func:
+ stack.enter_context(model_chunk_no_sync_func())
+ return stack
+
+ no_sync_func = multi_no_sync
+ if no_sync_func is None:
+ no_sync_func = contextlib.nullcontext
+ no_sync_context = None
+
+ if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
+ config.grad_sync_func = [config.grad_sync_func for _ in model]
+
+ if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
+ config.param_sync_func = [config.param_sync_func for _ in model]
+
+ def disable_grad_sync():
+ """Disable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is None:
+ no_sync_context = no_sync_func()
+ no_sync_context.__enter__()
+
+ def enable_grad_sync():
+ """Enable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is not None:
+ no_sync_context.__exit__(None, None, None)
+ no_sync_context = None
+
+ disable_grad_sync()
+
+ # Model chunk IDs with synchronized grads
+ synchronized_model_chunks = set()
+
+ input_tensors = [[] for _ in range(len(model))]
+ output_tensors = [[] for _ in range(len(model))]
+ total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
+
+ forward_data_store = []
+ if not forward_only:
+ output_tensor_grads = [[] for _ in range(len(model))]
+
+ pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
+
+ if num_microbatches % pipeline_parallel_size != 0:
+ msg = f'number of microbatches ({num_microbatches}) is not divisible by '
+ msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
+ msg += 'when using interleaved schedule'
+ raise RuntimeError(msg)
+
+ model_type = get_model_type(model[0])
+ if model_type == ModelType.encoder_and_decoder:
+ raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
+
+ if decoder_seq_length is not None and decoder_seq_length != seq_length:
+ raise RuntimeError(
+ "Interleaving is not supported with a different decoder sequence length."
+ )
+
+ tensor_shape = get_args().pipeline_tensor_shapes
+
+ # Compute number of warmup and remaining microbatches.
+ num_model_chunks = len(model)
+ total_num_microbatches = num_microbatches * num_model_chunks
+ all_warmup_microbatches = False
+ if forward_only:
+ num_warmup_microbatches = total_num_microbatches
+ else:
+ # Run all forward passes and then all backward passes if number of
+ # microbatches is just the number of pipeline stages.
+ # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
+ # all workers, followed by more microbatches after depending on
+ # stage ID (more forward passes for earlier stages, later stages can
+ # immediately start with 1F1B).
+ if num_microbatches == pipeline_parallel_size:
+ num_warmup_microbatches = total_num_microbatches
+ all_warmup_microbatches = True
+ else:
+ num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
+ num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
+ num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
+ num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
+
+ # Checkpoint the activations of partial Transformer layers in a number of micro-batches
+ # within the maximum outstanding micro-batch backpropagations.
+ # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
+ # checkpoint partial Transformer layers (or skip checkpointing) and
+ # the rest of micro-batches within a window of micro-batches checkpoint
+ # all Transformer layers. The window of micro-batches is set by the maximum
+ # outstanding backpropagations and becomes smaller at later pipeline stages.
+ # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
+ max_outstanding_backprops = None
+ if config.num_microbatches_with_partial_activation_checkpoints is not None:
+ max_outstanding_backprops = num_warmup_microbatches + 1
+
+ # Synchronize params for first two model chunks
+ if config.param_sync_func is not None:
+ config.param_sync_func[0](model[0].parameters())
+ config.param_sync_func[1](model[1].parameters())
+
+ def get_model_chunk_id(microbatch_id, forward):
+ """Helper method to get the model chunk ID given the iteration number."""
+ microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
+ model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
+ if not forward:
+ model_chunk_id = num_model_chunks - model_chunk_id - 1
+ return model_chunk_id
+
+ def get_microbatch_id_in_model_chunk(iteration_id, forward):
+ """Helper method to get the microbatch_id within model chunk given the iteration number."""
+ assert forward
+ iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks)
+ microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + (
+ iteration_id % pipeline_parallel_size
+ )
+ return microbatch_id_in_model_chunk
+
+ def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the first for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == 0:
+ return microbatch_id_in_group % pipeline_parallel_size == 0
+ else:
+ return False
+
+ def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the last for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == num_microbatch_groups - 1:
+ return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
+ else:
+ return False
+
+ def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch):
+ """Helper method to run forward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ forward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch param synchronization for next model chunk
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.param_sync_func is not None:
+ param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
+ if (
+ param_sync_microbatch_id < total_num_microbatches
+ and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
+ ):
+ param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
+ if 1 < param_sync_chunk_id < num_model_chunks:
+ config.param_sync_func[param_sync_chunk_id](
+ model[param_sync_chunk_id].parameters()
+ )
+
+ # forward step
+ if parallel_state.is_pipeline_first_stage():
+ if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
+ input_tensors[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id][-1]
+
+ output_tensor, num_tokens = forward_step(
+ forward_step_func,
+ data_iterator[model_chunk_id],
+ model[model_chunk_id],
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ checkpoint_activations_microbatch,
+ check_first_val_step(
+ first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id),
+ ),
+ current_microbatch=current_microbatch,
+ )
+ output_tensors[model_chunk_id].append(output_tensor)
+
+ nonlocal total_num_tokens
+ total_num_tokens += num_tokens.item()
+
+ # if forward-only, no need to save tensors for a backward pass
+ if forward_only:
+ input_tensors[model_chunk_id].pop()
+ output_tensors[model_chunk_id].pop()
+
+ return output_tensor
+
+ def backward_step_helper(microbatch_id):
+ """Helper method to run backward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ backward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch grad synchronization (default)
+ if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
+ enable_grad_sync()
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if parallel_state.is_pipeline_last_stage():
+ if len(output_tensor_grads[model_chunk_id]) == 0:
+ output_tensor_grads[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id].pop(0)
+ output_tensor = output_tensors[model_chunk_id].pop(0)
+ output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
+ input_tensor_grad = backward_step(
+ input_tensor, output_tensor, output_tensor_grad, model_type, config
+ )
+
+ # launch grad synchronization (custom grad sync)
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.grad_sync_func is not None:
+ grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
+ if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
+ grad_sync_microbatch_id
+ ):
+ grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
+ enable_grad_sync()
+ config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
+ synchronized_model_chunks.add(grad_sync_chunk_id)
+ disable_grad_sync()
+
+ return input_tensor_grad
+
+ # Run warmup forward passes.
+ parallel_state.set_virtual_pipeline_model_parallel_rank(0)
+ input_tensors[0].append(recv_forward(tensor_shape, config))
+
+ fwd_wait_handles = None
+ bwd_wait_handles = None
+
+ for k in range(num_warmup_microbatches):
+
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ cur_model_chunk_id = get_model_chunk_id(k, forward=True)
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True)
+ output_tensor = forward_step_helper(
+ k, current_microbatch, checkpoint_activations_microbatch
+ )
+
+ # Determine if tensor should be received from previous stage.
+ next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ if next_forward_model_chunk_id == 0:
+ recv_prev = False
+ if k == (total_num_microbatches - 1):
+ recv_prev = False
+
+ # Don't send tensor downstream if on last stage.
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Send and receive tensors as appropriate (send tensors computed
+ # in this iteration; receive tensors for next iteration).
+ if not config.overlap_p2p_comm:
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ tensor_shape,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ config=config,
+ )
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ else:
+ input_tensor = send_forward_recv_forward(
+ output_tensor, tensor_shape, recv_prev=recv_prev, config=config
+ )
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ else:
+ input_tensor, fwd_wait_handles = send_forward_recv_forward(
+ output_tensor,
+ tensor_shape,
+ recv_prev=recv_prev,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+
+ (
+ output_tensor_grad,
+ bwd_wait_handles,
+ ) = send_backward_recv_backward(
+ input_tensor_grad,
+ tensor_shape,
+ recv_next=recv_next,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Run 1F1B in steady state.
+ for k in range(num_microbatches_remaining):
+ # Forward pass.
+ forward_k = k + num_warmup_microbatches
+
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ forward_k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True)
+ if config.overlap_p2p_comm:
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ output_tensor = forward_step_helper(
+ forward_k, current_microbatch, checkpoint_activations_microbatch
+ )
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+
+ # Last virtual stage no activation tensor to send
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Send activation tensor to the next stage and receive activation tensor from the
+ # previous stage
+ input_tensor, fwd_wait_handles = send_forward_recv_forward(
+ output_tensor,
+ tensor_shape,
+ recv_prev=recv_prev,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+ # assert fwd_wait_handles is not None
+
+ if bwd_wait_handles is not None:
+ for req in bwd_wait_handles:
+ req.wait()
+
+ # Backward pass.
+ backward_k = k
+ input_tensor_grad = backward_step_helper(backward_k)
+
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+
+ # First virtual stage no activation gradient tensor to send
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if the current virtual stage has an activation gradient tensor to receive
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ output_tensor_grad, bwd_wait_handles = send_backward_recv_backward(
+ input_tensor_grad,
+ tensor_shape,
+ recv_next=recv_next,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ else: # no p2p overlap
+ output_tensor = forward_step_helper(
+ forward_k, current_microbatch, checkpoint_activations_microbatch
+ )
+
+ # Backward pass.
+ backward_k = k
+ input_tensor_grad = backward_step_helper(backward_k)
+
+ # Send output_tensor and input_tensor_grad, receive input_tensor
+ # and output_tensor_grad.
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Communicate tensors.
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ tensor_shape,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ config=config,
+ )
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Put input_tensor and output_tensor_grad in data structures in the
+ # right location.
+ if recv_prev:
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ if recv_next:
+ output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Run cooldown backward passes (flush out pipeline).
+ if not forward_only:
+ if config.overlap_p2p_comm and bwd_wait_handles is not None:
+ for wait_handle in bwd_wait_handles:
+ wait_handle.wait()
+
+ if all_warmup_microbatches:
+ output_tensor_grads[num_model_chunks - 1].append(
+ recv_backward(tensor_shape, config=config)
+ )
+ for k in range(num_microbatches_remaining, total_num_microbatches):
+ input_tensor_grad = backward_step_helper(k)
+ next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ if next_backward_model_chunk_id == (num_model_chunks - 1):
+ recv_next = False
+ if k == (total_num_microbatches - 1):
+ recv_next = False
+ output_tensor_grads[next_backward_model_chunk_id].append(
+ send_backward_recv_backward(
+ input_tensor_grad, tensor_shape, recv_next=recv_next, config=config
+ )
+ )
+
+ # Launch any remaining grad reductions.
+ enable_grad_sync()
+ if config.grad_sync_func is not None:
+ for model_chunk_id in range(num_model_chunks):
+ if model_chunk_id not in synchronized_model_chunks:
+ config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if config.finalize_model_grads_func is not None and not forward_only:
+ # Finalize model grads (perform full grad all-reduce / reduce-scatter for
+ # data parallelism, layernorm all-reduce for sequence parallelism, and
+ # embedding all-reduce for pipeline parallelism).
+ config.finalize_model_grads_func(
+ model, total_num_tokens if config.calculate_per_token_loss else None
+ )
+
+ if config.timers is not None:
+ config.timers('forward-backward').stop()
+
+ return forward_data_store
+
+
+def recv_forward(tensor_shapes, config):
+ input_tensors = []
+ for tensor_shape in tensor_shapes:
+ if tensor_shape is None:
+ input_tensors.append(None)
+ else:
+ config.pipeline_dtype = tensor_shape['dtype']
+ input_tensors.append(p2p_communication.recv_forward(tensor_shape['shape'], config))
+ return input_tensors
+
+
+def recv_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ return recv_forward(*arg, **kwargs)
+ return wrapper
+
+
+def recv_backward(tensor_shapes, config):
+ output_tensor_grads = []
+ for tensor_shape in tensor_shapes:
+ if tensor_shape is None:
+ output_tensor_grads.append(None)
+ else:
+ config.pipeline_dtype = tensor_shape['dtype']
+ output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape['shape'], config))
+ return output_tensor_grads
+
+
+def recv_backward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ return recv_backward(*arg, **kwargs)
+ return wrapper
+
+
+def send_forward(output_tensors, tensor_shapes, config):
+ if output_tensors is None:
+ output_tensors = [None] * len(tensor_shapes)
+ if not isinstance(output_tensors, list):
+ output_tensors = [output_tensors]
+ for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
+ if tensor_shape is None:
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ p2p_communication.send_forward(output_tensor, config)
+
+
+def send_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ return send_forward(*arg, **kwargs)
+ return wrapper
+
+
+def send_backward(input_tensor_grads, tensor_shapes, config):
+ if input_tensor_grads is None:
+ input_tensor_grads = [None] * len(tensor_shapes)
+ if not isinstance(input_tensor_grads, list):
+ input_tensor_grads = [input_tensor_grads]
+ for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
+ if tensor_shape is None:
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ p2p_communication.send_backward(input_tensor_grad, config)
+
+
+def send_backward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ return send_backward(*arg, **kwargs)
+ return wrapper
+
+
+def send_forward_recv_backward(output_tensors, tensor_shapes, config):
+ if not isinstance(output_tensors, list):
+ output_tensors = [None] * len(tensor_shapes)
+ output_tensor_grads = []
+ for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
+ if tensor_shape is None:
+ output_tensor_grads.append(None)
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ output_tensor_grad = p2p_communication.send_forward_recv_backward(
+ output_tensor, tensor_shape['shape'], config
+ )
+ output_tensor_grads.append(output_tensor_grad)
+ return output_tensor_grads
+
+
+def send_forward_recv_backward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ return send_forward_recv_backward(*arg, **kwargs)
+ return wrapper
+
+
+def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
+ if not isinstance(input_tensor_grads, list):
+ input_tensor_grads = [input_tensor_grads]
+ input_tensors = []
+ for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
+ if tensor_shape is None:
+ input_tensors.append(None)
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ input_tensor = p2p_communication.send_backward_recv_forward(
+ input_tensor_grad, tensor_shape['shape'], config
+ )
+ input_tensors.append(input_tensor)
+ return input_tensors
+
+
+def send_backward_recv_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*arg, **kwargs):
+ return send_backward_recv_forward(*arg, **kwargs)
+ return wrapper
+
+
+def send_forward_recv_forward(output_tensors, tensor_shapes, recv_prev, config, overlap_p2p_comm=False):
+ # overlap_p2p_comm
+ if output_tensors is None:
+ output_tensors = [None] * len(tensor_shapes)
+
+ if not isinstance(output_tensors, list):
+ output_tensors = [output_tensors]
+ input_tensors = []
+ all_fwd_wait_handles = []
+
+ if overlap_p2p_comm:
+ for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
+ if tensor_shape is None:
+ input_tensors.append(None)
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ input_tensor, wait_handles = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape['shape'],
+ config=config,
+ overlap_p2p_comm=overlap_p2p_comm,
+ )
+ input_tensors.append(input_tensor)
+ all_fwd_wait_handles.extend(wait_handles)
+ return input_tensors, all_fwd_wait_handles
+
+ else:
+ for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
+ if tensor_shape is None:
+ input_tensors.append(None)
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ input_tensor = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape['shape'],
+ config=config,
+ overlap_p2p_comm=overlap_p2p_comm,
+ )
+ input_tensors.append(input_tensor)
+ return input_tensors
+
+
+def send_backward_recv_backward(input_tensor_grads, tensor_shapes, recv_next, config, overlap_p2p_comm=False):
+ if input_tensor_grads is None:
+ input_tensor_grads = [None] * len(tensor_shapes)
+
+ if not isinstance(input_tensor_grads, list):
+ input_tensor_grads = [input_tensor_grads]
+ output_tensor_grads = []
+ all_fwd_wait_handles = []
+
+ if overlap_p2p_comm:
+ for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
+ if tensor_shape is None:
+ output_tensor_grads.append(None)
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape['shape'],
+ config=config,
+ overlap_p2p_comm=overlap_p2p_comm,
+ )
+ output_tensor_grads.append(output_tensor_grad)
+ all_fwd_wait_handles.extend(bwd_wait_handles)
+ return output_tensor_grads, all_fwd_wait_handles
+
+ else:
+ for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
+ if tensor_shape is None:
+ output_tensor_grads.append(None)
+ continue
+ config.pipeline_dtype = tensor_shape['dtype']
+ output_tensor_grad = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape['shape'],
+ config=config,
+ overlap_p2p_comm=overlap_p2p_comm,
+ )
+ output_tensor_grads.append(output_tensor_grad)
+ return output_tensor_grads
+
+
+def send_forward_backward_recv_forward_backward(output_tensors, input_tensor_grads, tensor_shapes, recv_prev, recv_next,
+ config):
+ if output_tensors is None:
+ output_tensors = [None] * len(tensor_shapes)
+ if input_tensor_grads is None:
+ input_tensor_grads = [None] * len(tensor_shapes)
+
+ if not isinstance(input_tensor_grads, list):
+ input_tensor_grads = [input_tensor_grads]
+ input_tensors = []
+ output_tensor_grads = []
+
+ for (output_tensor, input_tensor_grad, tensor_shape) in zip(output_tensors, input_tensor_grads, tensor_shapes):
+ config.pipeline_dtype = tensor_shape['dtype']
+ if tensor_shape is None:
+ input_tensors.append(None)
+ output_tensor_grads.append(None)
+ continue
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = p2p_communication.send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape['shape'],
+ config=config,
+ )
+ input_tensors.append(input_tensor)
+ output_tensor_grads.append(output_tensor_grad)
+ return input_tensors, output_tensor_grads
diff --git a/model/train/yoco_moe/mindspeed/core/pipeline_parallel/p2p_communication.py b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/p2p_communication.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a72dc45a3449133147c5c3364cf8f66ef339a3f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/p2p_communication.py
@@ -0,0 +1,471 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from megatron.core.parallel_state import (
+ get_pipeline_model_parallel_group,
+ get_pipeline_model_parallel_next_rank,
+ get_pipeline_model_parallel_prev_rank,
+ get_pipeline_model_parallel_rank,
+)
+from megatron.core.pipeline_parallel.p2p_communication import _batched_p2p_ops, _p2p_ops
+from megatron.core import ModelParallelConfig
+from megatron.training import get_args
+from mindspeed.utils import get_actual_seq_len, set_actual_seq_len, get_position_ids, set_position_ids
+# Types
+Shape = Union[List[int], torch.Size]
+
+
+def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config, tensor_dim: int = 3):
+ """Communicate tensor shapes between stages. Used to communicate
+ tensor shapes before the actual tensor communication happens.
+ This is required when the sequence lengths across micro batches
+ are not uniform.
+
+ Args:
+ tensor_send_next: tensor to send to next rank (no tensor sent if
+ set to None).
+ tensor_send_prev: tensor to send to prev rank (no tensor sent if
+ set to None).
+ recv_prev: boolean for whether tensor should be received from
+ previous rank.
+ recv_next: boolean for whether tensor should be received from
+ next rank.
+ Returns:
+ (recv_prev_shape, recv_next_shape)
+ """
+
+ recv_prev_shape_tensor = None
+ recv_next_shape_tensor = None
+ send_prev_shape_tensor = None
+ send_next_shape_tensor = None
+ if recv_prev:
+ recv_prev_shape_tensor = torch.empty(
+ (tensor_dim), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+ if recv_next:
+ recv_next_shape_tensor = torch.empty(
+ (tensor_dim), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+ if tensor_send_prev is not None:
+ send_prev_shape_tensor = torch.tensor(
+ tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+ if tensor_send_next is not None:
+ send_next_shape_tensor = torch.tensor(
+ tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64
+ )
+
+ if config.use_ring_exchange_p2p:
+ torch.distributed.ring_exchange(
+ tensor_send_prev=send_prev_shape_tensor,
+ tensor_recv_prev=recv_prev_shape_tensor,
+ tensor_send_next=send_next_shape_tensor,
+ tensor_recv_next=recv_next_shape_tensor,
+ group=get_pipeline_model_parallel_group(),
+ )
+
+ # Send tensors in both the forward and backward directions as appropriate.
+ if config.use_ring_exchange_p2p:
+
+ def _ring_exchange_wrapper(**kwargs):
+ torch.distributed.ring_exchange(**kwargs)
+ return []
+
+ p2p_func = _ring_exchange_wrapper
+ elif config.batch_p2p_comm:
+ p2p_func = _batched_p2p_ops
+ else:
+ p2p_func = _p2p_ops
+
+ reqs = p2p_func(
+ tensor_send_prev=send_prev_shape_tensor,
+ tensor_recv_prev=recv_prev_shape_tensor,
+ tensor_send_next=send_next_shape_tensor,
+ tensor_recv_next=recv_next_shape_tensor,
+ group=get_pipeline_model_parallel_group(),
+ )
+
+ if len(reqs) > 0:
+ for req in reqs:
+ req.wait()
+ reqs = None
+
+ if config.batch_p2p_comm and config.batch_p2p_sync:
+ # To protect against race condition when using batch_isend_irecv().
+ # User should assert that we have a modern enough PyTorch to not need this
+ torch.cuda.synchronize()
+
+ recv_prev_shape = [0, 0, 0]
+ if recv_prev_shape_tensor is not None:
+ recv_prev_shape = recv_prev_shape_tensor.tolist()
+
+ recv_next_shape = [0, 0, 0]
+ if recv_next_shape_tensor is not None:
+ recv_next_shape = recv_next_shape_tensor.tolist()
+
+ return recv_prev_shape, recv_next_shape
+
+
+def _communicate(
+ *,
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_send_prev: Optional[torch.Tensor],
+ recv_prev: bool,
+ recv_next: bool,
+ tensor_shape: Shape,
+ config: ModelParallelConfig,
+ wait_on_reqs: bool = True
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Communicate tensors between stages. Used as helper method in other
+ communication methods that are used in megatron/schedules.py.
+
+ Args:
+ tensor_send_next (torch.Tensor, optional):
+ Tensor to send to next rank (no tensor sent if None)
+
+ tensor_send_prev (torch.Tensor, optional):
+ Tensor to send to prev rank (no tensor sent if None)
+
+ recv_prev (boolean, required):
+ whether tensor should be received from previous rank.
+
+ recv_next (boolean, required):
+ whether tensor should be received from next rank.
+
+ tensor_shape (List[int] or torch.Size, required):
+ shape of tensor to receive (this method assumes that all
+ tensors sent and received in a single function call are
+ the same shape).
+
+ wait_on_reqs (boolean, optional, default=False):
+ For non-batched p2p communication, wait on each request
+ before returning.
+
+ Returns:
+ tuple containing
+
+ - tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
+ - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
+
+ """
+
+ # Create placeholder tensors for receive in forward and backward directions
+ # if needed.
+ tensor_recv_prev = None
+ tensor_recv_next = None
+
+ if not config.variable_seq_lengths:
+ recv_prev_shape = tensor_shape
+ recv_next_shape = tensor_shape
+ else:
+ tensor_dim = len(tensor_shape) if tensor_shape is not None else 3
+ recv_prev_shape, recv_next_shape = _communicate_shapes(
+ tensor_send_next, tensor_send_prev, recv_prev, recv_next, config, tensor_dim,
+ )
+
+ if recv_prev:
+ if config.pipeline_dtype is None:
+ raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
+ if tensor_shape is None:
+ raise RuntimeError(
+ "tensor_shape must be specified if recv_prev is True. "
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
+ )
+ tensor_recv_prev = torch.empty(
+ recv_prev_shape,
+ requires_grad=True,
+ device=torch.cuda.current_device(),
+ dtype=config.pipeline_dtype,
+ )
+ if recv_next:
+ if config.pipeline_dtype is None:
+ raise RuntimeError("dtype must be provided if recv_next is True")
+ if tensor_shape is None:
+ raise RuntimeError(
+ "tensor_shape must be specified if recv_next is True. "
+ "Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
+ )
+ tensor_recv_next = torch.empty(
+ recv_next_shape,
+ requires_grad=True,
+ device=torch.cuda.current_device(),
+ dtype=config.pipeline_dtype,
+ )
+
+ # Send tensors in both the forward and backward directions as appropriate.
+ if config.use_ring_exchange_p2p:
+
+ def _ring_exchange_wrapper(**kwargs):
+ torch.distributed.ring_exchange(**kwargs)
+ return []
+
+ p2p_func = _ring_exchange_wrapper
+ elif config.batch_p2p_comm:
+ assert wait_on_reqs
+ p2p_func = _batched_p2p_ops
+ else:
+ p2p_func = _p2p_ops
+
+ reqs = p2p_func(
+ tensor_send_prev=tensor_send_prev,
+ tensor_recv_prev=tensor_recv_prev,
+ tensor_send_next=tensor_send_next,
+ tensor_recv_next=tensor_recv_next,
+ group=get_pipeline_model_parallel_group(),
+ )
+
+ if wait_on_reqs and len(reqs) > 0:
+ for req in reqs:
+ req.wait()
+ reqs = None
+
+ if config.batch_p2p_comm and config.batch_p2p_sync:
+ # To protect against race condition when using batch_isend_irecv().
+ # User should assert that we have a modern enough PyTorch to not need this
+ torch.cuda.synchronize()
+
+ return tensor_recv_prev, tensor_recv_next, reqs
+
+
+def _p2p_ops_eod(
+ *,
+ tensor_send_prev: Optional[torch.Tensor],
+ tensor_recv_prev: Optional[torch.Tensor],
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_recv_next: Optional[torch.Tensor],
+ group: torch.distributed.ProcessGroup,
+):
+ reqs = []
+ rank = get_pipeline_model_parallel_rank()
+ prev_actual_seq_len = get_actual_seq_len()
+ prev_position_ids = get_position_ids()
+
+ tensor_length = None
+ length_buffer = None
+ args = get_args()
+ bsz = args.micro_batch_size
+ block_size = args.seq_length // args.context_parallel_size
+
+ if tensor_send_next is not None:
+ tensor_length = torch.tensor(prev_actual_seq_len.numel()).npu()
+
+ if tensor_recv_prev is not None:
+ length_buffer = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device())
+
+ if rank % 2 == 0:
+ if tensor_length is not None:
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_length, dst=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(send_next_req)
+
+ if length_buffer is not None:
+ recv_prev_req = torch.distributed.irecv(
+ tensor=length_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(recv_prev_req)
+ else:
+ if length_buffer is not None:
+ recv_prev_req = torch.distributed.irecv(
+ tensor=length_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(recv_prev_req)
+
+ if tensor_length is not None:
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_length, dst=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(send_next_req)
+
+ for req in reqs:
+ req.wait()
+
+ reqs = []
+
+ if get_pipeline_model_parallel_rank() % 2 == 0:
+ if tensor_send_next is not None:
+ req = torch.distributed.isend(
+ tensor=prev_actual_seq_len, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(),
+ )
+ reqs.append(req)
+
+ req = torch.distributed.isend(
+ tensor=prev_position_ids, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(),
+ )
+ reqs.append(req)
+
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(send_next_req)
+
+ if tensor_recv_prev is not None:
+ actual_seq_len_buffer = torch.empty([length_buffer.item()], dtype=torch.int64, device=torch.cuda.current_device())
+
+ req = torch.distributed.irecv(
+ tensor=actual_seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(req)
+ set_actual_seq_len(actual_seq_len_buffer)
+
+ position_ids_buffer = torch.empty((block_size, bsz), dtype=torch.int64, device=torch.cuda.current_device())
+ req = torch.distributed.irecv(
+ tensor=position_ids_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ set_position_ids(position_ids_buffer)
+ reqs.append(req)
+
+ recv_prev_req = torch.distributed.irecv(
+ tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(recv_prev_req)
+
+ if tensor_send_prev is not None:
+ send_prev_req = torch.distributed.isend(
+ tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(send_prev_req)
+
+ if tensor_recv_next is not None:
+ recv_next_req = torch.distributed.irecv(
+ tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(recv_next_req)
+
+ else:
+ if tensor_recv_prev is not None:
+ actual_seq_len_buffer = torch.empty([length_buffer.item()], dtype=torch.int64, device=torch.cuda.current_device())
+
+ req = torch.distributed.irecv(
+ tensor=actual_seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(req)
+ set_actual_seq_len(actual_seq_len_buffer)
+
+ position_ids_buffer = torch.empty((block_size, bsz), dtype=torch.int64, device=torch.cuda.current_device())
+ req = torch.distributed.irecv(
+ tensor=position_ids_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ set_position_ids(position_ids_buffer)
+ reqs.append(req)
+
+ recv_prev_req = torch.distributed.irecv(
+ tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(recv_prev_req)
+
+ if tensor_send_next is not None:
+ req = torch.distributed.isend(
+ tensor=prev_actual_seq_len, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(),
+ )
+ reqs.append(req)
+
+ req = torch.distributed.isend(
+ tensor=prev_position_ids, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(),
+ )
+ reqs.append(req)
+
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(send_next_req)
+
+ if tensor_recv_next is not None:
+ recv_next_req = torch.distributed.irecv(
+ tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(recv_next_req)
+
+ if tensor_send_prev is not None:
+ send_prev_req = torch.distributed.isend(
+ tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(send_prev_req)
+ return reqs
+
+
+def _p2p_ops_send_recv_overlap(
+ *,
+ tensor_send_prev: Optional[torch.Tensor],
+ tensor_recv_prev: Optional[torch.Tensor],
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_recv_next: Optional[torch.Tensor],
+ group: torch.distributed.ProcessGroup
+):
+ ops = []
+ if get_pipeline_model_parallel_rank() % 2 == 0:
+ if tensor_send_prev is not None:
+ send_prev_op = torch.distributed.P2POp(
+ torch.distributed.isend,
+ tensor_send_prev,
+ get_pipeline_model_parallel_prev_rank(),
+ group,
+ )
+ ops.append(send_prev_op)
+ if tensor_recv_prev is not None:
+ recv_prev_op = torch.distributed.P2POp(
+ torch.distributed.irecv,
+ tensor_recv_prev,
+ get_pipeline_model_parallel_prev_rank(),
+ group,
+ )
+ ops.append(recv_prev_op)
+ if tensor_send_next is not None:
+ send_next_op = torch.distributed.P2POp(
+ torch.distributed.isend,
+ tensor_send_next,
+ get_pipeline_model_parallel_next_rank(),
+ group,
+ )
+ ops.append(send_next_op)
+ if tensor_recv_next is not None:
+ recv_next_op = torch.distributed.P2POp(
+ torch.distributed.irecv,
+ tensor_recv_next,
+ get_pipeline_model_parallel_next_rank(),
+ group,
+ )
+ ops.append(recv_next_op)
+ else:
+ if tensor_recv_next is not None:
+ recv_next_op = torch.distributed.P2POp(
+ torch.distributed.irecv,
+ tensor_recv_next,
+ get_pipeline_model_parallel_next_rank(),
+ group,
+ )
+ ops.append(recv_next_op)
+ if tensor_send_next is not None:
+ send_next_op = torch.distributed.P2POp(
+ torch.distributed.isend,
+ tensor_send_next,
+ get_pipeline_model_parallel_next_rank(),
+ group,
+ )
+ ops.append(send_next_op)
+ if tensor_recv_prev is not None:
+ recv_prev_op = torch.distributed.P2POp(
+ torch.distributed.irecv,
+ tensor_recv_prev,
+ get_pipeline_model_parallel_prev_rank(),
+ group,
+ )
+ ops.append(recv_prev_op)
+ if tensor_send_prev is not None:
+ send_prev_op = torch.distributed.P2POp(
+ torch.distributed.isend,
+ tensor_send_prev,
+ get_pipeline_model_parallel_prev_rank(),
+ group,
+ )
+ ops.append(send_prev_op)
+
+ if len(ops) > 0:
+ reqs = torch.distributed.batch_isend_irecv(ops)
+ else:
+ reqs = []
+ return reqs
diff --git a/model/train/yoco_moe/mindspeed/core/pipeline_parallel/ripipe_schedules.py b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/ripipe_schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c8a27f72e3e5e62a8bc6d43dbcc49a8a4a2701
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/ripipe_schedules.py
@@ -0,0 +1,784 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections
+from typing import Iterator, List, Union
+import contextlib
+
+import torch
+from megatron.training import get_args
+from megatron.core import parallel_state
+from megatron.core.pipeline_parallel.schedules import deallocate_output_tensor, forward_step, backward_step, \
+ check_first_val_step
+from megatron.core.pipeline_parallel import p2p_communication
+from megatron.core.utils import get_model_config, get_model_type
+from megatron.core.enums import ModelType
+
+from mindspeed.core.tensor_parallel.checkpoint_manager import get_pipeline_checkpoint_manager
+from mindspeed.core.weight_grad_store import WeightGradStore
+
+
+def forward_backward_ripipe_pipelining(
+ *,
+ forward_step_func,
+ data_iterator: Union[Iterator, List[Iterator]],
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
+ num_microbatches: int,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int = None,
+ forward_only: bool = False,
+ collect_non_loss_data: bool = False,
+ first_val_step: bool = None,
+):
+ """Almost directly copied from megatron's forward_backward_pipelining_with_interleaving
+ function, all modifications are annotated with 'ripipe related' or 'nanopipe related' """
+ # ripipe related, setup checkpoint manager.
+ pipeline_checkpoint_manager = get_pipeline_checkpoint_manager(
+ num_of_chunks=parallel_state.get_virtual_pipeline_model_parallel_world_size())
+ args = get_args()
+ if args.recompute_in_bubble or args.recompute_in_advance:
+ pipeline_checkpoint_manager.open_ri_pipe = True
+ pipeline_checkpoint_manager.do_pre_recompute = True
+
+
+ """Run interleaved 1F1B schedule (model split into model chunks), with
+ communication between pipeline stages as needed.
+
+ Returns dictionary with losses if the last stage, empty dict otherwise."""
+ assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
+ assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
+ assert isinstance(
+ data_iterator, list
+ ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
+
+ config = get_model_config(model[0])
+ if config.overlap_p2p_comm and config.batch_p2p_comm:
+ raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
+
+ if config.timers is not None:
+ config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
+
+ # Disable async grad reductions
+ no_sync_func = config.no_sync_func
+ if isinstance(no_sync_func, list):
+
+ def multi_no_sync():
+ stack = contextlib.ExitStack()
+ for model_chunk_no_sync_func in config.no_sync_func:
+ stack.enter_context(model_chunk_no_sync_func())
+ return stack
+
+ no_sync_func = multi_no_sync
+ if no_sync_func is None:
+ no_sync_func = contextlib.nullcontext
+ no_sync_context = None
+
+ if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
+ config.grad_sync_func = [config.grad_sync_func for _ in model]
+
+ if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
+ config.param_sync_func = [config.param_sync_func for _ in model]
+
+ def disable_grad_sync():
+ """Disable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is None:
+ no_sync_context = no_sync_func()
+ no_sync_context.__enter__()
+
+ def enable_grad_sync():
+ """Enable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is not None:
+ no_sync_context.__exit__(None, None, None)
+ no_sync_context = None
+
+ disable_grad_sync()
+
+ # Model chunk IDs with synchronized grads
+ synchronized_model_chunks = set()
+
+ input_tensors = [[] for _ in range(len(model))]
+ output_tensors = [[] for _ in range(len(model))]
+ forward_data_store = []
+ if not forward_only:
+ output_tensor_grads = [[] for _ in range(len(model))]
+
+ pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
+
+ if num_microbatches % pipeline_parallel_size != 0:
+ msg = f'number of microbatches ({num_microbatches}) is not divisible by '
+ msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
+ msg += 'when using interleaved schedule'
+ raise RuntimeError(msg)
+
+ model_type = get_model_type(model[0])
+ if model_type == ModelType.encoder_and_decoder:
+ raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
+
+ if decoder_seq_length is not None and decoder_seq_length != seq_length:
+ raise RuntimeError(
+ "Interleaving is not supported with a different decoder sequence length."
+ )
+
+ tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+ tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
+ if config.sequence_parallel:
+ tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
+ tensor_shape[0] = tensor_shape[0] // args.tp_x
+ tensor_shape[-1] = tensor_shape[-1] // args.tp_y
+ # Compute number of warmup and remaining microbatches.
+ num_model_chunks = len(model)
+ total_num_microbatches = num_microbatches * num_model_chunks
+ all_warmup_microbatches = False
+ if forward_only:
+ num_warmup_microbatches = total_num_microbatches
+ else:
+ # ripipe related, no special handling of 'num_warmup_microbatches' when 'num_microbatches == pipeline_parallel_size'
+ num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
+ num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
+ num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
+ num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
+
+ num_fwd = min((pipeline_parallel_size - 1) * 2 + (num_model_chunks - 1) * pipeline_parallel_size, total_num_microbatches)
+ num_dx = num_fwd - num_warmup_microbatches
+ overlap_chunks_num = (num_dx + pipeline_parallel_size - 1) // pipeline_parallel_size
+ nano_flag = [True] * len(model)
+ for i in range(overlap_chunks_num):
+ nano_flag[-i - 1] = False
+ # ripipe related, calculate the variables needed by the recompute_in_bubble function
+ num_microbatches_recompute, num_microbatches_recompute_forward, num_microbatches_recompute_steady_groups, \
+ num_microbatches_recompute_tail = get_ripipe_recompute_count_params(num_microbatches,
+ num_model_chunks,
+ num_warmup_microbatches)
+
+ # Checkpoint the activations of partial Transformer layers in a number of micro-batches
+ # within the maximum outstanding micro-batch backpropagations.
+ # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
+ # checkpoint partial Transformer layers (or skip checkpointing) and
+ # the rest of micro-batches within a window of micro-batches checkpoint
+ # all Transformer layers. The window of micro-batches is set by the maximum
+ # outstanding backpropagations and becomes smaller at later pipeline stages.
+ # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
+ max_outstanding_backprops = None
+ if config.num_microbatches_with_partial_activation_checkpoints is not None:
+ max_outstanding_backprops = num_warmup_microbatches + 1
+
+ # Synchronize params for first two model chunks
+ if config.param_sync_func is not None:
+ config.param_sync_func[0](model[0].parameters())
+ config.param_sync_func[1](model[1].parameters())
+
+ def get_chunk_batch_id(microbatch_id, forward):
+ """ripipe related, needed by recompute_in_bubble function."""
+ microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
+ model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
+ if not forward:
+ model_chunk_id = num_model_chunks - model_chunk_id - 1
+ group_id = microbatch_id // (pipeline_parallel_size * num_model_chunks)
+ intra_chunk_batch_id = (microbatch_id_in_group % pipeline_parallel_size)
+ return group_id, intra_chunk_batch_id, model_chunk_id
+
+ def should_recompute(fk):
+ """ripipe related, needed by recompute_in_bubble function, used to determine
+ whether a mircobatch needs to be recomputed in the 1f1b stage."""
+ gid, intro_group_bid, chunk_id = get_chunk_batch_id(fk, forward=True)
+ if chunk_id == 0:
+ if gid < 2:
+ return False
+ elif gid < 2 + num_microbatches_recompute_steady_groups:
+ if intro_group_bid >= (1 + 2 * pipeline_parallel_rank):
+ return True
+ else:
+ if intro_group_bid >= pipeline_parallel_size - num_microbatches_recompute_tail:
+ return True
+ return False
+
+ def get_model_chunk_id(microbatch_id, forward):
+ """Helper method to get the model chunk ID given the iteration number."""
+ microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
+ model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
+ if not forward:
+ model_chunk_id = num_model_chunks - model_chunk_id - 1
+ return model_chunk_id
+
+ def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the first for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == 0:
+ return microbatch_id_in_group % pipeline_parallel_size == 0
+ else:
+ return False
+
+ def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
+ """Check if an iteration is the last for a model chunk."""
+ microbatch_group_size = pipeline_parallel_size * num_model_chunks
+ num_microbatch_groups = total_num_microbatches // microbatch_group_size
+ microbatch_group_id = microbatch_id // microbatch_group_size
+ microbatch_id_in_group = microbatch_id % microbatch_group_size
+ if microbatch_group_id == num_microbatch_groups - 1:
+ return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
+ else:
+ return False
+
+ def forward_step_helper(microbatch_id, checkpoint_activations_microbatch):
+ """Helper method to run forward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ forward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch param synchronization for next model chunk
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.param_sync_func is not None:
+ param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
+ if (
+ param_sync_microbatch_id < total_num_microbatches
+ and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
+ ):
+ param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
+ if 1 < param_sync_chunk_id < num_model_chunks:
+ config.param_sync_func[param_sync_chunk_id](
+ model[param_sync_chunk_id].parameters()
+ )
+
+ # forward step
+ if parallel_state.is_pipeline_first_stage():
+ if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
+ input_tensors[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id][-1]
+ output_tensor, _ = forward_step(
+ forward_step_func,
+ data_iterator[model_chunk_id],
+ model[model_chunk_id],
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ checkpoint_activations_microbatch,
+ check_first_val_step(
+ first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id),
+ ),
+ )
+ output_tensors[model_chunk_id].append(output_tensor)
+
+ # if forward-only, no need to save tensors for a backward pass
+ if forward_only:
+ input_tensors[model_chunk_id].pop()
+ output_tensors[model_chunk_id].pop()
+
+ # ripipe related, when a microbatch finish its forward pass, save needed recomputation
+ # functions for this microbatch.
+ if args.recompute_in_bubble or args.recompute_in_advance:
+ pipeline_checkpoint_manager.batch_fin(model_chunk_id)
+
+ return output_tensor
+
+ def backward_step_helper(microbatch_id):
+ """Helper method to run backward step with model split into chunks
+ (run set_virtual_pipeline_model_parallel_rank() before calling
+ backward_step())."""
+ model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
+
+ # launch grad synchronization (default)
+ if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id) and nano_flag[model_chunk_id]:
+ enable_grad_sync()
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if parallel_state.is_pipeline_last_stage():
+ if len(output_tensor_grads[model_chunk_id]) == 0:
+ output_tensor_grads[model_chunk_id].append(None)
+ input_tensor = input_tensors[model_chunk_id].pop(0)
+ output_tensor = output_tensors[model_chunk_id].pop(0)
+ output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
+ input_tensor_grad = backward_step(
+ input_tensor, output_tensor, output_tensor_grad, model_type, config
+ )
+
+ # launch grad synchronization (custom grad sync)
+ # Note: Asynchronous communication tends to slow down compute.
+ # To reduce idling from mismatched microbatch times, we launch
+ # asynchronous communication at the same time across the
+ # pipeline-parallel group.
+ if config.grad_sync_func is not None:
+ grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
+ if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
+ grad_sync_microbatch_id
+ ):
+ grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
+ if nano_flag[grad_sync_chunk_id]:
+ enable_grad_sync()
+ config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
+ synchronized_model_chunks.add(grad_sync_chunk_id)
+ disable_grad_sync()
+
+ return input_tensor_grad
+
+ # Run warmup forward passes.
+ parallel_state.set_virtual_pipeline_model_parallel_rank(0)
+ input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
+
+ fwd_wait_handles = None
+ bwd_wait_handles = None
+
+ for k in range(num_warmup_microbatches):
+
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ # ripipe related, when use recompute_in_bubble function, do not do recompute
+ # for the first pp * vp microbatches.
+ if args.recompute_in_bubble:
+ if k < pipeline_parallel_size * num_model_chunks:
+ pipeline_checkpoint_manager.disable_recompute()
+ else:
+ num_microbatches_recompute_forward -= 1
+ output_tensor = forward_step_helper(k, checkpoint_activations_microbatch)
+ if args.recompute_in_bubble or args.recompute_in_advance:
+ pipeline_checkpoint_manager.enable_recompute()
+
+ # Determine if tensor should be received from previous stage.
+ next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ if next_forward_model_chunk_id == 0:
+ recv_prev = False
+ if k == (total_num_microbatches - 1):
+ recv_prev = False
+
+ # Don't send tensor downstream if on last stage.
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Send and receive tensors as appropriate (send tensors computed
+ # in this iteration; receive tensors for next iteration).
+ if not config.overlap_p2p_comm:
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = p2p_communication.send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ )
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ else:
+ input_tensor = p2p_communication.send_forward_recv_forward(
+ output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
+ )
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ else:
+ input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ if (
+ k == (num_warmup_microbatches - 1)
+ and not forward_only
+ and not all_warmup_microbatches
+ ):
+ input_tensor_grad = None
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ recv_next = False
+
+ (
+ output_tensor_grad,
+ bwd_wait_handles,
+ ) = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Run 1F1B in steady state.
+ for k in range(num_microbatches_remaining):
+ # Forward pass.
+ forward_k = k + num_warmup_microbatches
+
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ forward_k % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ if config.overlap_p2p_comm:
+ if fwd_wait_handles is not None:
+ for req in fwd_wait_handles:
+ req.wait()
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # ripipe related, determine whether this microbatch should be recomputed
+ # when using recompute_in_bubble function.
+ if args.recompute_in_bubble:
+ if num_microbatches_recompute_forward > 0:
+ num_microbatches_recompute_forward -= 1
+ elif num_microbatches_recompute > 0 and should_recompute(forward_k):
+ pass
+ else:
+ pipeline_checkpoint_manager.disable_recompute()
+ output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
+ if args.recompute_in_bubble or args.recompute_in_advance:
+ pipeline_checkpoint_manager.enable_recompute()
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+
+ # Last virtual stage no activation tensor to send
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Send activation tensor to the next stage and receive activation tensor from the
+ # previous stage
+ input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
+ output_tensor,
+ recv_prev=recv_prev,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+ # assert fwd_wait_handles is not None
+
+ # ripipe related, actually do the recomputation.
+ if args.recompute_in_advance or args.recompute_in_bubble:
+ vpp_rank = get_model_chunk_id(k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(vpp_rank)
+ if not parallel_state.is_pipeline_last_stage() or args.recompute_in_bubble:
+ pipeline_checkpoint_manager.recompute_next(vpp_rank)
+
+ if bwd_wait_handles is not None:
+ for req in bwd_wait_handles:
+ req.wait()
+
+ # Backward pass.
+ backward_k = k
+ if k < num_dx and args.use_nanopipe:
+ WeightGradStore.start_decouple()
+
+ if args.use_nanopipe:
+ WeightGradStore.resize_ori_storage(args.use_nanopipe_swap)
+
+ input_tensor_grad = backward_step_helper(backward_k)
+ if args.use_nanopipe:
+ if WeightGradStore.is_decoupleBlock:
+ WeightGradStore.flush()
+ if k == num_dx - 1:
+ WeightGradStore.end_decouple()
+
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+
+ # First virtual stage no activation gradient tensor to send
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if the current virtual stage has an activation gradient tensor to receive
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ overlap_p2p_comm=True,
+ )
+
+ else: # no p2p overlap
+ output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
+
+ # Backward pass.
+ backward_k = k
+ if k < num_dx and args.use_nanopipe:
+ WeightGradStore.start_decouple()
+
+ if args.use_nanopipe:
+ WeightGradStore.resize_ori_storage(args.use_nanopipe_swap)
+
+ input_tensor_grad = backward_step_helper(backward_k)
+ if k == num_dx - 1 and args.use_nanopipe:
+ WeightGradStore.end_decouple()
+
+ # Send output_tensor and input_tensor_grad, receive input_tensor
+ # and output_tensor_grad.
+
+ # Determine if current stage has anything to send in either direction,
+ # otherwise set tensor to None.
+ forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
+ if parallel_state.is_pipeline_last_stage():
+ output_tensor = None
+
+ backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
+ if parallel_state.is_pipeline_first_stage():
+ input_tensor_grad = None
+
+ # Determine if peers are sending, and where in data structure to put
+ # received tensors.
+ recv_prev = True
+ if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
+ # First stage is ahead of last stage by (pipeline_parallel_size - 1).
+ next_forward_model_chunk_id = get_model_chunk_id(
+ forward_k - (pipeline_parallel_size - 1), forward=True
+ )
+ if next_forward_model_chunk_id == (num_model_chunks - 1):
+ recv_prev = False
+ next_forward_model_chunk_id += 1
+ else:
+ next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
+
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
+ next_backward_model_chunk_id = get_model_chunk_id(
+ backward_k - (pipeline_parallel_size - 1), forward=False
+ )
+ if next_backward_model_chunk_id == 0:
+ recv_next = False
+ next_backward_model_chunk_id -= 1
+ else:
+ next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
+
+ # If last iteration, don't receive; we already received one extra
+ # before the start of the for loop.
+ if k == (num_microbatches_remaining - 1):
+ recv_prev = False
+
+ # Communicate tensors.
+ (
+ input_tensor,
+ output_tensor_grad,
+ ) = p2p_communication.send_forward_backward_recv_forward_backward(
+ output_tensor,
+ input_tensor_grad,
+ recv_prev=recv_prev,
+ recv_next=recv_next,
+ tensor_shape=tensor_shape,
+ config=config,
+ )
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Put input_tensor and output_tensor_grad in data structures in the
+ # right location.
+ if recv_prev:
+ input_tensors[next_forward_model_chunk_id].append(input_tensor)
+ if recv_next:
+ output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
+
+ deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
+
+ # Run cooldown backward passes (flush out pipeline).
+ if not forward_only:
+ # ripipe related, actually do the recomputation.
+ if args.recompute_in_advance:
+ vpp_rank = get_model_chunk_id(num_microbatches_remaining, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(vpp_rank)
+ if not parallel_state.is_pipeline_last_stage():
+ pipeline_checkpoint_manager.recompute_next(vpp_rank)
+ if args.recompute_in_bubble and num_microbatches_recompute > 0:
+ old_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
+ parallel_state.set_virtual_pipeline_model_parallel_rank(0)
+ pipeline_checkpoint_manager.recompute_next_force(0)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(old_vpp_rank)
+ if config.overlap_p2p_comm and bwd_wait_handles is not None:
+ for wait_handle in bwd_wait_handles:
+ wait_handle.wait()
+
+ if all_warmup_microbatches:
+ output_tensor_grads[num_model_chunks - 1].append(
+ p2p_communication.recv_backward(tensor_shape, config=config)
+ )
+
+ # ripipe related
+ if args.recompute_in_bubble:
+ num_microbatches_recompute_forward = 1
+ for k in range(num_microbatches_remaining, total_num_microbatches):
+ input_tensor_grad = backward_step_helper(k)
+ next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
+ recv_next = True
+ if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
+ if next_backward_model_chunk_id == (num_model_chunks - 1):
+ recv_next = False
+ if k == (total_num_microbatches - 1):
+ recv_next = False
+
+ # ripipe related, use async communication
+ out_tensor, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
+ input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config, overlap_p2p_comm=True
+ )
+ output_tensor_grads[next_backward_model_chunk_id].append(
+ out_tensor
+ )
+
+ if args.use_nanopipe and args.use_nanopipe_swap and k == max(num_microbatches_remaining + 1, (total_num_microbatches + num_microbatches_remaining) // 2):
+ WeightGradStore.swap_tensors()
+
+ # ripipe related, actually do the recomputation
+ if args.recompute_in_bubble and num_microbatches_recompute > 0 and \
+ num_microbatches_recompute_forward < num_microbatches_recompute:
+ old_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
+ parallel_state.set_virtual_pipeline_model_parallel_rank(0)
+ pipeline_checkpoint_manager.recompute_next_force(0)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(old_vpp_rank)
+ num_microbatches_recompute_forward += 1
+ if args.recompute_in_advance and k != (total_num_microbatches - 1):
+ vpp_rank = get_model_chunk_id(k + 1, forward=False)
+ parallel_state.set_virtual_pipeline_model_parallel_rank(vpp_rank)
+ if not parallel_state.is_pipeline_last_stage():
+ pipeline_checkpoint_manager.recompute_next(vpp_rank)
+ # ripipe related, use async communication
+ if config.overlap_p2p_comm and bwd_wait_handles is not None:
+ for wait_handle in bwd_wait_handles:
+ wait_handle.wait()
+
+ # nanopipe related
+ if args.use_nanopipe:
+ if nano_flag[0] and 0 not in synchronized_model_chunks:
+ config.grad_sync_func[0](model[0].parameters())
+ synchronized_model_chunks.add(0)
+ overlap_arg = [pipeline_parallel_size, nano_flag, synchronized_model_chunks, config.grad_sync_func, model]
+ WeightGradStore.pop(overlap_arg)
+
+ # Launch any remaining grad reductions.
+ enable_grad_sync()
+ if config.grad_sync_func is not None:
+ for model_chunk_id in range(num_model_chunks):
+ if model_chunk_id not in synchronized_model_chunks:
+ config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
+ synchronized_model_chunks.add(model_chunk_id)
+
+ if config.timers is not None:
+ config.timers('forward-backward').stop()
+
+ if config.finalize_model_grads_func is not None and not forward_only:
+ # Finalize model grads (perform full grad all-reduce / reduce-scatter for
+ # data parallelism, layernorm all-reduce for sequence parallelism, and
+ # embedding all-reduce for pipeline parallelism).
+ config.finalize_model_grads_func(model)
+
+ # ripipe related, check all the needed recomputation is done.
+ if args.recompute_in_bubble or args.recompute_in_advance:
+ pipeline_checkpoint_manager.iter_fin()
+
+ return forward_data_store
+
+
+def get_ripipe_recompute_count_params(num_microbatches, num_model_chunks, num_warmup_microbatches):
+ """ripipe related, calculate the variables needed by the recompute_in_bubble function"""
+ args = get_args()
+ pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
+ num_microbatches_recompute_steady_groups = 0
+ num_microbatches_recompute_tail = 0
+ num_microbatches_recompute = 0
+ num_microbatches_recompute_forward = 0
+ if args.recompute_in_bubble and num_microbatches // pipeline_parallel_size > 1:
+ num_microbatches_recompute = num_warmup_microbatches + 1 - num_model_chunks * pipeline_parallel_size
+ if num_microbatches_recompute < 0:
+ num_microbatches_recompute = 0
+
+ num_microbatches_recompute_forward = num_microbatches_recompute
+ if num_microbatches_recompute > 0 and num_microbatches // pipeline_parallel_size >= 3:
+ num_microbatches_recompute_steady_groups = (num_microbatches // pipeline_parallel_size) - 3
+ num_microbatches_recompute_tail = 2 + 2 * pipeline_parallel_rank
+ if num_microbatches_recompute_steady_groups == 0:
+ if num_microbatches_recompute_tail >= pipeline_parallel_size - 1 - 2 * pipeline_parallel_rank:
+ num_microbatches_recompute_tail = 0
+ num_microbatches_recompute_steady_groups = 1
+ else:
+ num_microbatches_recompute_tail = 1
+
+ params = collections.namedtuple('RecomputeCountParams',
+ ['num_microbatches_recompute', 'num_microbatches_recompute_forward',
+ 'num_microbatches_recompute_steady_groups', 'num_microbatches_recompute_tail'])
+ return params(num_microbatches_recompute, num_microbatches_recompute_forward,
+ num_microbatches_recompute_steady_groups, num_microbatches_recompute_tail)
diff --git a/model/train/yoco_moe/mindspeed/core/pipeline_parallel/schedules.py b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4b9f166c5c753c26844815997278caa58b9b1e8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/pipeline_parallel/schedules.py
@@ -0,0 +1,184 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reversed.
+# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
+import contextlib
+import time
+import torch
+from functools import wraps
+from megatron.core.enums import ModelType
+from megatron.training import get_args
+from megatron.core import parallel_state
+from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
+from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
+from megatron.core.pipeline_parallel.schedules import set_current_microbatch
+from mindspeed.core.pipeline_parallel import flexible_schedules
+from mindspeed.core.pipeline_parallel.ripipe_schedules import forward_backward_ripipe_pipelining
+from mindspeed.core.pipeline_parallel import multiparameter_schedules
+from mindspeed.core.auto_parallel.mm_search.help import PROFILE_CONTENT
+
+LOSS_BACKWARD_SCALE = torch.tensor(1.0)
+
+
+def get_forward_backward_func_wrapper(get_forward_backward_func):
+ @wraps(get_forward_backward_func)
+ def wrapper(*args, **kwargs):
+ arguments = get_args()
+ if arguments.optimize_send_recv_comm and arguments.num_layers_per_virtual_pipeline_stage is None:
+ return flexible_schedules.forward_backward_pipelining_without_interleaving
+
+ if arguments.automated_pipeline_perf and arguments.pp_schedule_list:
+ return flexible_schedules.forward_backward_pipelining_without_interleaving
+
+ if (arguments.recompute_in_bubble or arguments.recompute_in_advance) and torch.is_grad_enabled():
+ return forward_backward_ripipe_pipelining
+
+ if parallel_state.get_pipeline_model_parallel_world_size() > 1 \
+ and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None \
+ and arguments.use_nanopipe:
+ return flexible_schedules.forward_backward_pipelining_with_interleaving_nano_pipe
+
+ if arguments.use_multiparameter_pipeline_model_parallel:
+ pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
+ if pipeline_model_parallel_size > 1 \
+ and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
+ return multiparameter_schedules.forward_backward_pipelining_with_interleaving
+
+ return get_forward_backward_func(*args, **kwargs)
+ return wrapper
+
+
+def forward_step(
+ forward_step_func,
+ data_iterator,
+ model,
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data=False,
+ checkpoint_activations_microbatch=None,
+ is_first_microbatch=False,
+ current_microbatch=None,
+):
+
+ """Forward step for passed-in model.
+
+ If first stage, input tensor is obtained from data_iterator, otherwise
+ passed-in input_tensor is used.
+
+ Returns output tensor."""
+ arguments = get_args()
+ if arguments.auto_parallel_profile:
+ torch.cuda.synchronize()
+ start_time = time.time()
+ torch.npu.reset_max_memory_allocated()
+ start_memory = torch.npu.memory_allocated()
+
+ if config.timers is not None:
+ config.timers('forward-compute', log_level=2).start()
+
+ if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
+ model.set_is_first_microbatch()
+ if current_microbatch is not None:
+ set_current_microbatch(model, current_microbatch)
+
+ unwrap_output_tensor = False
+ if not isinstance(input_tensor, list):
+ input_tensor = [input_tensor]
+ unwrap_output_tensor = True
+
+ set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
+ set_input_tensor(input_tensor)
+
+ if config.enable_autocast:
+ context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
+ else:
+ context_manager = contextlib.nullcontext()
+ with context_manager:
+ if checkpoint_activations_microbatch is None:
+ output_tensor, loss_func = forward_step_func(data_iterator, model)
+ else:
+ output_tensor, loss_func = forward_step_func(
+ data_iterator, model, checkpoint_activations_microbatch
+ )
+
+ num_tokens = torch.tensor(0, dtype=torch.int)
+ if parallel_state.is_pipeline_last_stage():
+ if not collect_non_loss_data:
+ outputs = loss_func(output_tensor)
+ if len(outputs) == 3:
+ output_tensor, num_tokens, loss_reduced = outputs
+ if not config.calculate_per_token_loss:
+ output_tensor /= num_tokens
+ output_tensor /= num_microbatches
+ else:
+ # preserve legacy loss averaging behavior (ie, over the number of microbatches)
+ assert len(outputs) == 2
+ output_tensor, loss_reduced = outputs
+ output_tensor /= num_microbatches
+ forward_data_store.append(loss_reduced)
+ else:
+ data = loss_func(output_tensor, non_loss_data=True)
+ forward_data_store.append(data)
+
+ if config.timers is not None:
+ config.timers('forward-compute').stop()
+
+ # Set the loss scale for the auxiliary loss of the MoE layer.
+ # Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
+ if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
+ # Calculate the loss scale based on the grad_scale_func if available, else default to 1.
+ loss_scale = (
+ config.grad_scale_func(LOSS_BACKWARD_SCALE)
+ if config.grad_scale_func is not None
+ else torch.tensor(1.0)
+ )
+ # Set the loss scale
+ MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
+
+ if arguments.auto_parallel_profile:
+ torch.cuda.synchronize()
+ PROFILE_CONTENT["fwd_time"].append((time.time() - start_time) * 1000) # time: ms
+ PROFILE_CONTENT["act_mem"].append((torch.npu.memory_allocated() - start_memory) / 1024 ** 2) # memory: MB
+ max_mem = torch.npu.max_memory_allocated()
+
+ # If T5 model (or other model with encoder and decoder)
+ # and in decoder stack, then send encoder_hidden_state
+ # downstream as well.
+ model_type = get_model_type(model)
+ if (
+ parallel_state.is_pipeline_stage_after_split()
+ and model_type == ModelType.encoder_and_decoder
+ ):
+ return [output_tensor, input_tensor[-1]], num_tokens
+
+ if unwrap_output_tensor:
+ return output_tensor, num_tokens
+ return [output_tensor], num_tokens
+
+
+def get_tensor_shapes_wrapper(get_tensor_shapes):
+ @wraps(get_tensor_shapes)
+ def wrapper(*args, **kwargs):
+ # [s, b, h]
+ tensor_shapes = get_tensor_shapes(*args, **kwargs)
+ arguments = get_args()
+ if arguments.tp_2d:
+ tensor_shapes = [[tensor_shape[0] // arguments.tp_x, tensor_shape[1], tensor_shape[2] // arguments.tp_y]
+ for tensor_shape in tensor_shapes]
+
+ return tensor_shapes
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/simple_parallel_cfg.py b/model/train/yoco_moe/mindspeed/core/simple_parallel_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e4524f873b8205c2c0632b4facdd550abb0588
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/simple_parallel_cfg.py
@@ -0,0 +1,19 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from collections import namedtuple
+
+SimpleParallelCfg = namedtuple(
+ "SimpleParallelCfg", field_names=["dp", "pp", "tp", "cp", "ep", "tp_x", "tp_y"]
+)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/singleton_meta.py b/model/train/yoco_moe/mindspeed/core/singleton_meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..581dc2c95a117a37137be65e8028b8a13e9df193
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/singleton_meta.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import threading
+
+
+class SingletonMeta(type):
+ _instances = {}
+ _lock = threading.Lock()
+
+ def __call__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls not in cls._instances:
+ instance = super().__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/__init__.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/__init__.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e0ed5acd0becaa95d7b2a7ca181550253e2430b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py
@@ -0,0 +1,40 @@
+class AscendConfig:
+ def __init__(self):
+ self.ColumnParallelLinear = None
+ self.RowParallelLinear = None
+ self.group_func = None
+ self.world_size_func = None
+
+ self.sequence_parallel_enabled = True
+ self.all_gather_recomputation = True
+
+ def set_sequence_parallel(self, sequence_parallel):
+ self.sequence_parallel = sequence_parallel
+
+ def set_all_gather_recomputation(self, all_gather_recomputation):
+ self.all_gather_recomputation = all_gather_recomputation
+
+ def set_group(self, group_func):
+ self.group_func = group_func
+
+ def get_group(self):
+ return self.group_func()
+
+ def set_world_size(self, world_size_func):
+ self.world_size_func = world_size_func
+
+ def get_world_size(self):
+ return self.world_size_func()
+
+ def set_column_parallel_linear(self, column_parallel_linear):
+ self.ColumnParallelLinear = column_parallel_linear
+
+ def set_row_parallel_linear(self, row_parallel_linear):
+ self.RowParallelLinear = row_parallel_linear
+
+ def parallel_linear_plugin(self, column_parallel_forward, row_parallel_forward):
+ self.ColumnParallelLinear.forward = column_parallel_forward
+ self.RowParallelLinear.forward = row_parallel_forward
+
+
+ascend_turbo_cfg = AscendConfig()
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/initialize.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/initialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5f43167a9dc95d84be6c608be4fd287d1d70dcf
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/initialize.py
@@ -0,0 +1,94 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .ascend_turbo_cfg import ascend_turbo_cfg
+from .mc2_linears_seq_parallel import (ColumnSeqParallelLinear, RowSeqParallelLinear,
+ ColumnSeqParallelLinearWithFrozenWeight, RowSeqParallelLinearWithFrozenWeight)
+
+
+def column_parallel_forward(self, input_, weight=None):
+ if weight is None:
+ if self.weight is None:
+ raise RuntimeError(
+ "weight was not supplied to ColumnParallelLinear forward pass"
+ "and skip_weight_param_allocation is True."
+ )
+ weight = self.weight
+ else:
+ # Check the weight passed in is the correct shape
+ expected_shape = (self.output_size_per_partition, self.input_size)
+ if weight.shape != expected_shape:
+ raise RuntimeError(
+ f"supplied weight's shape is {tuple(weight.shape)},"
+ f"not {expected_shape} as expected"
+ )
+
+ bias = self.bias if not self.skip_bias_add else None
+
+ if not weight.requires_grad:
+ output = ColumnSeqParallelLinearWithFrozenWeight.apply(
+ input_, weight, bias, ascend_turbo_cfg.get_group()
+ )
+ else:
+ output = ColumnSeqParallelLinear.apply(
+ input_, weight, bias, ascend_turbo_cfg.get_group()
+ )
+
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+
+def row_parallel_forward(self, input_):
+ if not self.weight.requires_grad:
+ output = RowSeqParallelLinearWithFrozenWeight.apply(
+ input_, self.weight, None, ascend_turbo_cfg.get_group()
+ )
+ else:
+ output = RowSeqParallelLinear.apply(
+ input_, self.weight, None, ascend_turbo_cfg.get_group()
+ )
+
+ if not self.skip_bias_add:
+ output = output + self.bias if self.bias is not None else output
+ output_bias = None
+ else:
+ output_bias = self.bias
+
+ return output, output_bias
+
+
+def initialize_cfg_from_framework():
+ from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
+ from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size
+
+ ascend_turbo_cfg.set_group(get_tensor_model_parallel_group)
+ ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
+
+ ascend_turbo_cfg.set_column_parallel_linear(ColumnParallelLinear)
+ ascend_turbo_cfg.set_row_parallel_linear(RowParallelLinear)
+
+ ascend_turbo_cfg.parallel_linear_plugin(
+ column_parallel_forward, row_parallel_forward
+ )
+
+
+def initialize_cfg_from_args(args):
+ if not args.sequence_parallel or args.tensor_model_parallel_size == 1 or args.use_pipe_experts:
+ return
+ if args.use_nanopipe:
+ return
+ ascend_turbo_cfg.set_sequence_parallel(args.sequence_parallel)
+ ascend_turbo_cfg.set_all_gather_recomputation(True)
+ initialize_cfg_from_framework()
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..b571022beb5ea0ec36a760e96e62198f4e6aca15
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py
@@ -0,0 +1,352 @@
+import torch
+import torch_npu
+from megatron.training import get_args
+from .ascend_turbo_cfg import ascend_turbo_cfg
+
+
+class ColumnSeqParallelLinear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, group):
+ ctx.save_for_backward(input_)
+ ctx.use_bias = bias is not None
+ ctx.weight = weight
+ ctx.gradient_accumulation_fusion = get_args().gradient_accumulation_fusion
+
+ rank = torch.distributed.get_rank(group)
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
+ global_rank
+ )
+
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+
+ world_size = ascend_turbo_cfg.get_world_size()
+ # npu_all_gather_base_mm currently do not support bias
+ output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
+ x,
+ weight.t(),
+ hcomm_info,
+ world_size,
+ bias=None,
+ gather_index=0,
+ gather_output=(not ascend_turbo_cfg.all_gather_recomputation),
+ )
+
+ if bias is not None:
+ output = output + bias
+
+ output = output.view(
+ int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
+ )
+
+ ctx.all_gather_output = all_gather_grad_output
+ ctx.world_size = world_size
+ ctx.group = group
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_ = ctx.saved_tensors[0]
+ weight = ctx.weight
+
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+
+ if ascend_turbo_cfg.all_gather_recomputation:
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] * ctx.world_size
+ all_gather_output = torch.empty(
+ dim_size,
+ dtype=input_.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ all_gather_work = torch.distributed._all_gather_base(
+ all_gather_output, input_.contiguous(), group=ctx.group, async_op=True
+ )
+ else:
+ all_gather_output = ctx.all_gather_output
+
+ grad_input = grad_output_.matmul(weight)
+ grad_input = grad_input.reshape(
+ grad_output.shape[0], grad_output.shape[1], weight.shape[1]
+ )
+
+ sub_grad_input = torch.empty(
+ list(input_.size()), dtype=input_.dtype, device=torch.cuda.current_device()
+ )
+ reduce_scatter_work = torch.distributed._reduce_scatter_base(
+ sub_grad_input, grad_input, group=ctx.group, async_op=True
+ )
+
+ if ascend_turbo_cfg.all_gather_recomputation:
+ all_gather_work.wait()
+ all_gather_output = all_gather_output.reshape(
+ all_gather_output.shape[0] * all_gather_output.shape[1],
+ all_gather_output.shape[2],
+ )
+
+ if ctx.gradient_accumulation_fusion and weight.main_grad.dtype == torch.float32:
+ from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32
+ npu_matmul_add_fp32(all_gather_output, grad_output_, weight.main_grad)
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input_.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input_.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output_.t().matmul(all_gather_output)
+
+ is_grad_bias_needed = ctx.needs_input_grad[2]
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = (
+ grad_output_.sum(dim=0)
+ if grad_output_.is_contiguous()
+ else grad_output_.t().sum(dim=1)
+ )
+ else:
+ grad_bias = None
+
+ reduce_scatter_work.wait()
+ return sub_grad_input, grad_weight, grad_bias, None
+
+
+class RowSeqParallelLinear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, group):
+ ctx.save_for_backward(input_)
+ ctx.use_bias = bias is not None
+ ctx.weight = weight
+ ctx.gradient_accumulation_fusion = get_args().gradient_accumulation_fusion
+
+ rank = torch.distributed.get_rank(group)
+ world_size = ascend_turbo_cfg.get_world_size()
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
+ global_rank
+ )
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+
+ # npu_mm_reduce_scatter_base currently do not support bias
+ output = torch_npu.npu_mm_reduce_scatter_base(
+ x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=None
+ )
+
+ if bias is not None:
+ output = output + bias
+
+ ctx.hcomm_info = hcomm_info
+ ctx.world_size = world_size
+
+ output = output.view(
+ int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
+ )
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_ = ctx.saved_tensors[0]
+ weight = ctx.weight
+ hcomm_info = ctx.hcomm_info
+ world_size = ctx.world_size
+
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+
+ grad_input, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
+ grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
+ )
+
+ grad_input = grad_input.view_as(input_)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+ if ctx.gradient_accumulation_fusion and weight.main_grad.dtype == torch.float32:
+ from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32
+ npu_matmul_add_fp32(x, all_gather_grad_output, weight.main_grad)
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input_.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input_.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = all_gather_grad_output.t().matmul(x)
+
+ is_grad_bias_needed = ctx.needs_input_grad[2]
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = (
+ grad_output.sum(dim=0)
+ if grad_output.is_contiguous()
+ else grad_output.t().sum(dim=1)
+ )
+ else:
+ grad_bias = None
+
+ return grad_input, grad_weight, grad_bias, None
+
+
+class ColumnSeqParallelLinearWithFrozenWeight(ColumnSeqParallelLinear):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, group):
+ ctx.input_shape = input_.shape
+ ctx.use_bias = bias is not None
+ ctx.weight = weight
+
+ rank = torch.distributed.get_rank(group)
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
+ global_rank
+ )
+
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+
+ world_size = ascend_turbo_cfg.get_world_size()
+ # npu_all_gather_base_mm currently do not support bias
+ output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
+ x,
+ weight.t(),
+ hcomm_info,
+ world_size,
+ bias=None,
+ gather_index=0,
+ gather_output=(not ascend_turbo_cfg.all_gather_recomputation),
+ )
+
+ if bias is not None:
+ output = output + bias
+
+ output = output.view(
+ int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
+ )
+ ctx.hcomm_info = hcomm_info
+ ctx.world_size = world_size
+ ctx.group = group
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_shape = ctx.input_shape
+ weight = ctx.weight
+
+ hcomm_info = ctx.hcomm_info
+ world_size = ctx.world_size
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+
+ sub_grad_input = torch_npu.npu_mm_reduce_scatter_base(
+ grad_output_, weight, hcomm_info, world_size, bias=None
+ )
+
+ sub_grad_input = sub_grad_input.view(input_shape)
+
+ return sub_grad_input, None, None, None
+
+
+class RowSeqParallelLinearWithFrozenWeight(RowSeqParallelLinear):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, group):
+ ctx.input_shape = input_.shape
+ ctx.use_bias = bias is not None
+ ctx.weight = weight
+
+ rank = torch.distributed.get_rank(group)
+ world_size = ascend_turbo_cfg.get_world_size()
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
+ global_rank
+ )
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+
+ # npu_mm_reduce_scatter_base currently do not support bias
+ output = torch_npu.npu_mm_reduce_scatter_base(
+ x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=None
+ )
+
+ if bias is not None:
+ output = output + bias
+
+ ctx.hcomm_info = hcomm_info
+ ctx.world_size = world_size
+
+ output = output.view(
+ int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
+ )
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_shape = ctx.input_shape
+ weight = ctx.weight
+ hcomm_info = ctx.hcomm_info
+ world_size = ctx.world_size
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+
+ grad_input, _ = torch_npu.npu_all_gather_base_mm(
+ grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
+ )
+
+ grad_input = grad_input.view(input_shape)
+
+ return grad_input, None, None, None
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/checkpoint_manager.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/checkpoint_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d6820e0960df31d9d4d5c95f96b24b4e473bdac
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/checkpoint_manager.py
@@ -0,0 +1,62 @@
+import torch
+
+
+class PipelineCheckpointManager:
+ instance = None
+
+ def __init__(self, num_of_chunks=2):
+ self.open_ri_pipe = False
+ self.do_pre_recompute = False
+ self.checkpoint_list = []
+ self.chunk_list = [[] for i in range(num_of_chunks)]
+ self.chunk_do_recompute = True
+
+ def batch_fin(self, chunk_idx):
+ self.chunk_list[chunk_idx].append(self.checkpoint_list)
+ self.checkpoint_list = []
+
+ def iter_fin(self):
+ if len(self.checkpoint_list) != 0:
+ raise RuntimeError("recompute list is not empty")
+
+ for batch_list_for_chunk in self.chunk_list:
+ for layer_list_for_batch in batch_list_for_chunk:
+ if len(layer_list_for_batch) != 0:
+ raise RuntimeError(
+ f"{torch.distributed.get_rank()} batch recompute list is not empty, got {layer_list_for_batch[0]} in index 0")
+ batch_list_for_chunk.clear()
+ self.open_ri_pipe = False
+ self.do_pre_recompute = False
+ self.chunk_do_recompute = True
+
+ def add_recompute(self, recompute_func):
+ self.checkpoint_list.append(recompute_func)
+
+ def recompute_next_force(self, chunk_idx):
+ layer_list_for_head_batch = []
+ while len(layer_list_for_head_batch) == 0:
+ if len(self.chunk_list[chunk_idx]) == 0:
+ print(f"rank-{torch.distributed.get_rank()} warning, poping empty list")
+ break
+ layer_list_for_head_batch = self.chunk_list[chunk_idx].pop(0)
+
+ for layer_func in layer_list_for_head_batch:
+ layer_func()
+
+ def recompute_next(self, chunk_idx):
+ layer_list_for_head_batch = self.chunk_list[chunk_idx].pop(0)
+ if self.do_pre_recompute:
+ for layer_func in layer_list_for_head_batch:
+ layer_func()
+
+ def disable_recompute(self):
+ self.chunk_do_recompute = False
+
+ def enable_recompute(self):
+ self.chunk_do_recompute = True
+
+
+def get_pipeline_checkpoint_manager(num_of_chunks=2):
+ if PipelineCheckpointManager.instance is None:
+ PipelineCheckpointManager.instance = PipelineCheckpointManager(num_of_chunks=num_of_chunks)
+ return PipelineCheckpointManager.instance
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_autograd_function.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_autograd_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e324c76706978b282adebf46d828492fd36a69c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_autograd_function.py
@@ -0,0 +1,246 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import torch
+import torch.distributed
+
+from mindspeed.core.tensor_parallel.comm_group_api import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+from mindspeed.core.tensor_parallel.comm_utils import _gather_along_last_dim
+from mindspeed.core.tensor_parallel.comm_utils import _split_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import _split_along_last_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_gather_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_gather_along_last_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_reduce_scatter_along_first_dim
+
+
+class _SyncGatherAlongFirstDim(torch.autograd.Function):
+ """Gather the input from model parallel X region and concatinate."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return sync_gather_along_first_dim(input_, TPXCollectiveComm)
+
+ @staticmethod
+ def forward(ctx, input_, comm_intf: CollectiveCommIntf):
+ ctx.comm_intf = comm_intf
+ return sync_gather_along_first_dim(input_, comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split_along_first_dim(grad_output, ctx.comm_intf), None
+
+
+class _SyncGatherAlongLastDim(torch.autograd.Function):
+ """Gather the input from model parallel Y region and concatinate."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return sync_gather_along_last_dim(input_, TPYCollectiveComm)
+
+ @staticmethod
+ def forward(ctx, input_, comm_intf: CollectiveCommIntf):
+ ctx.comm_intf = comm_intf
+ return sync_gather_along_last_dim(input_, comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split_along_last_dim(grad_output, ctx.comm_intf), None
+
+
+def _reduce(input_, tp_intf: CollectiveCommIntf = TPXCollectiveComm):
+ """All-reduce the input tensor across model parallel group."""
+
+ # Bypass the function if we are using only 1 GPU.
+ if tp_intf.get_comm_group_world_size() == 1:
+ return input_
+
+ # All-reduce.
+ torch.distributed.all_reduce(input_, group=tp_intf.get_comm_group())
+ return input_
+
+
+class _ReduceFromModelParallelRegion(torch.autograd.Function):
+ """All-reduce the input from the model parallel region."""
+
+ @staticmethod
+ def symbolic(graph, input_, tp_intf: CollectiveCommIntf = TPXCollectiveComm):
+ return _reduce(input_, tp_intf), None
+
+ @staticmethod
+ def forward(ctx, input_, tp_intf: CollectiveCommIntf = TPXCollectiveComm):
+ return _reduce(input_, tp_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None
+
+
+class _GatherFromParallelRegion(torch.autograd.Function):
+ """Gather the input from model parallel region and concatinate."""
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _gather_along_last_dim(input_)
+
+ @staticmethod
+ def forward(ctx, input_, comm_intf: CollectiveCommIntf):
+ ctx.comm_intf = comm_intf
+ return _gather_along_last_dim(input_, comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split_along_last_dim(grad_output, ctx.comm_intf), None
+
+
+class _ScatterAlongLastDim(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank."""
+
+ @staticmethod
+ def symbolic(graph, input_, comm_intf: CollectiveCommIntf):
+ return _split_along_last_dim(input_, comm_intf)
+
+ @staticmethod
+ def forward(ctx, input_, comm_intf: CollectiveCommIntf):
+ ctx.comm_intf = comm_intf
+ return _split_along_last_dim(input_, comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _gather_along_last_dim(grad_output, ctx.comm_intf), None
+
+
+class _ScatterAlongFirstDim(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank."""
+
+ @staticmethod
+ def symbolic(graph, input_, comm_intf: CollectiveCommIntf):
+ return _split_along_first_dim(input_, comm_intf)
+
+ @staticmethod
+ def forward(ctx, input_, comm_intf: CollectiveCommIntf):
+ ctx.comm_intf = comm_intf
+ return _split_along_first_dim(input_, comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return sync_gather_along_first_dim(grad_output, ctx.comm_intf), None
+
+
+class _ScatterAlongFirstDimThenLastDim(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank."""
+
+ @staticmethod
+ def symbolic(graph, local_rank_input, first_dim_comm_intf, last_dim_comm_intf):
+ graph.first_dim_comm_intf = first_dim_comm_intf
+ graph.last_dim_comm_intf = last_dim_comm_intf
+
+ first_dim_split_output = _split_along_first_dim(local_rank_input, first_dim_comm_intf)
+ return _split_along_last_dim(first_dim_split_output, last_dim_comm_intf)
+
+ @staticmethod
+ def forward(ctx, local_rank_input, first_dim_comm_intf, last_dim_comm_intf):
+ ctx.first_dim_comm_intf = first_dim_comm_intf
+ ctx.last_dim_comm_intf = last_dim_comm_intf
+
+ first_dim_split_output = _split_along_first_dim(local_rank_input, first_dim_comm_intf)
+ return _split_along_last_dim(first_dim_split_output, last_dim_comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ last_dim_gather_output = _gather_along_last_dim(grad_output, ctx.last_dim_comm_intf)
+ first_dim_gather_output = sync_gather_along_first_dim(
+ last_dim_gather_output, ctx.first_dim_comm_intf)
+ return first_dim_gather_output, None, None
+
+
+class _SyncGatherAlongFirstDimRS(torch.autograd.Function):
+ """Gather the input from model parallel X region and concatinate."""
+
+ @staticmethod
+ def symbolic(graph, input_, comm_intf: CollectiveCommIntf):
+ return sync_gather_along_first_dim(input_, comm_intf)
+
+ @staticmethod
+ def forward(ctx, input_, comm_intf: CollectiveCommIntf):
+ ctx.comm_intf = comm_intf
+ return sync_gather_along_first_dim(input_, comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return sync_reduce_scatter_along_first_dim(grad_output, ctx.comm_intf), None
+
+
+class _SyncReduceScatterAlongFirstDim(torch.autograd.Function):
+ """Reduce scatter the input along first dim"""
+
+ @staticmethod
+ def symbolic(graph, input_, comm_intf: CollectiveCommIntf):
+ return sync_reduce_scatter_along_first_dim(input_, comm_intf)
+
+ @staticmethod
+ def forward(ctx, input_, comm_intf: CollectiveCommIntf):
+ ctx.comm_intf = comm_intf
+ return sync_reduce_scatter_along_first_dim(input_, comm_intf)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return sync_gather_along_first_dim(grad_output, ctx.comm_intf), None
+
+
+def auto_grad_sync_gather_along_first_dim(input_, comm_intf: CollectiveCommIntf):
+ return _SyncGatherAlongFirstDim.apply(input_, comm_intf)
+
+
+def auto_grad_sync_gather_along_last_dim(input_, comm_intf: CollectiveCommIntf):
+ return _SyncGatherAlongLastDim.apply(input_, comm_intf)
+
+
+def scatter_to_tensor_parallel_y_region(input_):
+ return _ScatterAlongLastDim.apply(input_)
+
+
+def auto_grad_scatter_along_last_dim(input_, comm_intf: CollectiveCommIntf):
+ return _ScatterAlongLastDim.apply(input_, comm_intf)
+
+
+def auto_grad_scatter_along_first_dim(input_, comm_intf: CollectiveCommIntf):
+ return _ScatterAlongFirstDim.apply(input_, comm_intf)
+
+
+def auto_grad_scatter_along_first_dim_then_last_dim(
+ local_rank_input: torch.Tensor,
+ first_dim_comm_intf: CollectiveCommIntf,
+ last_dim_comm_intf: CollectiveCommIntf,
+):
+ return _ScatterAlongFirstDimThenLastDim.apply(
+ local_rank_input, first_dim_comm_intf, last_dim_comm_intf
+ )
+
+
+def reduce_from_parallel_region(input_, tp_intf: CollectiveCommIntf = TPXCollectiveComm):
+ return _ReduceFromModelParallelRegion.apply(input_, tp_intf)
+
+
+def gather_from_parallel_region(input_, comm_intf: CollectiveCommIntf):
+ return _GatherFromParallelRegion.apply(input_, comm_intf)
+
+
+def auto_grad_sync_gather_along_first_dim_rs(input_, comm_intf: CollectiveCommIntf):
+ return _SyncGatherAlongFirstDimRS.apply(input_, comm_intf)
+
+
+def auto_grad_reduce_scatter_along_first_dim(input_, comm_intf: CollectiveCommIntf):
+ return _SyncReduceScatterAlongFirstDim.apply(input_, comm_intf)
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_group_api.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_group_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..2422b3bcabb3b288321325fd204c7c33610fc045
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_group_api.py
@@ -0,0 +1,174 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import abc
+
+from megatron.core.parallel_state import get_context_parallel_group
+from megatron.core.parallel_state import get_context_parallel_rank
+from megatron.core.parallel_state import get_context_parallel_world_size
+from megatron.core.parallel_state import get_tensor_model_parallel_group
+from megatron.core.parallel_state import get_tensor_model_parallel_rank
+from megatron.core.parallel_state import get_tensor_model_parallel_world_size
+
+from mindspeed.core.parallel_state import get_tensor_model_parallel_group_for_nd1_dim1
+from mindspeed.core.parallel_state import get_tensor_model_parallel_group_for_nd1_dim1_rank
+from mindspeed.core.parallel_state import get_tensor_model_parallel_group_for_nd1_dim1_world_size
+from mindspeed.core.parallel_state import get_tensor_model_parallel_group_for_nd1_dim2
+from mindspeed.core.parallel_state import get_tensor_model_parallel_group_for_nd1_dim2_rank
+from mindspeed.core.parallel_state import get_tensor_model_parallel_group_for_nd1_dim2_world_size
+from mindspeed.core.parallel_state import get_tp_x_ep_group
+from mindspeed.core.parallel_state import get_tp_x_ep_group_rank
+from mindspeed.core.parallel_state import get_tp_x_ep_group_world_size
+from mindspeed.core.parallel_state import get_tp_x_ring_global_ranks
+from mindspeed.core.parallel_state import get_tp_x_sd_rcv_overlap_group
+from mindspeed.core.parallel_state import get_tp_y_ring_global_ranks
+from mindspeed.core.parallel_state import get_tp_y_sd_rcv_overlap_group
+
+
+class CollectiveCommIntf:
+ def __init__(self, comm_group_name):
+ self.comm_group_name = comm_group_name
+
+ @classmethod
+ @abc.abstractmethod
+ def get_comm_group_world_size(cls):
+ raise NotImplementedError
+
+ @classmethod
+ @abc.abstractmethod
+ def get_comm_group(cls):
+ raise NotImplementedError
+
+ @classmethod
+ @abc.abstractmethod
+ def get_comm_rank(cls):
+ raise NotImplementedError
+
+ def get_comm_group_name(self):
+ return self.comm_group_name
+
+
+class OverlapCollectiveIntf(CollectiveCommIntf):
+ @classmethod
+ @abc.abstractmethod
+ def get_ring_global_ranks(cls):
+ raise NotImplementedError
+
+
+class CPCollectiveComm(CollectiveCommIntf):
+ @classmethod
+ def get_comm_group_world_size(cls):
+ return get_context_parallel_world_size()
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_context_parallel_group()
+
+ @classmethod
+ def get_comm_rank(cls):
+ return get_context_parallel_rank()
+
+
+class TPXCollectiveComm(CollectiveCommIntf):
+ def __init__(self, name="tp-x"):
+ super().__init__(name)
+
+ @classmethod
+ def get_comm_rank(cls):
+ return get_tensor_model_parallel_group_for_nd1_dim1_rank()
+
+ @classmethod
+ def get_comm_group_world_size(cls):
+ return get_tensor_model_parallel_group_for_nd1_dim1_world_size()
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_tensor_model_parallel_group_for_nd1_dim1()
+
+
+class TPXEPCollectiveComm(CollectiveCommIntf):
+ def __init__(self, name="tp-x-ep"):
+ super().__init__(name)
+
+ @classmethod
+ def get_comm_rank(cls):
+ return get_tp_x_ep_group_rank()
+
+ @classmethod
+ def get_comm_group_world_size(cls):
+ return get_tp_x_ep_group_world_size()
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_tp_x_ep_group()
+
+
+class TPXOverlapCollectiveComm(TPXCollectiveComm, OverlapCollectiveIntf):
+ def __init__(self):
+ super().__init__("tp-x-overlap")
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_tp_x_sd_rcv_overlap_group()
+
+ @classmethod
+ def get_ring_global_ranks(cls):
+ return get_tp_x_ring_global_ranks()
+
+
+class TPYCollectiveComm(CollectiveCommIntf):
+ def __init__(self, name="tp-y"):
+ super().__init__(name)
+
+ @classmethod
+ def get_comm_rank(cls):
+ return get_tensor_model_parallel_group_for_nd1_dim2_rank()
+
+ @classmethod
+ def get_comm_group_world_size(cls):
+ return get_tensor_model_parallel_group_for_nd1_dim2_world_size()
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_tensor_model_parallel_group_for_nd1_dim2()
+
+
+class TPYOverlapCollectiveComm(TPYCollectiveComm, OverlapCollectiveIntf):
+ def __init__(self):
+ super().__init__("y-overlap")
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_tp_y_sd_rcv_overlap_group()
+
+ @classmethod
+ def get_ring_global_ranks(cls):
+ return get_tp_y_ring_global_ranks()
+
+
+class TPXYCollectiveComm(CollectiveCommIntf):
+ def __init__(self):
+ super().__init__("tp-xy")
+
+ @classmethod
+ def get_comm_rank(cls):
+ return get_tensor_model_parallel_rank()
+
+ @classmethod
+ def get_comm_group_world_size(cls):
+ return get_tensor_model_parallel_world_size()
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_tensor_model_parallel_group()
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_utils.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2360fa8d50908a19dca5e812768bb0cbf4c5a1c3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/comm_utils.py
@@ -0,0 +1,271 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import torch
+from torch import Tensor
+from torch import distributed
+import torch.distributed as dist
+
+from megatron.core.parallel_state import get_global_memory_buffer
+from mindspeed.core.tensor_parallel.comm_group_api import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+
+
+def _split_along_last_dim(
+ local_rank_input: Tensor, comm_intf: CollectiveCommIntf = TPXCollectiveComm
+):
+ """Split the tensor along its last dimension and keep the
+ corresponding slice."""
+
+ world_size = comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return local_rank_input
+
+ # Split along last dimension.
+ last_dim = local_rank_input.dim() - 1
+ last_dim_size = local_rank_input.size()[last_dim] // world_size
+ # Split.
+ tensor_list = torch.split(local_rank_input, last_dim_size, dim=last_dim)
+
+ # Note: torch.split does not create contiguous tensors by default.
+ rank = comm_intf.get_comm_rank()
+ output = tensor_list[rank].contiguous()
+
+ return output
+
+
+def _split_along_first_dim(local_rank_input, comm_intf: CollectiveCommIntf = TPXCollectiveComm):
+ """Split the tensor along its first dimension and keep the
+ corresponding slice."""
+
+ world_size = comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return local_rank_input
+
+ # Split along first dimension.
+ dim_size = local_rank_input.size()[0]
+ if dim_size % world_size:
+ raise AssertionError("First dimension of the tensor should be divisible by parallel size")
+ local_dim_size = dim_size // world_size
+ rank = comm_intf.get_comm_rank()
+ dim_offset = rank * local_dim_size
+
+ output = local_rank_input[dim_offset : dim_offset + local_dim_size].contiguous()
+
+ return output
+
+
+def _gather_along_last_dim(
+ local_rank_input: Tensor, ag_comm_intf: CollectiveCommIntf = TPXCollectiveComm
+):
+ """Gather tensors and concatinate along the last dimension."""
+
+ world_size = ag_comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return local_rank_input
+
+ tensor_list = [torch.empty_like(local_rank_input) for _ in range(world_size)]
+ torch.distributed.all_gather(
+ tensor_list, local_rank_input, group=ag_comm_intf.get_comm_group(), async_op=False
+ )
+
+ # Note: torch.cat already creates a contiguous tensor.
+ last_dim = local_rank_input.dim() - 1
+ output = torch.cat(tensor_list, dim=last_dim).contiguous()
+ return output
+
+
+def sync_gather_along_last_dim(
+ local_rank_tensor: Tensor, ag_comm_intf: CollectiveCommIntf = TPXCollectiveComm
+):
+ """Gather tensors and concatinate along the last dimension synchronously.
+
+ :param local_rank_tensor: input of current rank.
+ :param ag_comm_intf: the communication process group interface.
+ :return: the AllGather-ed result.
+ """
+
+ world_size = ag_comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU/NPU.
+ if world_size == 1:
+ return local_rank_tensor
+
+ gathered_tensors = [torch.empty_like(local_rank_tensor) for _ in range(world_size)]
+ torch.distributed.all_gather(
+ gathered_tensors,
+ local_rank_tensor.contiguous(),
+ group=ag_comm_intf.get_comm_group(),
+ async_op=False,
+ )
+
+ return torch.cat(gathered_tensors, dim=local_rank_tensor.dim() - 1).contiguous()
+
+
+def async_gather_tensors(
+ local_rank_input: Tensor,
+ ag_comm_intf: CollectiveCommIntf = TPXCollectiveComm,
+ buffer_name="mpu-async-tp-2d",
+):
+ """Gather tensors and concatinate along the last dimension asynchronously.
+
+ :param local_rank_input: input of current rank.
+ :param ag_comm_intf: the AllGather communication process group interface.
+ :param buffer_name: buffer name of str type.
+ :return: the AllGather op handle and tensor list storing the op result tensors.
+
+ Note: the result tensors may be handled as following according to your need:
+ output = torch.cat(gathered_tensors, dim=xx_dim).contiguous()
+ """
+
+ world_size = ag_comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 NPU/GPU.
+ if world_size == 1:
+ return None, local_rank_input
+
+ dim_size = list(local_rank_input.size())
+ dim_size[0] *= world_size
+
+ ag_out = torch.empty(dim_size, dtype=local_rank_input.dtype, device=torch.cuda.current_device())
+ handle = torch.distributed._all_gather_base(
+ ag_out, local_rank_input, group=ag_comm_intf.get_comm_group(), async_op=True
+ )
+
+ return handle, ag_out
+
+
+def sync_gather_along_first_dim(
+ local_rank_input: Tensor,
+ comm_intf: CollectiveCommIntf = TPXCollectiveComm,
+ buffer_name=None,
+):
+ """Gather tensors and concatinate along the first dimension."""
+
+ world_size = comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return local_rank_input
+
+ dim_size = list(local_rank_input.size())
+ dim_size[0] *= world_size
+
+ if buffer_name is None:
+ output = torch.empty(dim_size, dtype=local_rank_input.dtype, device=torch.cuda.current_device())
+ else:
+ output = get_global_memory_buffer().get_tensor(dim_size, local_rank_input.dtype, buffer_name)
+ torch.distributed._all_gather_base(
+ output, local_rank_input.contiguous(), group=comm_intf.get_comm_group()
+ )
+
+ return output
+
+
+def sync_reduce_scatter_along_first_dim(
+ local_rank_input, comm_intf: CollectiveCommIntf = TPXCollectiveComm
+):
+ """Reduce-scatter the input tensor across specified parallel group."""
+ world_size = comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return local_rank_input
+
+ dim_size = list(local_rank_input.size())
+ if dim_size[0] % world_size:
+ raise AssertionError("First dimension of the tensor should be divisible by tensor parallel size")
+
+ dim_size[0] = dim_size[0] // world_size
+
+ output = torch.empty(dim_size, dtype=local_rank_input.dtype, device=torch.cuda.current_device())
+ dist.reduce_scatter_tensor(
+ output, local_rank_input.contiguous(), group=comm_intf.get_comm_group(), async_op=False
+ )
+
+ return output
+
+
+def async_reduce_scatter_along_first_dim(
+ local_rank_input, comm_intf: CollectiveCommIntf = TPXCollectiveComm
+):
+ """Reduce-scatter the input tensor across parallel group specified by comm_intf."""
+ world_size = comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return None, local_rank_input
+
+ dim_size = list(local_rank_input.size())
+ if dim_size[0] % world_size:
+ raise AssertionError("First dimension of the tensor should be divisible by parallel size")
+
+ dim_size[0] = dim_size[0] // world_size
+
+ rs_output = torch.empty(
+ dim_size, dtype=local_rank_input.dtype, device=torch.cuda.current_device()
+ )
+ handle = dist.reduce_scatter_tensor(
+ rs_output, local_rank_input.contiguous(), group=comm_intf.get_comm_group(), async_op=True
+ )
+ return handle, rs_output
+
+
+def async_gather_along_last_dim(input_, comm_intf: CollectiveCommIntf = TPXCollectiveComm):
+ world_size = comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU/NPU.
+ if world_size == 1:
+ return None, input_
+
+ gathered_tensors = [torch.empty_like(input_) for _ in range(world_size)]
+ handle = torch.distributed.all_gather(
+ gathered_tensors, input_.contiguous(), group=comm_intf.get_comm_group(), async_op=True,
+ )
+
+ return handle, gathered_tensors
+
+
+def sync_reduce_scatter_along_last_dim(
+ local_rank_input, rs_comm_intf: CollectiveCommIntf = TPXCollectiveComm
+):
+ """Reduce-scatter the input tensor across specified parallel group."""
+ world_size = rs_comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return local_rank_input
+
+ local_rank_input = local_rank_input.transpose(0, -1)
+ output = sync_reduce_scatter_along_first_dim(local_rank_input, rs_comm_intf)
+ return output.transpose(0, -1).contiguous()
+
+
+def async_reduce_scatter_along_last_dim(
+ local_rank_input, rs_comm_intf: CollectiveCommIntf = TPXCollectiveComm
+):
+ """Reduce-scatter the input tensor across model parallel group.
+
+ :param local_rank_input: input of local rank
+ :param rs_comm_intf: Reduce scatter comm intf.
+ :return:
+
+ Note: the result tensors should be handled as following:
+ rs_output = rs_output.transpose(0, 2).contiguous()
+
+ """
+ world_size = rs_comm_intf.get_comm_group_world_size()
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return None, local_rank_input
+
+ local_rank_input = local_rank_input.transpose(0, 2)
+ return async_reduce_scatter_along_first_dim(local_rank_input, rs_comm_intf)
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/cross_entropy.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e42a86854c2aa1da570b9b8f4be7f1aef13a1c6
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/cross_entropy.py
@@ -0,0 +1,38 @@
+from typing import Tuple
+
+import torch
+
+
+def calculate_predicted_logits(
+ vocab_parallel_logits: torch.Tensor,
+ target: torch.Tensor,
+ logits_max: torch.Tensor,
+ vocab_start_index: int,
+ vocab_end_index: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # subtraction the maximum value.
+ # Use in-place to reduce memory pressure.
+ vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ masked_target = target.clone() - vocab_start_index
+ masked_target *= ~target_mask
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits *= ~target_mask
+
+ exp_logits = vocab_parallel_logits
+ torch.exp(vocab_parallel_logits, out=exp_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+
+ return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/layers.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d348d2c7b2b7ce899954cadeed2b80283b60db20
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/layers.py
@@ -0,0 +1,1969 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2024, Bytedance Inc. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import wraps
+from webbrowser import get
+import os
+import warnings
+from typing import List, Optional, Callable, Any
+
+import torch
+import torch_npu
+import torch.distributed
+import torch.nn.functional as F
+
+from megatron.core.transformer import TransformerConfig
+from .comm_group_api import TPXCollectiveComm
+from .comm_group_api import TPYCollectiveComm
+
+try:
+ import fused_weight_gradient_mlp_cuda
+except Exception:
+ warnings.warn("failed to generate the npu_matmul_add_fp32")
+from torch.nn.parameter import Parameter
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torch.nn.parameter import Parameter
+import megatron.core.parallel_state as ps
+from megatron.core.tensor_parallel.layers import (
+ _initialize_affine_weight_cpu,
+ _initialize_affine_weight_gpu
+)
+
+from megatron.core import parallel_state, tensor_parallel, mpu
+from megatron.training import get_args
+
+from megatron.core.tensor_parallel.mappings import (
+ copy_to_tensor_model_parallel_region,
+ gather_from_tensor_model_parallel_region,
+ reduce_from_tensor_model_parallel_region,
+ reduce_scatter_to_sequence_parallel_region,
+ scatter_to_tensor_model_parallel_region,
+ _reduce_scatter_along_first_dim,
+ _gather_along_first_dim,
+ _ReduceFromModelParallelRegion,
+)
+from megatron.core.tensor_parallel.layers import (
+ LinearWithGradAccumulationAndAsyncCommunication,
+ linear_with_grad_accumulation_and_async_allreduce,
+ linear_with_frozen_weight,
+)
+from megatron.core.parallel_state import (
+ get_global_memory_buffer,
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ is_pipeline_first_stage,
+ get_data_parallel_world_size,
+ get_data_parallel_rank,
+)
+from megatron.core.tensor_parallel.layers import set_tensor_model_parallel_attributes
+from megatron.core.parallel_state import get_tensor_model_parallel_world_size
+from megatron.core.model_parallel_config import ModelParallelConfig
+from megatron.core.tensor_parallel.utils import VocabUtility, divide, split_tensor_along_last_dim
+from megatron.core.utils import (
+ make_tp_sharded_tensor_for_checkpoint,
+ prepare_input_tensors_for_wgrad_compute
+)
+from mindspeed.core.parallel_state import (
+ get_tensor_model_parallel_group_for_nd1_dim1,
+ get_tensor_model_parallel_group_for_nd1_dim2,
+ get_tensor_model_parallel_group_for_nd2_dim1,
+ get_tensor_model_parallel_group_for_nd2_dim2,
+ get_tensor_model_parallel_world_size_for_nd1_dim1,
+ get_tensor_model_parallel_world_size_for_nd1_dim2,
+ get_tensor_model_parallel_world_size_for_nd2_dim1,
+ get_tensor_model_parallel_world_size_for_nd2_dim2
+)
+from mindspeed.core.weight_grad_store import WeightGradStore
+from mindspeed.moe.async_comm_utils import get_fw_ag_output
+from mindspeed.moe.utils import get_slice_indices_from_disorder_to_order
+from .ascend_turbo.mc2_linears_seq_parallel import RowSeqParallelLinear
+
+
+def linear_with_grad_accumulation_and_async_allreduce_zero3(
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion: bool,
+ async_grad_allreduce: bool,
+ sequence_parallel: bool,
+ grad_output_buffer=None,
+ need_gather_param_in_bw=False):
+
+ args = [
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ grad_output_buffer,
+ need_gather_param_in_bw,
+ ]
+
+ if not linear_with_grad_accumulation_and_async_allreduce_zero3.warned:
+ if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
+ if sequence_parallel:
+ warnings.warn(
+ "When using sequence parallelism it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce_zero3.warned = True
+
+ if async_grad_allreduce:
+ warnings.warn(
+ "When using async grad allreduce it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce_zero3.warned = True
+
+ return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
+linear_with_grad_accumulation_and_async_allreduce_zero3.warned = False
+
+
+def linear_forward_zero3_wrapper(forward_func):
+ @wraps(forward_func)
+ def linear_forward_zero3(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ grad_output_buffer,
+ need_gather_param_in_bw=False):
+
+ ctx.need_gather_param_in_bw = need_gather_param_in_bw
+
+ return forward_func(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ grad_output_buffer)
+
+ return linear_forward_zero3
+
+
+def linear_backward_zero3_wrapper(func):
+ @wraps(func)
+ def linear_backward_zero3(ctx, grad_output):
+ ctx.gradient_accumulation_fusion = (ctx.gradient_accumulation_fusion and not ctx.need_gather_param_in_bw)
+ grad_input, grad_weight, grad_bias, _, _, _, _ = func(ctx, grad_output)
+ if ctx.need_gather_param_in_bw:
+ _, weight = ctx.saved_tensors
+ weight.full_grad = grad_weight
+ grad_weight = None
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None
+
+ return linear_backward_zero3
+
+
+def linear_forward_main_grad_wrapper(forward_func):
+ @wraps(forward_func)
+ def linear_forward_main_grad(ctx,
+ inputs,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ allreduce_dgrad,
+ wgrad_deferral_limit,
+ sequence_parallel,
+ grad_output_buffer,):
+ output = forward_func(ctx,
+ inputs,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ allreduce_dgrad,
+ wgrad_deferral_limit,
+ sequence_parallel,
+ grad_output_buffer,)
+ ctx.weight = weight
+ return output
+
+ return linear_forward_main_grad
+
+
+def linear_backward_main_grad_wrapper(backward_func):
+ @wraps(backward_func)
+ def linear_backward_main_grad(ctx, grad_output):
+ class NewCtx:
+ pass
+ new_ctx = NewCtx()
+ inputs, _ = ctx.saved_tensors
+ for key in dir(ctx):
+ if key == 'saved_tensors':
+ setattr(new_ctx, 'saved_tensors', (inputs, ctx.weight))
+ elif key.startswith('__') or key == 'saved_variables':
+ continue
+ else:
+ try:
+ getattr(ctx, key)
+ except AttributeError:
+ continue
+ setattr(new_ctx, key, getattr(ctx, key))
+ return backward_func(new_ctx, grad_output)
+
+ return linear_backward_main_grad
+
+
+def parallel_linear_init_zero3_wrapper(func):
+ @wraps(func)
+ def parallel_linear_init(self, *args, **kwargs):
+ global_args = get_args()
+ self.enable_zero3 = global_args.enable_zero3
+ func(self, *args, **kwargs)
+ if self.enable_zero3:
+ dp_size = get_data_parallel_world_size()
+ dp_rank = get_data_parallel_rank()
+ tmp_tensor = self.weight.chunk(dp_size, dim=0)[dp_rank]
+ self.weight = Parameter(
+ torch.empty(
+ tmp_tensor.shape, dtype=self.config.params_dtype
+ )
+ )
+ self.weight.data.copy_(tmp_tensor)
+ setattr(self.weight, 'enable_zero3', self.enable_zero3)
+
+ return parallel_linear_init
+
+
+def column_parallel_linear_forward_zero3(self, input_, weight=None):
+ """Forward of ColumnParallelLinear
+
+ Args:
+ input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
+
+ weight (optional): weight tensor to use, compulsory when
+ skip_weight_param_allocation is True.
+
+ Returns:
+ - output
+ - bias
+
+ """
+ if weight is None:
+ if self.weight is None:
+ raise RuntimeError(
+ "weight was not supplied to ColumnParallelLinear forward pass "
+ "and skip_weight_param_allocation is True."
+ )
+ weight = self.weight
+ else:
+ # Check the weight passed in is the correct shape
+ expected_shape = (self.output_size_per_partition, self.input_size)
+ if weight.shape != expected_shape:
+ raise RuntimeError(
+ f"supplied weight's shape is {tuple(weight.shape)}, "
+ f"not {expected_shape} as expected"
+ )
+
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ bias = self.bias if not self.skip_bias_add else None
+
+ if (
+ self.async_tensor_model_parallel_allreduce
+ or self.sequence_parallel
+ or self.explicit_expert_comm
+ ):
+ input_parallel = input_
+ else:
+ input_parallel = copy_to_tensor_model_parallel_region(input_)
+
+ if self.config.defer_embedding_wgrad_compute:
+ self.embedding_activation_buffer.append(input_parallel)
+
+ # Matrix multiply.
+ if not weight.requires_grad:
+ self._forward_impl = linear_with_frozen_weight
+ else:
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=weight,
+ bias=bias,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False
+ if self.explicit_expert_comm
+ else self.async_tensor_model_parallel_allreduce,
+ sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
+ grad_output_buffer=self.grad_output_buffer
+ if self.config.defer_embedding_wgrad_compute
+ else None,
+ need_gather_param_in_bw=self.enable_zero3
+ )
+ if self.gather_output:
+ # All-gather across the partitions.
+ assert not self.sequence_parallel
+ output = gather_from_tensor_model_parallel_region(output_parallel)
+ else:
+ output = output_parallel
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+
+def row_parallel_linear_forward_zero3(self, input_):
+
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ # Set up backprop all-reduce.
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ assert not self.sequence_parallel
+ input_parallel = scatter_to_tensor_model_parallel_region(input_)
+ # Matrix multiply.
+ if not self.weight.requires_grad:
+ self._forward_impl = linear_with_frozen_weight
+ else:
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False,
+ sequence_parallel=False,
+ need_gather_param_in_bw=self.enable_zero3
+ )
+
+ # All-reduce across all the partitions.
+ if self.explicit_expert_comm:
+ assert self.skip_bias_add
+ output_ = output_parallel
+ elif self.sequence_parallel:
+ output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
+ else:
+ output_ = reduce_from_tensor_model_parallel_region(output_parallel)
+ if not self.skip_bias_add:
+ output = (output_ + self.bias) if self.bias is not None else output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.bias
+ return output, output_bias
+
+
+def vocab_parallel_embedding_forward(self, input_):
+ if self.tensor_model_parallel_size > 1:
+ # Build the mask.
+ input_mask = (input_ < self.vocab_start_index) | \
+ (input_ >= self.vocab_end_index)
+ # Mask the input.
+ masked_input = input_.clone() - self.vocab_start_index
+ masked_input *= ~input_mask
+ else:
+ masked_input = input_
+ # Get the embeddings.
+
+ if self.deterministic_mode:
+ output_parallel = self.weight[masked_input]
+ else:
+ # F.embedding currently has a non-deterministic backward function
+ # For higher accumulation accuracy for bf16 on NPU.
+ output_parallel = F.embedding(masked_input, self.weight)
+
+ # Mask the output embedding.
+ if self.tensor_model_parallel_size > 1:
+ output_parallel *= ~input_mask[..., None]
+ # Reduce across all the model parallel GPUs.
+ if self.reduce_scatter_embeddings:
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
+ output_parallel = output_parallel.transpose(0, 1).contiguous()
+ output = reduce_scatter_to_sequence_parallel_region(output_parallel)
+ else:
+ # Reduce across all the model parallel GPUs.
+ output = reduce_from_tensor_model_parallel_region(output_parallel)
+ return output
+
+
+def row_parallel_nocomm_optimizer_wrapper(forward_func):
+ @wraps(forward_func)
+ def row_parallel_forward(*args, **kwargs):
+ global_args = get_args()
+ output = forward_func(*args, **kwargs)
+ recompute_num_layers = global_args.recompute_num_layers or 0
+
+ def is_need_avoid_infinite_recompute_loop():
+ return isinstance(output, tuple) and ((global_args.swap_attention and recompute_num_layers > 0)
+ or global_args.adaptive_memory_optimization)
+
+ if is_need_avoid_infinite_recompute_loop():
+ output, bias = output
+ if bias is not None:
+ # where only recompute mlp, training enters an infinite loop, this * 1 fix this bug
+ bias = bias * 1
+ return output, bias
+
+ return output
+ return row_parallel_forward
+
+
+class LinearWithGradAccumulationAndAsyncCommunicationPipeExperts(torch.autograd.Function):
+ """See linear_with_grad_accumulation_and_async_allreduce"""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ grad_output_buffer,
+ wgrad_deferral_limit,
+ pipe_experts,
+ ampipe_degree
+ ):
+ ctx.save_for_backward(input, weight)
+ ctx.use_bias = bias is not None
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.async_grad_allreduce = async_grad_allreduce
+ ctx.sequence_parallel = sequence_parallel
+ ctx.grad_output_buffer = grad_output_buffer
+ ctx.wgrad_deferral_limit = wgrad_deferral_limit
+ ctx.pipe_experts = pipe_experts
+
+ if sequence_parallel:
+ global_args = get_args()
+ if global_args.use_ascend_mc2 and not pipe_experts:
+ from .ascend_turbo.ascend_turbo_cfg import ascend_turbo_cfg
+ group = get_tensor_model_parallel_group()
+ rank = get_tensor_model_parallel_rank()
+ ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
+ hcomm_info = None
+
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank)
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input.reshape(input.shape[0] * input.shape[1], input.shape[2])
+ world_size = ascend_turbo_cfg.get_world_size()
+ output, _ = torch_npu.npu_all_gather_base_mm(
+ x,
+ weight.t(),
+ hcomm_info,
+ world_size,
+ bias=bias,
+ gather_index=0,
+ gather_output=(not ascend_turbo_cfg.all_gather_recomputation)
+ )
+ output = output.view(
+ output.shape[0] // input.shape[1], input.shape[1], output.shape[1]
+ )
+ elif pipe_experts:
+ total_input = get_fw_ag_output()[0]
+ output = torch.matmul(total_input, weight.t())
+ else:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+ torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group()
+ )
+ total_input = all_gather_buffer
+ output = torch.matmul(total_input, weight.t())
+ else:
+ total_input = input
+ output = torch.matmul(total_input, weight.t())
+
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ grad_output_buffer = ctx.grad_output_buffer
+ wgrad_deferral_limit = ctx.wgrad_deferral_limit
+
+ wgrad_compute = True
+ if grad_output_buffer is not None:
+ if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
+ grad_output_buffer.append(grad_output)
+ wgrad_compute = False
+
+ if wgrad_compute:
+ if ctx.sequence_parallel:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ if ctx.pipe_experts:
+ all_gather_buffer = torch.empty(dim_size, dtype=input.dtype, device=torch.cuda.current_device())
+ else:
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+
+ handle = torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # gather is scheduled before the input gradient computation
+ total_input = all_gather_buffer
+ else:
+ total_input = input
+ grad_input = grad_output.matmul(weight)
+
+ if ctx.sequence_parallel and wgrad_compute:
+ handle.wait()
+
+ if wgrad_compute:
+ grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
+ grad_output, total_input
+ )
+
+ if ctx.async_grad_allreduce:
+ # Asynchronous all-reduce
+ handle = torch.distributed.all_reduce(
+ grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # all-reduce is scheduled before the weight gradient computation
+
+ if ctx.sequence_parallel:
+ assert not ctx.async_grad_allreduce
+ dim_size = list(input.size())
+ sub_grad_input = torch.empty(
+ dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
+ )
+ # reduce_scatter
+ handle = torch.distributed._reduce_scatter_base(
+ sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # reduce scatter is scheduled before the weight gradient computation
+
+ if ctx.gradient_accumulation_fusion:
+ if wgrad_compute:
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+ from mindspeed.moe.pipe_experts import get_async_bw_all_gather_count
+ if ctx.pipe_experts and get_async_bw_all_gather_count() != 2:
+ grad_output.storage().resize_(0)
+
+ if ctx.sequence_parallel:
+ handle.wait()
+ # Need to return None's as gradient has to flow for all the input arguments
+ # provided during forward
+ return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
+
+ if ctx.async_grad_allreduce:
+ handle.wait()
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
+
+
+class LinearWithGradAccumulationAndAsyncCommunication_nano(torch.autograd.Function):
+ """See linear_with_grad_accumulation_and_async_allreduce"""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ wgrad_deferral_limit,
+ sequence_parallel,
+ pipe_experts,
+ is_nano_row,
+ is_nano_column,
+ ):
+ ctx.weight = weight
+ ctx.save_for_backward(input)
+ ctx.is_nano_row = is_nano_row
+ ctx.is_nano_column = is_nano_column
+ ctx.use_bias = bias is not None
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.async_grad_allreduce = async_grad_allreduce
+ ctx.wgrad_deferral_limit = wgrad_deferral_limit
+ ctx.sequence_parallel = sequence_parallel
+ ctx.pipe_experts = pipe_experts
+ global_args = get_args()
+ if is_nano_row:
+ total_input = input
+ if sequence_parallel:
+ if pipe_experts:
+ output = torch.matmul(total_input, weight.t())
+ elif global_args.use_ascend_mc2:
+ from .ascend_turbo.ascend_turbo_cfg import ascend_turbo_cfg
+ rank = get_tensor_model_parallel_rank()
+ ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
+ world_size = ascend_turbo_cfg.get_world_size()
+ group = get_tensor_model_parallel_group()
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank)
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input.reshape(input.shape[0] * input.shape[1], input.shape[2])
+ output = torch_npu.npu_mm_reduce_scatter_base(
+ x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=bias
+ )
+ ctx.hcomm_info = hcomm_info
+ ctx.world_size = world_size
+ output = output.view(
+ output.shape[0] // input.shape[1], input.shape[1], output.shape[1]
+ )
+ return output
+ else:
+ output = torch.matmul(total_input, weight.t())
+ output = _reduce_scatter_along_first_dim(output)
+ else:
+ output = torch.matmul(total_input, weight.t())
+ if bias is not None:
+ output = output + bias
+ return output
+
+ if sequence_parallel:
+ if pipe_experts:
+ total_input = get_fw_ag_output()[0]
+ output = torch.matmul(total_input, weight.t())
+ elif global_args.use_ascend_mc2:
+ from .ascend_turbo.ascend_turbo_cfg import ascend_turbo_cfg
+ group = get_tensor_model_parallel_group()
+ rank = get_tensor_model_parallel_rank()
+ ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device('npu')).get_hccl_comm_name(global_rank)
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+ x = input.reshape(input.shape[0] * input.shape[1], input.shape[2])
+ world_size = ascend_turbo_cfg.get_world_size()
+ output, _ = torch_npu.npu_all_gather_base_mm(
+ x,
+ weight.t(),
+ hcomm_info,
+ world_size,
+ bias=bias,
+ gather_index=0,
+ gather_output=(not ascend_turbo_cfg.all_gather_recomputation),
+ )
+ output = output.view(
+ output.shape[0] // input.shape[1], input.shape[1], output.shape[1]
+ )
+ else:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+ torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group(),
+ )
+ total_input = all_gather_buffer
+ output = torch.matmul(total_input, weight.t())
+ else:
+ total_input = input
+ output = torch.matmul(total_input, weight.t())
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ input = ctx.saved_tensors[0]
+ weight = ctx.weight
+ use_bias = ctx.use_bias
+ sequence_parallel = ctx.sequence_parallel
+ pipe_experts = ctx.pipe_experts
+ global_args = get_args()
+ grad_output_gathered = grad_output
+ grad_input = None
+ if ctx.is_nano_row:
+ if ctx.sequence_parallel:
+ if pipe_experts:
+ grad_input = grad_output.matmul(weight)
+ elif global_args.use_ascend_mc2:
+ hcomm_info = ctx.hcomm_info
+ world_size = ctx.world_size
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+ grad_input, grad_output_gathered = torch_npu.npu_all_gather_base_mm(
+ grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
+ )
+
+ grad_input = grad_input.view_as(input)
+ else:
+ grad_output_gathered = _gather_along_first_dim(grad_output)
+ grad_input = grad_output_gathered.matmul(weight)
+ else:
+ grad_input = grad_output.matmul(weight)
+
+ if WeightGradStore.is_decoupleBlock:
+ if pipe_experts and ctx.sequence_parallel:
+ WeightGradStore.put(
+ input.clone().detach(),
+ None,
+ weight,
+ sequence_parallel,
+ in_row=True,
+ pipe_experts=True
+ )
+ else:
+ WeightGradStore.put(
+ input.clone().detach(),
+ grad_output.clone().detach(),
+ weight,
+ sequence_parallel,
+ in_row=True,
+ pipe_experts=False
+ )
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = None
+ else:
+ total_input = input
+ grad_output = grad_output_gathered.contiguous()
+ # Convert the tensonr shapes to 2D for execution compatibility
+ if len(grad_output.shape) != 2:
+ grad_output = grad_output.view(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+ total_input = total_input.view(
+ total_input.shape[0] * total_input.shape[1], total_input.shape[2]
+ )
+ if ctx.gradient_accumulation_fusion:
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
+
+ if WeightGradStore.is_decoupleBlock:
+ WeightGradStore.put(
+ input.clone().detach(),
+ grad_output.clone().detach(),
+ weight,
+ ctx.sequence_parallel
+ )
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = None
+ if not WeightGradStore.is_decoupleBlock:
+ if ctx.sequence_parallel:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+ handle = torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # gather is scheduled before the input gradient computation
+ total_input = all_gather_buffer
+ else:
+ total_input = input
+ grad_input = grad_output.matmul(weight)
+
+ if not WeightGradStore.is_decoupleBlock:
+ if ctx.sequence_parallel:
+ handle.wait()
+
+ # Doing gather + slicing during the NeMo forward pass can make this tensor
+ # not be contiguous. PyTorch only checks if the tensor is contiguous, and only
+ # clones it if it's not contiguous
+
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shape to 2D for execution compatibility
+ grad_output = grad_output.view(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+ total_input = total_input.view(
+ total_input.shape[0] * total_input.shape[1], total_input.shape[2]
+ )
+
+ if ctx.async_grad_allreduce:
+ # Asynchronous all_reduce
+ handle = torch.distributed.all_reduce(
+ grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # all-reduce is scheduled before the weight gradient computation
+
+ if ctx.sequence_parallel:
+ assert not ctx.async_grad_allreduce
+ dim_size = list(input.size())
+ sub_grad_input = torch.empty(
+ dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
+ )
+ # reduce_scatter
+ handle = torch.distributed._reduce_scatter_base(
+ sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # reduce_scatter is scheduled before the weight gradient computation
+ if not WeightGradStore.is_decoupleBlock:
+ if ctx.gradient_accumulation_fusion:
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.sequence_parallel:
+ handle.wait()
+ return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
+
+ if ctx.async_grad_allreduce:
+ handle.wait()
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
+
+
+class LinearWithGradAccumulationAndAsyncCommunicationAmpipe(torch.autograd.Function):
+ """See linear_with_grad_accumulation_and_async_allreduce"""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ allreduce_dgrad,
+ sequence_parallel,
+ grad_output_buffer,
+ wgrad_deferral_limit,
+ ampipe_degree,
+ is_dense_h_to_3h
+ ):
+ ctx.save_for_backward(input, weight)
+ ctx.use_bias = bias is not None
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.allreduce_dgrad = allreduce_dgrad
+ ctx.sequence_parallel = sequence_parallel
+ ctx.wgrad_deferral_limit = wgrad_deferral_limit
+ ctx.grad_output_buffer = grad_output_buffer
+ ctx.ampipe_degree = ampipe_degree
+ ctx.is_dense_h_to_3h = is_dense_h_to_3h
+ global_args = get_args()
+ ampipe_tp_sp_comm_overlap = global_args.ampipe_tp_sp_comm_overlap
+ ctx.ampipe_tp_sp_comm_overlap = ampipe_tp_sp_comm_overlap
+
+ if sequence_parallel:
+ if global_args.use_ascend_mc2 and ampipe_degree <= 1:
+ group = get_tensor_model_parallel_group()
+ world_size = get_tensor_model_parallel_world_size()
+ rank = torch.distributed.get_rank(group)
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank)
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+ x = input.reshape(input.shape[0] * input.shape[1], input.shape[2])
+ output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
+ x,
+ weight.t(),
+ hcomm_info,
+ world_size,
+ bias=bias,
+ gather_index=0,
+ gather_output=False,
+ )
+ output = output.view(
+ int(output.shape[0] / input.shape[1]), input.shape[1], output.shape[1]
+ )
+ elif ampipe_degree > 1 and is_dense_h_to_3h:
+ input_list = input.chunk(ampipe_degree, dim=0)
+ output_list = []
+ for i in range(ampipe_degree):
+ input_chunk = input_list[i]
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input_chunk.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = torch.empty(dim_size, dtype=input_chunk.dtype,
+ device=torch.cuda.current_device())
+ torch.distributed._all_gather_base(
+ all_gather_buffer, input_chunk, group=get_tensor_model_parallel_group()
+ )
+ output_chunk = torch.matmul(all_gather_buffer, weight.t())
+ output_list.append(output_chunk)
+
+ output = torch.cat(output_list, dim=0)
+ elif ampipe_degree > 1 and not is_dense_h_to_3h and ampipe_tp_sp_comm_overlap:
+ total_input = get_fw_ag_output().pop(0)
+ output = torch.matmul(total_input, weight.t())
+ if bias is not None:
+ output = output + bias
+ total_input.untyped_storage().resize_(0)
+
+ else:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+ torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group()
+ )
+ total_input = all_gather_buffer
+ output = torch.matmul(total_input, weight.t())
+ else:
+ total_input = input
+
+ output = torch.matmul(total_input, weight.t())
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ grad_output_buffer = ctx.grad_output_buffer
+ wgrad_deferral_limit = ctx.wgrad_deferral_limit
+
+ wgrad_compute = True
+ if grad_output_buffer is not None:
+ if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
+ grad_output_buffer.append(grad_output)
+ wgrad_compute = False
+
+ if wgrad_compute:
+ if ctx.sequence_parallel:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+ if ctx.ampipe_degree > 1 and ctx.is_dense_h_to_3h:
+ new_indices = get_slice_indices_from_disorder_to_order(dim_size[0],
+ ctx.ampipe_degree,
+ device=torch.cuda.current_device())
+ grad_output = torch.index_select(grad_output, dim=0, index=new_indices)
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(
+ dim_size, input.dtype, "mpu"
+ )
+ handle = torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # gather is scheduled before the input gradient computation
+ total_input = all_gather_buffer
+ else:
+ total_input = input
+ grad_input = grad_output.matmul(weight)
+
+ if ctx.sequence_parallel and wgrad_compute:
+ handle.wait()
+
+ if wgrad_compute:
+ grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
+ grad_output, total_input
+ )
+
+ if ctx.allreduce_dgrad:
+ # Asynchronous all-reduce
+ handle = torch.distributed.all_reduce(
+ grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # all-reduce is scheduled before the weight gradient computation
+
+ if ctx.sequence_parallel:
+ assert not ctx.allreduce_dgrad
+ dim_size = list(input.size())
+ sub_grad_input = torch.empty(
+ dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
+ )
+ # reduce_scatter
+ handle = torch.distributed._reduce_scatter_base(
+ sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # reduce scatter is scheduled before the weight gradient computation
+
+ if ctx.gradient_accumulation_fusion:
+ if wgrad_compute:
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.sequence_parallel:
+ handle.wait()
+ # Need to return None's as gradient has to flow for all the input arguments
+ # provided during forward
+ return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
+
+ if ctx.allreduce_dgrad:
+ handle.wait()
+
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
+
+
+def linear_with_grad_accumulation_and_async_allreduce_moe(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias: Optional[torch.Tensor],
+ gradient_accumulation_fusion: bool,
+ async_grad_allreduce: bool,
+ sequence_parallel: bool,
+ pipe_experts=False,
+ grad_output_buffer: Optional[List[torch.Tensor]] = None,
+ wgrad_deferral_limit: Optional[int] = 0,
+ allreduce_dgrad: bool = None,
+ matmul_id: int = 1,
+ is_nano_row: bool = False,
+ is_nano_column: bool = False,
+ ampipe_degree: int = 1,
+ is_dense_h_to_3h: bool = False,
+) -> torch.Tensor:
+ """Linear layer execution with asynchronous communication and
+ gradient accumulation fusion in backprop.
+
+ This has the option to accumulate the result of backprop
+ calculation into an existing gradient buffer, preventing the need
+ to do an additional addition kernel after the gradient
+ calculation.
+
+ Additionally, the tensor parallel all reduce of the input
+ gradients can be done asynchronously with the calculation of
+ the weight gradients.
+
+ In the case of sequence parallelism, the reduce scatter of the
+ input gradients is done asynchronously with the calcluation of the
+ weight gradients.
+
+ Use of this module requires that the environment variable
+ CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
+ operations, noted in the code, that should be scheduled before
+ compute kernels to overlap the communication with the computation,
+ which is necessary for a speedup but not for correctness so that
+ ordering isn't imposed by the scheduler. Setting
+ CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
+ in the order they are called.
+
+ Args:
+
+ input (torch.Tensor required): input like torch.nn.functional.linear
+
+ weight (torch.Tensor required): weight like torch.nn.functional.linear
+
+ bias (torch.Tensor optional): bias like torch.nn.functional.linear
+
+ gradient_accumulation_fusion (bool required): Perform the gradient
+ accumulation fusion, requires the custom CUDA extension
+ fused_weight_gradient_mlp_cuda module. To use
+ gradient_accumulation_fusion you must install APEX with
+ --cpp_ext and --cuda_ext. For example: "pip install
+ --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
+ " Note that the extension requires CUDA>=11. Otherwise, you
+ must turn off gradient accumulation fusion."
+
+ async_grad_allreduce (bool required): Do the allreduce of input
+ gradients asyncronously with the computation of weight
+ gradients. If sequence_parallel is True, this must be
+ False, as no all reduce is performed.
+
+ sequence_parallel (bool required): Indicates that sequence
+ parallelism is used and thus in the forward pass the input is
+ all gathered, and the backward pass the input gradients are
+ reduce scattered.
+
+ grad_output_buffer (List[torch.Tensor] optional): Buffer used to save
+ output gradients when embedding table wgrad compute is deferred.
+ Defaults to None.
+ """
+ if allreduce_dgrad is None:
+ warnings.warn(
+ "async_grad_allreduce is deprecated and will be removed in a future release. use allreduce_dgrad instead."
+ )
+ allreduce_dgrad = async_grad_allreduce
+
+ args = [
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ allreduce_dgrad,
+ sequence_parallel,
+ grad_output_buffer,
+ wgrad_deferral_limit
+ ]
+
+ if not linear_with_grad_accumulation_and_async_allreduce_moe.warned:
+ if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
+ if sequence_parallel:
+ warnings.warn(
+ "When using sequence parallelism it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce_moe.warned = True
+
+ if allreduce_dgrad:
+ warnings.warn(
+ "When using async grad allreduce it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce_moe.warned = True
+
+ if get_args().use_nanopipe and parallel_state.get_pipeline_model_parallel_world_size() > 1 \
+ and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
+ if get_args().use_nanopipe and (is_nano_row or is_nano_column):
+ args = [
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ wgrad_deferral_limit,
+ async_grad_allreduce,
+ sequence_parallel,
+ pipe_experts,
+ is_nano_row,
+ is_nano_column
+ ]
+ return LinearWithGradAccumulationAndAsyncCommunication_nano.apply(*args)
+ if pipe_experts:
+ return LinearWithGradAccumulationAndAsyncCommunicationPipeExperts.apply(*args, pipe_experts, ampipe_degree)
+ if ampipe_degree > 1:
+ return LinearWithGradAccumulationAndAsyncCommunicationAmpipe.apply(*args, ampipe_degree, is_dense_h_to_3h)
+
+ if get_args().use_nd_matmul:
+ args.append(pipe_experts)
+ args.append(matmul_id)
+ return LinearWithGradAccumulationAndAsyncCommunication_Nd.apply(*args)
+
+ return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
+
+
+linear_with_grad_accumulation_and_async_allreduce_moe.warned = False
+
+
+def parallel_linear_init_wrapper(init_func):
+ @wraps(init_func)
+ def parallel_linear_init_func(self, *args, pipe_experts: bool = False, in_nano: bool = False,
+ ampipe_degree: int = 1,
+ is_dense_h_to_3h: bool = False,
+ **kwargs):
+ output = init_func(self, *args, **kwargs)
+ self.pipe_experts = pipe_experts
+ self.in_nano = in_nano
+ self.ampipe_degree = ampipe_degree
+ self.is_dense_h_to_3h = is_dense_h_to_3h
+ return output
+ return parallel_linear_init_func
+
+
+def row_parallel_moe(self, input_):
+ """Forward of RowParallelLinear
+
+ Args:
+ input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
+
+ Returns:
+ - output
+ - bias
+ """
+
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ # Set up backprop all-reduce.
+ global_args = get_args()
+ if global_args.use_ascend_mc2 and not self.pipe_experts and not self.in_nano:
+ output = Mc2RowSeqParallelLinear.apply(
+ input_, self.weight, None, get_tensor_model_parallel_group()
+ )
+
+ if not self.skip_bias_add:
+ output = output + self.bias if self.bias is not None else output
+ output_bias = None
+ else:
+ output_bias = self.bias
+
+ return output, output_bias
+
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ assert not self.sequence_parallel
+ input_parallel = scatter_to_tensor_model_parallel_region(input_)
+ # Matrix multiply.
+ if not self.weight.requires_grad:
+ self._forward_impl = linear_with_frozen_weight
+ else:
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+
+ if self.in_nano and self.sequence_parallel:
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False,
+ sequence_parallel=True,
+ pipe_experts=self.pipe_experts,
+ is_nano_row=self.in_nano,
+ )
+ output_ = output_parallel
+ elif self.ampipe_degree > 1:
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False,
+ sequence_parallel=False,
+ ampipe_degree=self.ampipe_degree,
+ pipe_experts=self.pipe_experts
+ )
+ ampipe_tp_sp_comm_overlap = get_args().ampipe_tp_sp_comm_overlap
+ if ampipe_tp_sp_comm_overlap or self.pipe_experts:
+ output_ = output_parallel
+ elif self.sequence_parallel:
+ output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
+ else:
+ output_ = reduce_from_tensor_model_parallel_region(output_parallel)
+ else:
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False,
+ sequence_parallel=False,
+ pipe_experts=self.pipe_experts,
+ is_nano_row=self.in_nano,
+ )
+ # All-reduce across all the partitions or self.pipe_experts
+ if self.explicit_expert_comm or self.pipe_experts:
+ assert self.skip_bias_add
+ output_ = output_parallel
+ elif self.sequence_parallel:
+ output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
+ else:
+ output_ = reduce_from_tensor_model_parallel_region(output_parallel)
+ if not self.skip_bias_add:
+ output = (output_ + self.bias) if self.bias is not None else output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.bias
+ return output, output_bias
+
+
+def column_parallel_moe(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
+ """Forward of ColumnParallelLinear
+
+ Args:
+ input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
+
+ weight (optional): weight tensor to use, compulsory when
+ skip_weight_param_allocation is True.
+
+ Returns:
+ - output
+ - bias
+
+ """
+ if weight is None:
+ if self.weight is None:
+ raise RuntimeError(
+ "weight was not supplied to ColumnParallelLinear forward pass "
+ "and skip_weight_param_allocation is True."
+ )
+ weight = self.weight
+ else:
+ # Check the weight passed in is the correct shape
+ expected_shape = (self.output_size_per_partition, self.input_size)
+ if weight.shape != expected_shape:
+ raise RuntimeError(
+ f"supplied weight's shape is {tuple(weight.shape)}, "
+ f"not {expected_shape} as expected"
+ )
+
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ bias = self.bias if not self.skip_bias_add else None
+
+ if (
+ self.allreduce_dgrad
+ or self.sequence_parallel
+ or self.explicit_expert_comm
+ ):
+ input_parallel = input_
+ else:
+ input_parallel = copy_to_tensor_model_parallel_region(input_)
+
+ if self.config.defer_embedding_wgrad_compute:
+ if (
+ self.config.wgrad_deferral_limit == 0
+ or len(self.embedding_activation_buffer) < self.config.wgrad_deferral_limit
+ ):
+ self.embedding_activation_buffer.append(input_parallel)
+
+ # Matrix multiply.
+ if not weight.requires_grad:
+ self._forward_impl = linear_with_frozen_weight
+ else:
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=weight,
+ bias=bias,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False
+ if self.explicit_expert_comm
+ else self.allreduce_dgrad,
+ sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
+ grad_output_buffer=(
+ self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None
+ ),
+ wgrad_deferral_limit=(
+ self.config.wgrad_deferral_limit
+ if self.config.defer_embedding_wgrad_compute
+ else None
+ ),
+ pipe_experts=self.pipe_experts,
+ is_nano_column=self.in_nano,
+ ampipe_degree=self.ampipe_degree,
+ is_dense_h_to_3h=self.is_dense_h_to_3h
+ )
+ if self.gather_output:
+ # All-gather across the partitions.
+ assert not self.sequence_parallel
+ output = gather_from_tensor_model_parallel_region(output_parallel)
+ else:
+ output = output_parallel
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+
+class Mc2RowSeqParallelLinear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, group):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+
+ from .ascend_turbo.ascend_turbo_cfg import ascend_turbo_cfg
+ rank = get_tensor_model_parallel_rank()
+ ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
+ world_size = ascend_turbo_cfg.get_world_size()
+ hcomm_info = None
+
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
+ global_rank
+ )
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+
+ output = torch_npu.npu_mm_reduce_scatter_base(
+ x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=bias
+ )
+
+ ctx.hcomm_info = hcomm_info
+ ctx.world_size = world_size
+
+ output = output.view(
+ output.shape[0] // input_.shape[1], input_.shape[1], output.shape[1]
+ )
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ hcomm_info = ctx.hcomm_info
+ world_size = ctx.world_size
+
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+
+ grad_input, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
+ grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
+ )
+ grad_input = grad_input.view_as(input_)
+
+ total_input = input_
+ total_input = total_input.view(
+ total_input.shape[0] * total_input.shape[1], total_input.shape[2]
+ )
+ grad_weight = all_gather_grad_output.t().matmul(total_input)
+
+ is_grad_bias_needed = ctx.needs_input_grad[2]
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = (
+ grad_output.sum(dim=0)
+ if grad_output.is_contiguous()
+ else grad_output.t().sum(dim=1)
+
+ )
+ else:
+ grad_bias = None
+
+ return grad_input, grad_weight, grad_bias, None
+
+
+def _initialize_affine_weight_cpu_2d(weight, partition_dim, stride=1, return_master_weight=False, *,
+ config: TransformerConfig):
+ """Initialize affine weight for model parallel when use tp-2d"""
+ set_tensor_model_parallel_attributes(
+ tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
+ )
+
+ if partition_dim == 1:
+ row_num = TPYCollectiveComm.get_comm_group_world_size()
+ col_num = TPXCollectiveComm.get_comm_group_world_size()
+ else:
+ row_num = TPXCollectiveComm.get_comm_group_world_size()
+ col_num = TPYCollectiveComm.get_comm_group_world_size()
+
+ # Initialize master weight
+ split_input_size, split_output_size = weight.size()
+ input_size = split_input_size * row_num
+ output_size = split_output_size * col_num
+
+ master_weight = torch.empty(input_size, output_size, dtype=torch.float, requires_grad=False)
+ config.init_method(master_weight)
+
+ master_weight = master_weight.to(dtype=config.params_dtype)
+
+ x = TPXCollectiveComm.get_comm_rank()
+ y = TPYCollectiveComm.get_comm_rank()
+
+ rows = torch.chunk(master_weight, row_num, dim=0)
+ if partition_dim == 1:
+ row_idx = y
+ col_idx = x
+ else:
+ row_idx = x
+ col_idx = y
+
+ row = rows[row_idx]
+ cols = torch.chunk(row, col_num, dim=1)
+ final_weight = cols[col_idx].contiguous()
+ weight.data.copy_(final_weight)
+
+ if return_master_weight:
+ return master_weight
+
+
+def _initialize_affine_weight_cpu_nd(
+ weight,
+ output_size,
+ input_size,
+ input_size_per_partition,
+ output_size_per_partition,
+ init_method,
+ stride=1,
+ return_master_weight=False,
+ *,
+ params_dtype=torch.float32
+):
+ """Initialize affine weight for model parallel when use nd-matmul"""
+ set_tensor_model_parallel_attributes(
+ tensor=weight, is_parallel=True, dim=0, stride=stride
+ )
+
+ # Initialize master weight
+ master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False)
+ init_method(master_weight)
+
+ master_weight = master_weight.to(dtype=params_dtype)
+ # Split and copy
+ rank = ps.get_tensor_model_parallel_rank()
+ world_size = ps.get_tensor_model_parallel_world_size()
+
+ def compute_target_rank(rank, row_num, col_num):
+ return rank % row_num * col_num + rank // row_num
+
+ # The weight positions of nd and megatron are different. So weight needs to be rearranged.
+ # This rearrangement is only to make the calculations of nd and megatron consistent.
+ # Even if this rearrangement is removed, it will not affect the correctness of nd calculation.
+ row_num = input_size // input_size_per_partition
+ col_num = output_size // output_size_per_partition
+ weight_list = torch.split(master_weight, master_weight.size()[0] // world_size, dim=0)
+ tensor_list = [weight_list[compute_target_rank(i, row_num, col_num)] for i in range(world_size)]
+ master_weight = torch.cat(tensor_list, dim=0)
+
+ weight_list_1 = torch.split(master_weight, input_size_per_partition, dim=1)
+ weight_1 = weight_list_1[rank // col_num]
+ weight_list_2 = torch.split(weight_1, output_size_per_partition, dim=0)
+ my_weight_list = weight_list_2[rank % col_num:: world_size]
+
+ with torch.no_grad():
+ torch.cat(my_weight_list, dim=0, out=weight)
+ if return_master_weight:
+ return master_weight
+ return None
+
+
+class LinearWithGradAccumulationAndAsyncCommunication_Nd(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ wgrad_deferral_limit,
+ sequence_parallel,
+ grad_output_buffer,
+ pipe_experts,
+ matmul_id,
+ ):
+ if sequence_parallel:
+ raise AssertionError(
+ 'Nd_matmul cannot be used with sequence_parallel.'
+ 'If you want to train long sequences, '
+ 'you can use ulysess or context_parallel that is compatible with nd_matmul.'
+ )
+ ctx.use_bias = bias is not None
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.async_grad_allreduce = async_grad_allreduce
+ ctx.wgrad_deferral_limit = wgrad_deferral_limit
+ ctx.sequence_parallel = sequence_parallel
+ ctx.save_for_backward(input, weight)
+
+ if matmul_id == 1:
+ world_size1 = get_tensor_model_parallel_world_size_for_nd1_dim1()
+ comm_group1 = get_tensor_model_parallel_group_for_nd1_dim1()
+ world_size2 = get_tensor_model_parallel_world_size_for_nd1_dim2()
+ comm_group2 = get_tensor_model_parallel_group_for_nd1_dim2()
+ else:
+ world_size1 = get_tensor_model_parallel_world_size_for_nd2_dim1()
+ comm_group1 = get_tensor_model_parallel_group_for_nd2_dim1()
+ world_size2 = get_tensor_model_parallel_world_size_for_nd2_dim2()
+ comm_group2 = get_tensor_model_parallel_group_for_nd2_dim2()
+
+ ctx.world_size1 = world_size1
+ ctx.comm_group1 = comm_group1
+ ctx.world_size2 = world_size2
+ ctx.comm_group2 = comm_group2
+
+ last_dim = input.dim() - 1
+ total_input_list = [torch.empty_like(input) for _ in range(world_size1)]
+ torch.distributed.all_gather(total_input_list, input, group=comm_group1)
+ total_input = torch.cat(total_input_list, dim=last_dim)
+
+ output_parallel = torch.matmul(total_input, weight.t())
+ output_parallel = output_parallel.transpose(0, 2)
+
+ dim_size = list(output_parallel.size())
+ dim_size[0] //= world_size2
+ output = torch.empty(dim_size, dtype=output_parallel.dtype, device=torch.cuda.current_device())
+ torch.distributed._reduce_scatter_base(
+ output, output_parallel.contiguous(), group=comm_group2
+ )
+ output = output.transpose(0, 2).contiguous()
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ world_size1 = ctx.world_size1
+ comm_group1 = ctx.comm_group1
+ world_size2 = ctx.world_size2
+ comm_group2 = ctx.comm_group2
+ input, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ last_dim = grad_output.dim() - 1
+
+ grad_output_ag_list = [torch.empty_like(grad_output) for _ in range(world_size2)]
+ torch.distributed.all_gather(grad_output_ag_list, grad_output.contiguous(), group=comm_group2)
+ grad_output_ag = torch.cat(grad_output_ag_list, dim=last_dim)
+
+ total_input_list = [torch.empty_like(input) for _ in range(world_size1)]
+ handle1 = torch.distributed.all_gather(total_input_list, input, group=comm_group1, async_op=True)
+
+ grad_bias = grad_output_ag.view(
+ grad_output_ag.shape[0] * grad_output_ag.shape[1], grad_output_ag.shape[2]
+ ).sum(dim=0) if use_bias else None
+
+ grad_input = grad_output_ag.matmul(weight)
+
+ grad_input = grad_input.transpose(0, 2)
+ dim_size = list(grad_input.size())
+ dim_size[0] = dim_size[0] // world_size1
+
+ handle1.wait()
+ total_input = torch.cat(total_input_list, dim=last_dim)
+
+ grad_input_rs = torch.empty(dim_size, dtype=grad_input.dtype, device=torch.cuda.current_device())
+
+ handle2 = torch.distributed._reduce_scatter_base(
+ grad_input_rs, grad_input.contiguous(), group=comm_group1, async_op=True
+ )
+
+ grad_output_ag = grad_output_ag.view(
+ grad_output_ag.shape[0] * grad_output_ag.shape[1], grad_output_ag.shape[2]
+ )
+ total_input = total_input.view(
+ total_input.shape[0] * total_input.shape[1], total_input.shape[2]
+ )
+
+ if ctx.gradient_accumulation_fusion:
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_input, grad_output_ag, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_input, grad_output_ag, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output_ag.t().matmul(total_input)
+
+ handle2.wait()
+ grad_input_rs = grad_input_rs.transpose(0, 2).contiguous()
+ return grad_input_rs, grad_weight, grad_bias, None, None, None, None, None, None, None, None
+
+
+class Nd_ParallelLinear(torch.nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ *,
+ config: ModelParallelConfig,
+ init_method: Callable,
+ bias: bool,
+ input_is_parallel: bool,
+ skip_bias_add: bool,
+ stride: int = 1,
+ keep_master_weight_for_test: bool = False,
+ is_expert: bool = False,
+ tp_comm_buffer_name: str = None, # Not used
+ matmul_id: int = 1,
+ ):
+ """Nd_ParallelLinear is used to replace the columnParallelLinear and RowParallelLinear in Megatron TP.
+
+ Args:
+ matmul_id: which GEMM operation within the attention or FFN block.
+ if matmul_id is 1 in attention, which represents GEMM for compute QKV.
+ """
+ super(Nd_ParallelLinear, self).__init__()
+
+ self.input_size = input_size
+ self.output_size = output_size
+ self.input_is_parallel = input_is_parallel
+ if matmul_id == 1:
+ self.world_size_dim1 = get_tensor_model_parallel_world_size_for_nd1_dim1()
+ self.world_size_dim2 = get_tensor_model_parallel_world_size_for_nd1_dim2()
+ else:
+ self.world_size_dim1 = get_tensor_model_parallel_world_size_for_nd2_dim1()
+ self.world_size_dim2 = get_tensor_model_parallel_world_size_for_nd2_dim2()
+
+ self.matmul_id = matmul_id
+ self.input_size_per_partition = divide(input_size, self.world_size_dim2)
+ self.output_size_per_partition = divide(output_size, self.world_size_dim1)
+
+ self.skip_bias_add = skip_bias_add
+ self.config = config
+ self.is_expert = is_expert
+ self.expert_parallel = config.expert_model_parallel_size > 1
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
+ self.sequence_parallel = config.sequence_parallel
+ if self.sequence_parallel:
+ raise RuntimeError(
+ 'Nd_matmul cannot be used with sequence_parallel.'
+ 'If you want to train long sequences, '
+ 'you can use ulysess or context_parallel that is compatible with nd_matmul.'
+ )
+
+ if config.use_cpu_initialization:
+ self.weight = torch.nn.Parameter(
+ torch.empty(self.output_size, self.input_size_per_partition, dtype=config.params_dtype)
+ )
+
+ if config.perform_initialization:
+ self.master_weight = _initialize_affine_weight_cpu_nd(
+ self.weight,
+ self.output_size,
+ self.input_size,
+ self.input_size_per_partition,
+ self.output_size_per_partition,
+ init_method,
+ stride=stride,
+ return_master_weight=keep_master_weight_for_test,
+ params_dtype=config.params_dtype
+ )
+ else:
+ self.weight = torch.nn.Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ self.input_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype
+ )
+ )
+ if config.perform_initialization:
+ _initialize_affine_weight_gpu(
+ self.weight,
+ init_method,
+ partition_dim=1,
+ stride=stride,
+ expert_parallel=(self.is_expert and self.expert_parallel)
+ )
+
+ setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
+
+ if bias:
+ if config.use_cpu_initialization:
+ self.bias = torch.nn.Parameter(
+ torch.empty(self.output_size, dtype=config.params_dtype)
+ )
+ else:
+ self.bias = torch.nn.Parameter(
+ torch.empty(
+ self.output_size,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype
+ )
+ )
+
+ if config.perform_initialization:
+ with torch.no_grad():
+ self.bias.zero_()
+
+ setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
+ setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
+ else:
+ self.register_parameter('bias', None)
+
+ self.explicit_expert_comm = self.is_expert and (
+ self.sequence_parallel or self.expert_parallel
+ )
+
+ def forward(self, input_):
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce_moe
+ output_ = self._forward_impl(
+ input=input_,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False,
+ sequence_parallel=False,
+ matmul_id=self.matmul_id
+ )
+
+ if not self.skip_bias_add:
+ output = (output_ + self.bias) if self.bias is not None else output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.bias
+
+ return output, output_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/__init__.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..66da6a3f4191ad935e0f07854e50701f397137d7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/__init__.py
@@ -0,0 +1 @@
+# lcal coc (communication over computation)
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_all_reduce.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_all_reduce.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15a2d57525f5ca03145cd883f3935965aca4ef3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_all_reduce.py
@@ -0,0 +1,73 @@
+from functools import reduce
+import torch
+import torch_npu
+
+from .min_comm_cfg import min_comm_config
+from .matmul_soc_friendly import get_aligned_mm_inputs
+from .rewrite_parallel_linears_all_reduce import RewriteColumnAllReduceFunction, RewriteRowAllReduceFunction
+from .coc_utils import set_context, CommunicationType, COCParallel
+from .coc_utils import reshape_to_2D, is_grad_needed, get_parallel_num, get_output_shape
+
+
+class COCColumnAllReduceFunction(RewriteColumnAllReduceFunction):
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+ grad_input = grad_output.matmul(weight)
+ handle = torch.distributed.all_reduce(grad_input, group=min_comm_config.tp_group, async_op=True)
+ grad_weight, grad_bias = None, None
+ if is_grad_weight_needed:
+ grad_output = reshape_to_2D(grad_output)
+ grad_weight = grad_output.t().matmul(reshape_to_2D(input_))
+ handle.wait()
+ grad_bias = grad_output.sum(dim=0) if ctx.use_bias and is_grad_bias_needed else None
+ else:
+ handle.wait()
+
+ return grad_input, grad_weight, grad_bias
+
+
+class COCRowAllReduceFunction(RewriteRowAllReduceFunction):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ set_context(ctx, input_, weight, bias)
+ trans_weight = weight.t()
+
+ parallel_num = get_parallel_num(m=reduce(lambda x, y: x * y, input_.shape[:-1]),
+ k=trans_weight.shape[0],
+ n=trans_weight.shape[1])
+ if parallel_num == 1:
+ return RewriteRowAllReduceFunction.forward(ctx, input_, weight, bias)
+
+ output_orig_shape = get_output_shape(input_, trans_weight, 1, is_gather=True)
+ input_ = reshape_to_2D(input_)
+
+ if min_comm_config.matmul_soc_friendly_enabled:
+ input_, trans_weight = get_aligned_mm_inputs(input_, trans_weight, sp_coef=min_comm_config.tp_world_size,
+ parallel_num=parallel_num)
+
+ def compute_fcn(input_tensor, output_tensor):
+ torch.matmul(input_tensor, trans_weight, out=output_tensor)
+ return output_tensor
+
+ coc_all_gather = COCParallel(input_, CommunicationType.ALL_REDUCE, compute_fcn, compute_first=True,
+ weight_shape_list=list(trans_weight.shape))
+ output_ = coc_all_gather.run()
+ output_ = output_.reshape(output_orig_shape)
+ if bias is not None:
+ output_ = output_ + bias
+ return output_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ total_input, weight = ctx.saved_tensors
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+ grad_input = grad_output.matmul(weight)
+ grad_weight, grad_bias = None, None
+ if is_grad_weight_needed:
+ grad_output = reshape_to_2D(grad_output)
+ grad_weight = grad_output.t().matmul(reshape_to_2D(total_input))
+ grad_bias = grad_output.sum(dim=0) if ctx.use_bias and is_grad_bias_needed else None
+
+ return grad_input, grad_weight, grad_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_all_reduce_fused.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_all_reduce_fused.py
new file mode 100644
index 0000000000000000000000000000000000000000..3520f363d0a0779f976b7bf9f112c6f0e096701c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_all_reduce_fused.py
@@ -0,0 +1,43 @@
+from functools import reduce
+
+from .coc_parallel_linears_all_reduce import COCRowAllReduceFunction
+from .rewrite_parallel_linears_all_reduce import RewriteRowAllReduceFunction
+from .coc_utils import get_parallel_num, get_output_shape, allocate_for_output
+from .coc_utils import set_context, allocate_for_output, reshape_to_2D, is_grad_needed
+
+
+class FusedCOCRowAllReduceFunction(RewriteRowAllReduceFunction):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ set_context(ctx, input_, weight, bias)
+ parallel_num = get_parallel_num(m=reduce(lambda x, y: x * y, input_.shape[:-1]),
+ k=weight.shape[1],
+ n=weight.shape[0])
+ if parallel_num == 1:
+ return RewriteRowAllReduceFunction.forward(ctx, input_, weight, bias)
+ elif parallel_num in [2, 4, 8]:
+ return COCRowAllReduceFunction.forward(ctx, input_, weight, bias)
+
+ output_orig_shape = get_output_shape(input_, weight.t(), 1, is_gather=True)
+ input_ = reshape_to_2D(input_)
+ output_ = allocate_for_output(input_, weight.t(), 1, is_gather=True)
+
+ coc_ops.matmul_all_reduce(input_, weight, output_, bias)
+ output_ = output_.reshape(output_orig_shape)
+
+ return output_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ total_input, weight = ctx.saved_tensors
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+ grad_input = grad_output.matmul(weight)
+ grad_weight, grad_bias = None, None
+ if is_grad_weight_needed:
+ grad_output = reshape_to_2D(grad_output)
+ grad_weight = grad_output.t().matmul(reshape_to_2D(total_input))
+ grad_bias = grad_output.sum(dim=0) if ctx.use_bias and is_grad_bias_needed else None
+
+ return grad_input, grad_weight, grad_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_sequence_parallel.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_sequence_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..add98553ea9304aef66d4ed749a89ef9a3aa9b7e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_sequence_parallel.py
@@ -0,0 +1,242 @@
+from functools import reduce
+import torch
+import torch_npu
+
+from megatron.training import get_args
+from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32, npu_matmul_add_fp16
+from .min_comm_cfg import min_comm_config
+from .matmul_soc_friendly import get_aligned_mm_inputs
+from .coc_utils import CommunicationType, COCParallel, get_output_shape
+from .coc_utils import shuffle_as_coc_reduce_scatter, shuffle_as_coc_all_gather
+from .coc_utils import set_context, reshape_to_2D, async_gather_along_first_dim, is_grad_needed, get_parallel_num
+from .rewrite_parallel_linears_sequence_parallel import RewriteColumnSeqParallelFunction, RewriteRowSeqParallelFunction
+
+ALIGN_SIZE = 512
+
+
+class COCColumnSeqParallelFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ ctx.save_for_backward(input_)
+ ctx.use_bias = bias is not None
+ ctx.weight = weight
+ trans_weight = weight.t()
+
+ parallel_num = get_parallel_num(m=reduce(lambda x, y: x * y, input_.shape[:-1]) * min_comm_config.tp_world_size,
+ k=trans_weight.shape[0],
+ n=trans_weight.shape[1])
+ if parallel_num == 1:
+ return RewriteColumnSeqParallelFunction.forward(ctx, input_, weight, bias)
+
+ output_orig_shape = get_output_shape(input_, trans_weight, min_comm_config.tp_world_size, is_gather=True)
+ gathered_input_shape = get_output_shape(input_, None, min_comm_config.tp_world_size, is_gather=True)
+ input_ = reshape_to_2D(input_)
+
+ if min_comm_config.matmul_soc_friendly_enabled:
+ input_, trans_weight = get_aligned_mm_inputs(input_, trans_weight, sp_coef=min_comm_config.tp_world_size,
+ parallel_num=parallel_num)
+
+ def compute_fcn(input_tensor, output_tensor):
+ torch.matmul(input_tensor, trans_weight, out=output_tensor)
+ return output_tensor
+
+ coc_parallel = COCParallel(input_, CommunicationType.ALL_GATHER, compute_fcn, compute_first=False,
+ weight_shape_list=list(trans_weight.shape), parallel_num=parallel_num)
+ output = coc_parallel.run()
+ output = shuffle_as_coc_reduce_scatter(output, min_comm_config.tp_world_size, parallel_num)
+ if not min_comm_config.all_gather_recomputation_enabled:
+ total_input = shuffle_as_coc_reduce_scatter(coc_parallel.comm_output, min_comm_config.tp_world_size,
+ parallel_num)
+ ctx.total_input = total_input.reshape(gathered_input_shape)
+ output = output.reshape(output_orig_shape)
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_ = ctx.saved_tensors[0]
+ weight = ctx.weight
+ grad_input_orig_shape = get_output_shape(grad_output, weight, 1, is_gather=True)
+ grad_output = reshape_to_2D(grad_output)
+
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+ total_input_work, total_input = None, None
+
+ if is_grad_weight_needed:
+ if min_comm_config.all_gather_recomputation_enabled:
+ total_input_work, total_input = async_gather_along_first_dim(input_, min_comm_config.tp_group,
+ min_comm_config.tp_world_size)
+ else:
+ total_input = ctx.total_input
+
+ # if grad_output.shape[-1] is not 512B aligned, transpose its memory alignment but keep its shape
+ if grad_output.is_contiguous() and (grad_output.shape[-1] * grad_output.element_size()) % ALIGN_SIZE > 0:
+ grad_output = grad_output.t().contiguous().t()
+ grad_input = grad_output.matmul(weight)
+ grad_input = grad_input.reshape(grad_input_orig_shape)
+ sub_grad_input = torch.empty(list(input_.size()), dtype=input_.dtype, device=torch.cuda.current_device())
+ sub_grad_input_work = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
+ group=min_comm_config.tp_group, async_op=True)
+ grad_weight, grad_bias = None, None
+ if is_grad_weight_needed:
+ if min_comm_config.all_gather_recomputation_enabled:
+ total_input_work.wait()
+ total_input = reshape_to_2D(total_input)
+ if get_args().gradient_accumulation_fusion:
+ if weight.main_grad.dtype == torch.float32:
+ npu_matmul_add_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ npu_matmul_add_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=total_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=total_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ sub_grad_input_work.wait()
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = grad_output.sum(dim=0) if grad_output.is_contiguous() else grad_output.t().sum(dim=1)
+ else:
+ sub_grad_input_work.wait()
+ return sub_grad_input, grad_weight, grad_bias
+
+
+class COCRowSeqParallelFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ ctx.save_for_backward(input_)
+ ctx.use_bias = bias is not None
+ ctx.weight = weight
+ ctx.world_size = min_comm_config.tp_world_size
+ trans_weight = weight.t()
+
+ parallel_num = get_parallel_num(m=reduce(lambda x, y: x * y, input_.shape[:-1]),
+ k=trans_weight.shape[0],
+ n=trans_weight.shape[1])
+ if parallel_num == 1:
+ return RewriteRowSeqParallelFunction.forward(ctx, input_, weight, bias)
+
+ output_orig_shape = get_output_shape(input_, trans_weight, min_comm_config.tp_world_size, is_gather=False)
+ input_ = reshape_to_2D(input_)
+
+ if min_comm_config.matmul_soc_friendly_enabled:
+ input_, trans_weight = get_aligned_mm_inputs(input_, trans_weight, parallel_num=parallel_num)
+
+ def compute_fcn(input_tensor):
+ sub_output = torch.matmul(input_tensor, trans_weight)
+ return sub_output
+
+ input_ = shuffle_as_coc_all_gather(input_, ctx.world_size, parallel_num)
+ coc_reduce_scatter = COCParallel(input_, CommunicationType.REDUCE_SCATTER, compute_fcn, compute_first=True,
+ weight_shape_list=list(trans_weight.shape), parallel_num=parallel_num)
+ output_ = coc_reduce_scatter.run()
+ output_ = output_.reshape(output_orig_shape)
+ if bias is not None:
+ output_ = output_ + bias
+ return output_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ total_input = ctx.saved_tensors[0]
+ weight = ctx.weight
+
+ parallel_num = get_parallel_num(
+ m=reduce(lambda x, y: x * y, grad_output.shape[:-1]) * min_comm_config.tp_world_size,
+ k=weight.shape[0],
+ n=weight.shape[1]
+ )
+ if parallel_num == 1:
+ return RewriteRowSeqParallelFunction.backward(ctx, grad_output)
+
+ grad_input_orig_shape = get_output_shape(grad_output, weight, min_comm_config.tp_world_size, is_gather=True)
+ grad_output = reshape_to_2D(grad_output)
+
+ if min_comm_config.matmul_soc_friendly_enabled:
+ grad_output, weight = get_aligned_mm_inputs(grad_output, weight, sp_coef=min_comm_config.tp_world_size,
+ parallel_num=parallel_num)
+
+ def compute_fcn(input_tensor, output_tensor):
+ torch.matmul(input_tensor, weight, out=output_tensor)
+ return output_tensor
+
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+
+ coc_all_gather = COCParallel(grad_output, CommunicationType.ALL_GATHER, compute_fcn, compute_first=False,
+ weight_shape_list=list(weight.shape), parallel_num=parallel_num)
+ grad_input = coc_all_gather.run()
+ grad_input = shuffle_as_coc_reduce_scatter(grad_input, ctx.world_size, parallel_num)
+
+ grad_input = grad_input.reshape(grad_input_orig_shape)
+
+ grad_weight, grad_bias = None, None
+
+ if is_grad_weight_needed:
+ grad_output = coc_all_gather.comm_output
+ grad_output = shuffle_as_coc_reduce_scatter(grad_output, ctx.world_size, parallel_num)
+ total_input = reshape_to_2D(total_input)
+ if get_args().gradient_accumulation_fusion:
+ if weight.main_grad.dtype == torch.float32:
+ npu_matmul_add_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ npu_matmul_add_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=total_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=total_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = grad_output.sum(dim=0) if grad_output.is_contiguous() else grad_output.t().sum(dim=1)
+
+ return grad_input, grad_weight, grad_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_sequence_parallel_fused.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_sequence_parallel_fused.py
new file mode 100644
index 0000000000000000000000000000000000000000..6459ccded62f671ae340f6283726d5d983e14d2d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_parallel_linears_sequence_parallel_fused.py
@@ -0,0 +1,153 @@
+from functools import reduce
+import torch
+import torch_npu
+
+from .min_comm_cfg import min_comm_config
+from .coc_utils import get_parallel_num, set_context, is_grad_needed, check_equal
+from .coc_utils import async_gather_along_first_dim, reshape_to_2D, allocate_for_output
+from .coc_parallel_linears_sequence_parallel import COCColumnSeqParallelFunction, COCRowSeqParallelFunction
+from .rewrite_parallel_linears_sequence_parallel import RewriteColumnSeqParallelFunction, RewriteRowSeqParallelFunction
+
+ALIGN_SIZE = 512
+
+
+class FusedCOCColumnSeqParallelFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ set_context(ctx, input_, weight, bias)
+
+ parallel_num = get_parallel_num(reduce(lambda x, y: x * y, input_.shape[:-1]) * min_comm_config.tp_world_size,
+ weight.shape[1], weight.shape[0], default_parallel_num=-1)
+ if parallel_num == 1:
+ return RewriteColumnSeqParallelFunction.forward(ctx, input_, weight, bias)
+ elif parallel_num in [2, 4, 8]:
+ return COCColumnSeqParallelFunction.forward(ctx, input_, weight, bias)
+
+ output_shape = list(input_.shape)[:-1] + list([weight.shape[0]])
+ output_shape[0] = output_shape[0] * min_comm_config.tp_world_size
+ input_ = reshape_to_2D(input_)
+
+ output = allocate_for_output(input1=input_, input2=weight.t(),
+ tp_world_size=min_comm_config.tp_world_size, is_gather=True)
+
+ coc_ops.all_gather_matmul(input_, weight, output, bias)
+ output = output.reshape(output_shape)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ input_, weight = ctx.saved_tensors
+ check_equal(grad_output.shape[0] % min_comm_config.tp_world_size, 0,
+ error_info="m size must be multiple of world size")
+ sub_grad_input_shape = [grad_output.shape[0] // min_comm_config.tp_world_size] + \
+ list(grad_output.shape[1:-1]) + [weight.shape[-1]]
+ # manually make sure grad_output is 2D and its memory inner axis is 512B aligned
+ grad_output = reshape_to_2D(grad_output)
+ if grad_output.is_contiguous() and (grad_output.shape[-1] * grad_output.element_size()) % ALIGN_SIZE > 0:
+ grad_output = grad_output.t().contiguous().t()
+ sub_grad_input = allocate_for_output(input1=reshape_to_2D(input_))
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+ grad_weight, grad_bias = None, None
+
+ if is_grad_weight_needed:
+ if min_comm_config.all_gather_recomputation_enabled:
+ total_input_work, total_input = async_gather_along_first_dim(input_, min_comm_config.tp_group,
+ min_comm_config.tp_world_size)
+ else:
+ total_input = ctx.total_input
+ total_input = reshape_to_2D(total_input)
+
+ if min_comm_config.enable_coc_in_column_backward:
+ coc_ops.matmul_reduce_scatter(grad_output, weight, sub_grad_input, bias=None)
+ else:
+ grad_input = grad_output.matmul(weight)
+ sub_grad_input_work = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
+ group=min_comm_config.tp_group,
+ async_op=True)
+
+ if min_comm_config.all_gather_recomputation_enabled:
+ total_input_work.wait()
+
+ grad_weight = grad_output.t().matmul(total_input)
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = grad_output.sum(dim=0) if grad_output.is_contiguous() else grad_output.t().sum(dim=1)
+
+ if not min_comm_config.enable_coc_in_column_backward:
+ sub_grad_input_work.wait()
+
+ else:
+ grad_input = grad_output.matmul(weight)
+ torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, group=min_comm_config.tp_group)
+
+ sub_grad_input = sub_grad_input.reshape(sub_grad_input_shape)
+ return sub_grad_input, grad_weight, grad_bias
+
+
+class FusedCOCRowSeqParallelFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ set_context(ctx, input_, weight, bias)
+ ctx.world_size = min_comm_config.tp_world_size
+
+ parallel_num = get_parallel_num(reduce(lambda x, y: x * y, input_.shape[:-1]), weight.shape[1],
+ weight.shape[0], default_parallel_num=-1)
+ if parallel_num == 1:
+ return RewriteRowSeqParallelFunction.forward(ctx, input_, weight, bias)
+ elif parallel_num in [2, 4, 8]:
+ return COCRowSeqParallelFunction.forward(ctx, input_, weight, bias)
+
+ output_shape = list(input_.shape)[:-1] + list([weight.shape[0]])
+ output_shape[0] = output_shape[0] // min_comm_config.tp_world_size
+ input_ = reshape_to_2D(input_)
+
+ output = allocate_for_output(input_, weight.t(), min_comm_config.tp_world_size, is_gather=False)
+ coc_ops.matmul_reduce_scatter(input_, weight, output, bias)
+ output = output.reshape(output_shape)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ from mindspeed.ops.lcal_functional import coc_ops
+
+ total_input, weight = ctx.saved_tensors
+
+ parallel_num = get_parallel_num(
+ reduce(lambda x, y: x * y, grad_output.shape[:-1]) * min_comm_config.tp_world_size,
+ weight.shape[0], weight.shape[1], default_parallel_num=-1
+ )
+ if parallel_num == 1:
+ return RewriteRowSeqParallelFunction.backward(ctx, grad_output)
+ elif parallel_num in [2, 4, 8]:
+ return COCRowSeqParallelFunction.backward(ctx, grad_output)
+
+ grad_input_shape = list(grad_output.shape)[:-1] + list([weight.shape[-1]])
+ grad_input_shape[0] = grad_input_shape[0] * min_comm_config.tp_world_size
+ grad_output = reshape_to_2D(grad_output)
+ total_input = reshape_to_2D(total_input)
+ grad_input = allocate_for_output(grad_output, weight, min_comm_config.tp_world_size, is_gather=True)
+
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+ grad_weight, grad_bias = None, None
+
+ if is_grad_weight_needed:
+ gathered_grad_output = allocate_for_output(grad_output, tp_world_size=min_comm_config.tp_world_size,
+ is_gather=True)
+ coc_ops.all_gather_matmul_v2(grad_output, weight, grad_input, gathered_grad_output, bias=None)
+
+ grad_weight = gathered_grad_output.t().matmul(total_input)
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = gathered_grad_output.sum(dim=0) if gathered_grad_output.is_contiguous() \
+ else gathered_grad_output.t().sum(dim=1)
+ else:
+ coc_ops.all_gather_matmul(grad_output, weight, grad_input, bias=None)
+
+ grad_input = grad_input.reshape(grad_input_shape)
+ return grad_input, grad_weight, grad_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_utils.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d68f8ffe4d7275dac30cc5808b1c302470bdd09
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/coc_utils.py
@@ -0,0 +1,245 @@
+from enum import Enum
+import torch
+
+from .min_comm_cfg import min_comm_config
+
+
+def check_equal(a, b, error_info):
+ if a != b:
+ if torch.npu.current_device() == 0:
+ print(error_info)
+
+
+def print_tensor_value(name, value, device_id=0):
+ if min_comm_config.print_tensor_value_enabled and torch.npu.current_device() == device_id:
+ n = min_comm_config.parallel_num * min_comm_config.tp_world_size
+ per = value.shape[0] // n
+ slices = []
+ for k in range(n):
+ v = torch.flatten(value[k * per: (k + 1) * per])
+ slices.append(v[:5])
+ print(f"{name}, shape={value.shape}, value=\n{torch.cat(tuple(slices)).view(n, -1)}", flush=True)
+
+
+def set_context(ctx, input_, weight, bias):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+
+
+def infer_matmul_out_shape(shape_a, shape_b):
+ shape_a[-1] = shape_b[-1]
+ return shape_a
+
+
+def reshape_to_2D(input_tensor):
+ # Convert the tensor shapes to 2D for execution compatibility
+ input_tensor = input_tensor.reshape(input_tensor.shape[0] * input_tensor.shape[1],
+ input_tensor.shape[2])
+ return input_tensor
+
+
+def async_gather_along_first_dim(input_, group, world_size):
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] * world_size
+ output_ = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device(), requires_grad=False)
+ work = torch.distributed._all_gather_base(output_, input_.contiguous(), group=group, async_op=True)
+ return work, output_
+
+
+def shuffle_as_coc_reduce_scatter(input_, world_size, parallel_num):
+ per = input_.shape[0] // parallel_num // world_size
+ input_shape = list(input_.shape)
+ reshape_tensor = torch.reshape(input_, [parallel_num, world_size, per] + input_shape[1:])
+ return torch.reshape(reshape_tensor.transpose(0, 1), tuple(input_shape))
+
+
+def shuffle_as_coc_all_gather(input_, world_size, parallel_num):
+ per = input_.shape[0] // parallel_num // world_size
+ input_shape = list(input_.shape)
+ reshape_tensor = torch.reshape(input_, [world_size, parallel_num, per] + input_shape[1:])
+ return torch.reshape(reshape_tensor.transpose(0, 1), tuple(input_shape))
+
+
+def is_grad_needed(needs_input_grad):
+ is_grad_input_needed, is_grad_weight_needed, is_grad_bias_needed = needs_input_grad
+ if not is_grad_input_needed:
+ raise RuntimeError("To use COC, grad_input is necessary to compute. Check if optimizer update is turned off by \
+ mistake.")
+ if not is_grad_weight_needed and is_grad_bias_needed:
+ raise RuntimeError("To use COC, grad_weight must be needed if grad_bias is required.")
+ return is_grad_weight_needed, is_grad_bias_needed
+
+
+def get_parallel_num(m, k, n, default_parallel_num=min_comm_config.parallel_num):
+ parallel_num = default_parallel_num
+ shape_str = str([m, k, n])
+ if len(min_comm_config.customized_coc_dict) > 0 and str(shape_str) in min_comm_config.customized_coc_dict.keys():
+ parallel_num = min_comm_config.customized_coc_dict.get(shape_str)
+ if not min_comm_config.coc_fused_kernel and m < parallel_num:
+ return 1
+ if parallel_num not in [-1, 1, 2, 4, 8]:
+ raise RuntimeError("invalid parallel num, only support integer from 1, 2, 4 or 8.")
+ return parallel_num
+
+
+def get_output_shape(input1, input2=None, tp_world_size=1, is_gather=True):
+ check_equal(input1.dim() >= 2 and (input2 is None or input2.dim() == 2), True,
+ error_info="invalid matmul input shape for CoC")
+ output_shape = list(input1.shape)[:-1] + list([input2.shape[-1]]) if input2 is not None else list(input1.shape)
+ if not is_gather:
+ check_equal(output_shape[0] % tp_world_size == 0 and output_shape[0] >= tp_world_size, True,
+ error_info="invalid matmul m shape for CoC")
+ output_shape[0] = output_shape[0] * tp_world_size if is_gather else output_shape[0] // tp_world_size
+ return output_shape
+
+
+# input1 is required to be 2-dimensional here.
+def allocate_for_output(input1, input2=None, tp_world_size=1, is_gather=True):
+ if input2 is not None:
+ dim_size = list(input1.shape)[:-1] + list([input2.shape[1]])
+ else:
+ dim_size = list(input1.shape)
+ dim_size[0] = dim_size[0] * tp_world_size if is_gather else dim_size[0] // tp_world_size
+ output = torch.empty(dim_size, dtype=input1.dtype, device=torch.npu.current_device())
+ return output
+
+
+class CommunicationType(Enum):
+ ALL_GATHER = 0
+ ALL_REDUCE = 1
+ REDUCE_SCATTER = 2
+
+
+class COCParallel:
+ def __init__(self, input_data, comm_type, compute_fcn, compute_first=True, synchronize=True, weight_shape_list=None,
+ parallel_num=min_comm_config.parallel_num):
+ self.input_data = input_data
+ self.split_num = parallel_num
+ self.synchronize = synchronize
+ self.comm_type = comm_type
+ self.compute_fcn = compute_fcn
+ self.compute_first = compute_first
+ self.works = []
+ self.group = min_comm_config.tp_group
+ self.world_size = min_comm_config.tp_world_size
+ self.input_slice = input_data.shape[0] // self.split_num
+ self.init_output_space(input_data, weight_shape_list, compute_first)
+
+ def init_output_space(self, input_data, weight_shape_list, compute_first):
+ if weight_shape_list is None:
+ self.compute_output_shape_slice = list(input_data.shape)
+ else:
+ check_equal(input_data.shape[-1], weight_shape_list[0], error_info="In COCParallel, input_data should be of \
+ shape [m,k] and weight_shape_list should be [k,n]")
+ self.compute_output_shape_slice = infer_matmul_out_shape(list(input_data.shape), weight_shape_list)
+ self.output = self.allocate_output_memory()
+ self.output_slice = self.output.shape[0] // self.split_num
+ if compute_first:
+ self.comm_output = self.output
+ else:
+ self.comm_output = self.allocate_communicate_memory_for_communicate_first()
+ self.comm_slice = self.comm_output.shape[0] // self.split_num
+
+ def get_dim_size_after_comm(self, dim_size):
+ if self.comm_type == CommunicationType.ALL_GATHER:
+ dim_size[0] = dim_size[0] * self.world_size
+ elif self.comm_type == CommunicationType.REDUCE_SCATTER:
+ dim_size[0] = dim_size[0] // self.world_size
+ elif self.comm_type == CommunicationType.ALL_REDUCE:
+ pass
+ else:
+ raise ValueError("Invalid comm_type.")
+ return dim_size
+
+ def allocate_output_memory(self):
+ # No matter compute first or communicate first, the output shape remains the same
+ output_dim_size = self.get_dim_size_after_comm(self.compute_output_shape_slice)
+ output_ = torch.empty(output_dim_size, dtype=self.input_data.dtype,
+ device=torch.npu.current_device(), requires_grad=False)
+ return output_
+
+ def allocate_communicate_memory_for_communicate_first(self):
+ dim_size = list(self.input_data.shape)
+ dim_size = self.get_dim_size_after_comm(dim_size)
+ comm_output = torch.empty(dim_size, dtype=self.input_data.dtype,
+ device=torch.npu.current_device(), requires_grad=False)
+ return comm_output
+
+ def run_synchronize(self):
+ for work in self.works:
+ work.wait()
+ return self.comm_output
+
+ def run(self):
+ if self.compute_first:
+ return self.run_compute_first()
+ else:
+ return self.run_communicate_first()
+
+ def comm_fcn(self, i, input_):
+ if self.comm_type == CommunicationType.ALL_GATHER:
+ output_ = self.comm_output[i * self.comm_slice: (i + 1) * self.comm_slice]
+ work = torch.distributed._all_gather_base(output_, input_.contiguous(), group=self.group, async_op=True)
+ elif self.comm_type == CommunicationType.REDUCE_SCATTER:
+ output_ = self.comm_output[i * self.comm_slice: (i + 1) * self.comm_slice]
+ work = torch.distributed._reduce_scatter_base(output_, input_.contiguous(), group=self.group, async_op=True)
+ elif self.comm_type == CommunicationType.ALL_REDUCE:
+ # all_reduce interface currently only supports overwriting the same address of input
+ output_ = input_
+ work = torch.distributed.all_reduce(output_, group=self.group, async_op=True)
+ else:
+ raise ValueError("Invalid comm_type.")
+ return work, output_
+
+ def get_input_slice(self, i):
+ return self.input_data[i * self.input_slice: (i + 1) * self.input_slice]
+
+ def run_compute_first(self):
+ compute_outputs = []
+ for i in range(self.split_num):
+ input_slice = self.get_input_slice(i)
+ if self.comm_type == CommunicationType.ALL_REDUCE:
+ compute_output = self.output[i * self.comm_slice: (i + 1) * self.comm_slice]
+ self.compute_fcn(input_tensor=input_slice, output_tensor=compute_output)
+ else:
+ compute_output = self.compute_fcn(input_slice)
+ compute_outputs.append(compute_output)
+ work, _ = self.comm_fcn(i, compute_output)
+ self.works.append(work)
+
+ if self.synchronize:
+ return self.run_synchronize()
+ else:
+ return self.output, self.works
+
+ def get_output_slice(self, i):
+ return self.output[i * self.output_slice: (i + 1) * self.output_slice]
+
+ def run_communicate_first(self):
+ check_equal(self.synchronize, True, error_info="In COCParallel, must synchronize before return if communicate \
+ first")
+ pre_work = None
+ pre_output = None
+ outputs = []
+
+ for i in range(self.split_num):
+ input_slice = self.get_input_slice(i)
+ if self.comm_type == CommunicationType.ALL_REDUCE:
+ input_ = torch.empty_like(input_slice).copy_(input_slice)
+ else:
+ input_ = input_slice
+ work, output_i = self.comm_fcn(i, input_)
+ outputs.append(output_i)
+
+ self.works.append(work)
+
+ if pre_output is not None:
+ pre_work.wait()
+ self.compute_fcn(input_tensor=pre_output, output_tensor=self.get_output_slice(i - 1))
+
+ pre_work = work
+ pre_output = output_i
+
+ pre_work.wait()
+ self.compute_fcn(input_tensor=pre_output, output_tensor=self.get_output_slice(self.split_num - 1))
+ return self.output
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/matmul_soc_friendly.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/matmul_soc_friendly.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8ca65c8e6d4be3ac656ccd4cef0e6494faf86cb
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/matmul_soc_friendly.py
@@ -0,0 +1,136 @@
+from functools import reduce
+import torch
+from torch.nn import functional as F
+
+from .coc_utils import check_equal
+from .min_comm_cfg import min_comm_config
+
+
+def extract_info_from_mm_tensors(left, right):
+ m = reduce(lambda x, y: x * y, left.shape[:-1])
+ k = left.shape[-1]
+ check_equal(right.shape[0], k, error_info="For matmul_soc_friendly in CoC, the two input tensors left and right \
+ should be of shape [.., k] and [k, n] respectively")
+ n = reduce(lambda x, y: x * y, right.shape[1:])
+ return m, k, n
+
+
+def is_transposed(input_):
+ if input_.dim() < 2 or input_.dim() > 3:
+ raise RuntimeError("input tensor of is_tensor_transposed should be either 2- or 3-dimensional")
+ dim1 = input_.dim() - 1
+ dim2 = input_.dim() - 2
+ if input_.stride()[dim2] == 1 and input_.stride()[dim1] == reduce(lambda x, y: x * y, input_.shape[:-1]):
+ return True
+ else:
+ return False
+
+
+def ceil_div(a, b):
+ if b == 0:
+ raise ZeroDivisionError
+ return (a + b - 1) // b
+
+
+def ceil_coc(a, b):
+ if b == 0:
+ raise ZeroDivisionError
+ return ((a + b - 1) // b) * b
+
+
+# 512B aligned shape is soc friendly
+kPackage512 = 512
+kPackage32 = 32
+
+
+def compute_pad_num(single_dim_size, element_size, kPackage=kPackage512):
+ least_size = ceil_coc(single_dim_size, ceil_div(kPackage, element_size))
+ pad_num = least_size - single_dim_size
+ return pad_num
+
+
+# pad_dim could be in the form of 3 / 2 / 1 or -1 / -2 / -3
+def pad_tensor(input_, pad_num, pad_dim):
+ dim_size = input_.dim()
+ pad_list = [0] * (dim_size * 2)
+ pad_list[pad_dim * (-2) - 1] += pad_num
+ input_ = F.pad(input_, tuple(pad_list), mode='constant', value=0) if pad_num > 0 else input_
+ return input_
+
+
+def process_with_k_aligned(left, right, mn_aligned, is_left_transposed, is_right_transposed):
+ if is_left_transposed:
+ left = left.contiguous()
+ if not mn_aligned and not is_right_transposed:
+ main_grad = right.main_grad
+ right = right.t().contiguous().t()
+ right.main_grad = main_grad
+ return left, right
+
+
+def process_left_with_padding_k(left, is_left_transposed, k_pad_num):
+ if is_left_transposed:
+ left = pad_tensor(left.permute(2, 0, 1), k_pad_num, 0)
+ left = left.permute(1, 2, 0).contiguous()
+ else:
+ left = pad_tensor(left, k_pad_num, 2)
+ return left
+
+
+def process_right_with_padding_k(right, is_right_transposed, k_pad_num):
+ if is_right_transposed:
+ right = pad_tensor(right.t(), k_pad_num, 1)
+ right = right.t()
+ else:
+ right = pad_tensor(right, k_pad_num, 0)
+ return right
+
+
+def process_with_padding_k(left, right, is_left_transposed, is_right_transposed, k_pad_num):
+ left = process_left_with_padding_k(left, is_left_transposed, k_pad_num)
+ right = process_right_with_padding_k(right, is_right_transposed, k_pad_num)
+ return left, right
+
+
+def get_aligned_mm_inputs(left, right, sp_coef=1, parallel_num=min_comm_config.parallel_num):
+ """Get properly aligned tensors for matmul, according to soc friendly properties.
+
+ Inputs
+ left: the left tensor of matmul, in the shape of [m,k].
+ right: the right tensor of matmul, in the shape of [k,n].
+ sp_coef: the coefficient for compensating m due to any expected collective communications before the matmul.
+ parallel_num: the number of parts to divide the left tensor in, by row.
+
+ Outputs:
+ left: the properly processed left tensor for matmul, in the shape of [m,k].
+ right: the properly processed right tensor for matmul, in the shape of [k,n].
+
+ """
+
+ # The dtype of left and right tensors for matmul should be the same
+ check_equal(left.element_size(), right.element_size(), error_info="In matmul_soc_friendly of CoC, the dtype of \
+ left and right tensors for matmul should be the same")
+ element_size = left.element_size()
+
+ m, k, n = extract_info_from_mm_tensors(left, right)
+
+ # check if the shape of left or right matches its memory alignment
+ is_left_transposed = is_transposed(left)
+ is_right_transposed = is_transposed(right)
+
+ # After communication (if applicable) and dividing left tensor, check if m-dim and n-dim are both 512B aligned
+ is_mn_aligned_512b = ((m * sp_coef // parallel_num) * element_size) % kPackage512 == 0 and (
+ n * element_size) % kPackage512 == 0
+ # Check if k-dim is 512B aligned
+ is_k_aligned_512b = (k * element_size) % kPackage512 == 0
+ # Check if k-dim is 32B aligned
+ is_k_aligned_32b = (k * element_size) % kPackage32 == 0
+ # Compute the required amount of padding for k-dim, if already aligned then gives 0
+ k_pad_num = compute_pad_num(k, element_size, kPackage=kPackage512)
+
+ if is_k_aligned_512b:
+ return process_with_k_aligned(left, right, is_mn_aligned_512b, is_left_transposed, is_right_transposed)
+ elif is_mn_aligned_512b and not is_k_aligned_32b and min_comm_config.k_min <= k <= min_comm_config.k_max:
+ return process_with_padding_k(left, right, is_left_transposed, is_right_transposed, k_pad_num)
+
+ return left, right
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/min_comm_cfg.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/min_comm_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a042fd7cd49acc27e7de408a2bd5afb3b21ac97
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/min_comm_cfg.py
@@ -0,0 +1,224 @@
+import ast
+import os
+from enum import Enum
+import torch
+import torch_npu
+import torch.nn.functional as F
+from megatron.training import get_args
+
+
+def column_forward(self, input_, weight, column_parallel_function=None, check_fcn=None):
+ if check_fcn is not None:
+ check_fcn()
+ bias = self.bias if not self.skip_bias_add else None
+ input_parallel = input_
+ use_weight = self.weight if weight is None else weight
+ if hasattr(self, "norm") and self.norm:
+ use_weight = F.normalize(self.weight)
+ output_parallel = column_parallel_function.apply(
+ input_parallel,
+ use_weight,
+ bias
+ )
+ output = output_parallel
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+
+def row_forward(self, input_, row_parallel_function=None, check_fcn=None):
+ if check_fcn is not None:
+ check_fcn()
+ input_parallel = input_
+ output_parallel = row_parallel_function.apply(
+ input_parallel,
+ self.weight,
+ None
+ )
+ output = output_parallel
+ if not self.skip_bias_add:
+ output = output + self.bias if self.bias is not None else output
+ output_bias = None
+ else:
+ output_bias = self.bias
+ return output, output_bias
+
+
+class ModuleType(Enum):
+ ORIGINAL_ALL_REDUCE = 0
+ ORIGINAL_SEQ_PARALLEL = 1
+ REWRITE_ALL_REDUCE = 2
+ REWRITE_SEQ_PARALLEL = 3
+ COC_FOR_ALL_REDUCE = 4
+ COC_FOR_SEQ_PARALLEL = 5
+
+
+class MinCommConfig:
+ def __init__(self):
+ # basic settings acquired from environmental variables
+ # default module_type is ModuleType.ORIGINAL_SEQ_PARALLEL
+ global_args = get_args()
+
+ self.module_type: ModuleType = ModuleType.ORIGINAL_SEQ_PARALLEL
+ self.coc_mode = global_args.coc_mode
+ self.parallel_num = global_args.coc_parallel_num
+ self.coc_fused_kernel = global_args.coc_fused_kernel
+
+ # configurations registered from framework
+ self.ColumnParallelLinear = None
+ self.RowParallelLinear = None
+ self.column_parallel_forward = None
+ self.row_parallel_forward = None
+ self.tp_group_fcn = None
+ self.tp_world_size_fcn = None
+ self.tp_rank_fcn = None
+ self.all_reduce = None
+ self.reduce_scatter_along_first_dim = None
+ self.gather_along_first_dim = None
+ self.prefix = None
+ self.check_fcn = None
+ self.tp_enabled = True
+ self.sequence_parallel_enabled = True
+
+ # configurations manually set by users in user_config.py
+ self.k_min = 1024
+ self.k_max = 4096
+ self.all_gather_recomputation_enabled = False
+ self.print_tensor_value_enabled = False
+ self.matmul_soc_friendly_enabled = True
+ self.customized_coc_dict = {}
+ self.enable_coc_in_column_backward = True
+
+ def print_settings(self):
+ if self.coc_fused_kernel:
+ enable_coc_in_column_backward = True if self.enable_coc_in_column_backward else False
+ else:
+ enable_coc_in_column_backward = False
+ if self.coc_fused_kernel:
+ settings_dict = {
+ "is coc turned on": True,
+ "use script or use fused kernel": "fused kernel",
+ "is sequence parallel enabled": self.sequence_parallel_enabled,
+ "is coc enabled in column backward": enable_coc_in_column_backward
+ }
+ elif "ORIGINAL" in self.module_type.name:
+ settings_dict = {
+ "is coc turned on": False
+ }
+ else:
+ settings_dict = {
+ "is coc turned on": True,
+ "use script or use fused kernel": "script",
+ "coc mode": self.coc_mode,
+ "parallel num": self.parallel_num,
+ "module type": self.module_type.name,
+ "is sequence parallel enabled": self.sequence_parallel_enabled,
+ "if get aligned mm inputs": self.matmul_soc_friendly_enabled
+ }
+ if torch.npu.current_device() == 0:
+ print("\n-----------------------------COC Settings: ------------------------------------")
+ for key, value in settings_dict.items():
+ print(f"{key}: {value}")
+ print("-------------------------------------------------------------------------------\n")
+
+ @property
+ def tp_rank(self):
+ return self.tp_rank_fcn()
+
+ @property
+ def tp_group(self):
+ return self.tp_group_fcn()
+
+ @property
+ def tp_world_size(self):
+ return self.tp_world_size_fcn()
+
+ def register_tp_get_functions(self, tp_group_fcn, tp_world_size_fcn, tp_rank_fcn):
+ self.tp_group_fcn = tp_group_fcn
+ self.tp_world_size_fcn = tp_world_size_fcn
+ self.tp_rank_fcn = tp_rank_fcn
+
+ def register_class(self, column_parallel_linear, row_parallel_linear):
+ self.ColumnParallelLinear = column_parallel_linear
+ self.RowParallelLinear = row_parallel_linear
+
+ def register_mappings(self, _all_reduce, _reduce_scatter_along_first_dim, _gather_along_first_dim):
+ self.all_reduce = _all_reduce
+ self.reduce_scatter_along_first_dim = _reduce_scatter_along_first_dim
+ self.gather_along_first_dim = _gather_along_first_dim
+
+ def replace_forward_functions_by_autograd_class(self, column_autograd_class, row_autograd_class):
+ def column_parallel_forward(x, input_, weight=None):
+ return column_forward(x, input_, weight, column_parallel_function=column_autograd_class,
+ check_fcn=self.check_fcn)
+
+ def row_parallel_forward(x, y):
+ return row_forward(x, y, row_parallel_function=row_autograd_class, check_fcn=self.check_fcn)
+
+ self.column_parallel_forward = column_parallel_forward
+ self.row_parallel_forward = row_parallel_forward
+ self.ColumnParallelLinear.forward = self.column_parallel_forward
+ self.RowParallelLinear.forward = self.row_parallel_forward
+
+ def register_sequence_parallel_switch(self, sequence_parallel_enabled):
+ self.sequence_parallel_enabled = sequence_parallel_enabled
+
+ def register_check_fcn(self, check_fcn):
+ self.check_fcn = check_fcn
+
+ def register_customized_coc(self, customized_coc):
+ if len(customized_coc) == 0:
+ return
+ for coc_shape_yaml_str in customized_coc.keys():
+ key_list = ast.literal_eval(coc_shape_yaml_str)
+ coc_shape_key_str = str(key_list)
+ self.customized_coc_dict.update({coc_shape_key_str: customized_coc[coc_shape_yaml_str]})
+ print("self.customized_coc_dict: ", self.customized_coc_dict)
+
+ def register_matmul_soc_friendly_setting(self, matmul_soc_friendly, k_min, k_max):
+ self.matmul_soc_friendly_enabled = matmul_soc_friendly
+ self.k_min = k_min
+ self.k_max = k_max
+
+ def register_all_gather_recomputation_switch(self, all_gather_recomputation_enabled):
+ self.all_gather_recomputation_enabled = all_gather_recomputation_enabled
+
+ def register_print_tensor_value_switch(self, print_tensor_value_enabled):
+ self.print_tensor_value_enabled = print_tensor_value_enabled
+
+ def register_column_backward_coc_switch(self, enable_coc_in_column_backward):
+ self.enable_coc_in_column_backward = enable_coc_in_column_backward
+
+ def acquire_module_type(self, tp_size):
+ sequence_parallel_types = [ModuleType.ORIGINAL_SEQ_PARALLEL,
+ ModuleType.REWRITE_SEQ_PARALLEL,
+ ModuleType.COC_FOR_SEQ_PARALLEL]
+ all_reduce_types = [ModuleType.ORIGINAL_ALL_REDUCE,
+ ModuleType.REWRITE_ALL_REDUCE,
+ ModuleType.COC_FOR_ALL_REDUCE]
+
+ if self.parallel_num not in [1, 2, 4, 8]:
+ raise RuntimeError("coc_parallel_num must be either 1, 2, 4 or 8. Current value not supported")
+ if self.coc_mode not in [-1, 0, 1, 2]:
+ raise RuntimeError("coc_mode must be either 0, 1, or 2. Current value not supported")
+
+ if self.coc_mode == -1:
+ self.coc_mode = 0 if self.parallel_num == 1 else 2
+
+ if tp_size == 1:
+ self.coc_mode = 0
+ self.parallel_num = 1
+
+ if self.sequence_parallel_enabled:
+ self.module_type = sequence_parallel_types[self.coc_mode]
+ else:
+ self.module_type = all_reduce_types[self.coc_mode]
+
+ if "COC" in self.module_type.name:
+ self.prefix = f"module_{self.module_type.name}_parallel_num_{self.parallel_num}"
+ else:
+ self.prefix = f"module_{self.module_type.name}"
+
+ self.print_settings()
+
+
+min_comm_config = MinCommConfig()
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/rewrite_parallel_linears_all_reduce.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/rewrite_parallel_linears_all_reduce.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31273cf4a094eb5c47b3e2a78b829916fcc4879
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/rewrite_parallel_linears_all_reduce.py
@@ -0,0 +1,59 @@
+import torch
+
+from .min_comm_cfg import min_comm_config
+from .coc_utils import set_context, reshape_to_2D, is_grad_needed
+
+
+class RewriteColumnAllReduceFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ set_context(ctx, input_, weight, bias)
+ output_parallel = torch.matmul(input_, weight.t())
+ if bias is not None:
+ output_parallel = output_parallel + bias
+ return output_parallel
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+
+ grad_input = grad_output.matmul(weight)
+ handle = torch.distributed.all_reduce(grad_input, group=min_comm_config.tp_group, async_op=True)
+ grad_weight, grad_bias = None, None
+
+ if is_grad_weight_needed:
+ grad_output = reshape_to_2D(grad_output)
+ grad_weight = grad_output.t().matmul(reshape_to_2D(input_))
+ handle.wait()
+ grad_bias = grad_output.sum(dim=0) if is_grad_bias_needed and ctx.use_bias else None
+ else:
+ handle.wait()
+
+ return grad_input, grad_weight, grad_bias
+
+
+class RewriteRowAllReduceFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ set_context(ctx, input_, weight, bias)
+ output_ = torch.matmul(input_, weight.t())
+ output_parallel = min_comm_config.all_reduce(output_)
+ if bias is not None:
+ output_parallel = output_parallel + bias
+ return output_parallel
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ total_input, weight = ctx.saved_tensors
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+
+ grad_input = grad_output.matmul(weight)
+ grad_weight, grad_bias = None, None
+
+ if is_grad_weight_needed:
+ grad_output = reshape_to_2D(grad_output)
+ grad_weight = grad_output.t().matmul(reshape_to_2D(total_input))
+ grad_bias = grad_output.sum(dim=0) if is_grad_bias_needed and ctx.use_bias else None
+
+ return grad_input, grad_weight, grad_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/rewrite_parallel_linears_sequence_parallel.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/rewrite_parallel_linears_sequence_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..8685be775a4b4738f5157c63bdf91a2c5d379266
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/rewrite_parallel_linears_sequence_parallel.py
@@ -0,0 +1,82 @@
+import torch
+
+from .min_comm_cfg import min_comm_config
+from .coc_utils import set_context, async_gather_along_first_dim, reshape_to_2D, is_grad_needed
+
+
+class RewriteColumnSeqParallelFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ set_context(ctx, input_, weight, bias)
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] * min_comm_config.tp_world_size
+
+ all_gather_buffer = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ torch.distributed._all_gather_base(all_gather_buffer, input_, group=min_comm_config.tp_group)
+ total_input = all_gather_buffer
+
+ output_parallel = torch.matmul(total_input, weight.t())
+ if bias is not None:
+ output_parallel = output_parallel + bias
+ return output_parallel
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+ tp_group = min_comm_config.tp_group
+ if is_grad_weight_needed:
+ handle_all_gather, total_input = async_gather_along_first_dim(input_, tp_group,
+ min_comm_config.tp_world_size)
+ grad_input = grad_output.matmul(weight)
+ handle_all_gather.wait()
+ dim_size = list(input_.size())
+ sub_grad_input = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device(),
+ requires_grad=False)
+ # reduce_scatter
+ handle_reduce_scatter = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, group=tp_group,
+ async_op=True)
+ grad_output = reshape_to_2D(grad_output)
+ grad_weight = grad_output.t().matmul(reshape_to_2D(total_input))
+ handle_reduce_scatter.wait()
+ grad_bias = grad_output.sum(dim=0) if is_grad_bias_needed and ctx.use_bias else None
+ else:
+ grad_input = grad_output.matmul(weight)
+ dim_size = list(input_.size())
+ sub_grad_input = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device(),
+ requires_grad=False)
+ # reduce_scatter
+ handle_reduce_scatter = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, group=tp_group,
+ async_op=True)
+ handle_reduce_scatter.wait()
+ grad_weight, grad_bias = None, None
+ return sub_grad_input, grad_weight, grad_bias
+
+
+class RewriteRowSeqParallelFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias):
+ set_context(ctx, input_, weight, bias)
+ # ctx.world_size is needed for the case: rewrite forward (manually skipped) with coc backward
+ ctx.world_size = min_comm_config.tp_world_size
+ output_ = torch.matmul(input_, weight.t())
+ output_parallel = min_comm_config.reduce_scatter_along_first_dim(output_)
+ if bias is not None:
+ output_parallel = output_parallel + bias
+ return output_parallel
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ total_input, weight = ctx.saved_tensors
+ grad_output = min_comm_config.gather_along_first_dim(grad_output)
+ is_grad_weight_needed, is_grad_bias_needed = is_grad_needed(ctx.needs_input_grad)
+
+ grad_input = grad_output.matmul(weight)
+ grad_weight, grad_bias = None, None
+
+ if is_grad_weight_needed:
+ grad_output = reshape_to_2D(grad_output)
+ grad_weight = grad_output.t().matmul(reshape_to_2D(total_input))
+ grad_bias = grad_output.sum(dim=0) if is_grad_bias_needed and ctx.use_bias else None
+
+ return grad_input, grad_weight, grad_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/user_config.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/user_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d613e90e44711a9991b24ed9c2cacc28416add14
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/lcal_coc/user_config.py
@@ -0,0 +1,110 @@
+import torch
+import torch_npu
+
+from .min_comm_cfg import min_comm_config, ModuleType
+from .coc_parallel_linears_all_reduce_fused import FusedCOCRowAllReduceFunction
+from .coc_parallel_linears_all_reduce import COCColumnAllReduceFunction, COCRowAllReduceFunction
+from .coc_parallel_linears_sequence_parallel import COCColumnSeqParallelFunction, COCRowSeqParallelFunction
+from .rewrite_parallel_linears_all_reduce import RewriteColumnAllReduceFunction, RewriteRowAllReduceFunction
+from .rewrite_parallel_linears_sequence_parallel import RewriteColumnSeqParallelFunction, RewriteRowSeqParallelFunction
+from .coc_parallel_linears_sequence_parallel_fused import FusedCOCColumnSeqParallelFunction, FusedCOCRowSeqParallelFunction
+
+
+coc_cfgs = {
+ 'recompute_all_gather': True,
+ 'matmul_soc_friendly': True,
+ 'print_tensor_value_open': False,
+ 'customized_coc': {},
+ 'enable_coc_in_column_backward': False,
+ 'k_min': 1024,
+ 'k_max': 4096,
+}
+
+
+def check_config_valid():
+ if min_comm_config.sequence_parallel_enabled:
+ if min_comm_config.module_type not in [ModuleType.ORIGINAL_SEQ_PARALLEL,
+ ModuleType.REWRITE_SEQ_PARALLEL,
+ ModuleType.COC_FOR_SEQ_PARALLEL]:
+ raise ValueError("In CoC, the config of sequence parallel is not valid")
+ else:
+ if min_comm_config.module_type not in [ModuleType.ORIGINAL_ALL_REDUCE,
+ ModuleType.REWRITE_ALL_REDUCE,
+ ModuleType.COC_FOR_ALL_REDUCE]:
+ raise ValueError("In CoC, the config of sequence parallel is not valid")
+
+
+def get_value_from_cfg(attr_name):
+ if attr_name not in coc_cfgs.keys():
+ raise RuntimeError("Lack attr_name: ", attr_name)
+ return coc_cfgs[attr_name]
+
+
+def print_on_device0(msg):
+ if torch.npu.current_device() == 0:
+ print(msg)
+
+
+def initialize_coc_from_cfg(cfg):
+ from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
+ from megatron.core.parallel_state import (
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_world_size,
+ get_tensor_model_parallel_rank
+ )
+ from megatron.core.tensor_parallel.mappings import (
+ _reduce,
+ _reduce_scatter_along_first_dim,
+ _gather_along_first_dim
+ )
+ min_comm_config.register_tp_get_functions(get_tensor_model_parallel_group,
+ get_tensor_model_parallel_world_size,
+ get_tensor_model_parallel_rank)
+ min_comm_config.register_class(ColumnParallelLinear,
+ RowParallelLinear)
+ min_comm_config.register_mappings(_reduce,
+ _reduce_scatter_along_first_dim,
+ _gather_along_first_dim)
+ min_comm_config.register_sequence_parallel_switch(cfg.sequence_parallel)
+
+ min_comm_config.register_customized_coc(get_value_from_cfg('customized_coc'))
+ min_comm_config.register_matmul_soc_friendly_setting(get_value_from_cfg('matmul_soc_friendly'),
+ int(get_value_from_cfg('k_min')),
+ int(get_value_from_cfg('k_max')))
+ min_comm_config.register_all_gather_recomputation_switch(get_value_from_cfg('recompute_all_gather'))
+ min_comm_config.register_print_tensor_value_switch(get_value_from_cfg('print_tensor_value_open'))
+ min_comm_config.register_column_backward_coc_switch(get_value_from_cfg('enable_coc_in_column_backward'))
+ min_comm_config.register_check_fcn(check_config_valid)
+ min_comm_config.acquire_module_type(cfg.tensor_model_parallel_size)
+
+ map_type2autograd_class = {
+ ModuleType.REWRITE_SEQ_PARALLEL: [RewriteColumnSeqParallelFunction,
+ RewriteRowSeqParallelFunction],
+ ModuleType.REWRITE_ALL_REDUCE: [RewriteColumnAllReduceFunction,
+ RewriteRowAllReduceFunction],
+ ModuleType.COC_FOR_SEQ_PARALLEL: [COCColumnSeqParallelFunction,
+ COCRowSeqParallelFunction],
+ ModuleType.COC_FOR_ALL_REDUCE: [COCColumnAllReduceFunction,
+ COCRowAllReduceFunction]
+ }
+
+ if min_comm_config.coc_fused_kernel:
+ print_on_device0("COC REPLACE WITH COC FUSED KERNEL SCRIPT!")
+ if min_comm_config.sequence_parallel_enabled:
+ min_comm_config.replace_forward_functions_by_autograd_class(FusedCOCColumnSeqParallelFunction,
+ FusedCOCRowSeqParallelFunction)
+ else:
+ min_comm_config.replace_forward_functions_by_autograd_class(COCColumnAllReduceFunction,
+ FusedCOCRowAllReduceFunction)
+ elif "ORIGINAL" not in min_comm_config.module_type.name:
+ if "REWRITE" in min_comm_config.module_type.name:
+ print_on_device0("COC REPLACE WITH REWRITE SCRIPT!")
+ else:
+ print_on_device0("COC REPLACE WITH COC SCRIPT!")
+ parallel_linear_autograd_class = map_type2autograd_class.get(min_comm_config.module_type)
+ if parallel_linear_autograd_class is None:
+ raise RuntimeError("Module type is not matched.")
+ min_comm_config.replace_forward_functions_by_autograd_class(parallel_linear_autograd_class[0],
+ parallel_linear_autograd_class[1])
+ else:
+ print_on_device0("COC REPLACE NONE!")
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/mapping.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eef0d9be53bb51b066162445c5dbe7b7f44d7d3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/mapping.py
@@ -0,0 +1,36 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from megatron.core.tensor_parallel.mappings import _reduce
+
+
+class _ReduceFromModelParallelRegion_Nd(torch.autograd.Function):
+ @staticmethod
+ def symbolic(graph, input_):
+ return _reduce(input_)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _reduce(input_)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _reduce(grad_output)
+
+
+def reduce_from_tensor_model_parallel_region_nd(input_):
+ return _ReduceFromModelParallelRegion_Nd.apply(input_)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/random.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..efaa148292c40bc2d88a729f68f112a6d236e4e7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/random.py
@@ -0,0 +1,300 @@
+import os
+from functools import wraps
+from typing import List, Union
+import torch
+from torch import _C
+from torch_npu.npu import _lazy_call, device as device_ctx_manager
+from torch.utils.checkpoint import _get_autocast_kwargs
+from megatron.training import get_args
+from megatron.core.tensor_parallel.utils import gather_split_1d_tensor
+from megatron.core.tensor_parallel.random import get_cuda_rng_tracker
+from megatron.core.utils import safely_set_viewless_tensor_data
+from torch.utils.checkpoint import detach_variable
+from megatron.core.parallel_state import (
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_world_size,
+ is_pipeline_last_stage,
+ get_virtual_pipeline_model_parallel_rank,
+)
+from mindspeed.core.tensor_parallel.checkpoint_manager import get_pipeline_checkpoint_manager
+
+
+def _set_cuda_rng_state(new_state, device=-1):
+ if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
+ # older PyTorch
+ def cb():
+ with device_ctx_manager(device):
+ _C._cuda_setRNGState(new_state)
+
+ else:
+ # newer PyTorch
+ if device == -1:
+ device = torch.device('cuda')
+ elif isinstance(device, str):
+ device = torch.device(device)
+ elif isinstance(device, int):
+ device = torch.device('cuda', device)
+
+ def cb():
+ idx = device.index
+ if idx is None:
+ idx = torch.cuda.current_device()
+ default_generator = torch.npu.default_generators[idx]
+ default_generator.set_state(new_state)
+
+ _lazy_call(cb)
+
+
+def checkpoint_function_backward(ctx, *args):
+ global_args = get_args()
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError(
+ "Checkpointing is not compatible with .grad(), "
+ "please use .backward() if possible"
+ )
+ inputs = ctx.saved_tensors
+ if ctx.distribute_saved_activations:
+ safely_set_viewless_tensor_data(
+ inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
+ )
+
+ # Store the current states.
+ bwd_cpu_rng_state = torch.get_rng_state()
+ bwd_cuda_rng_state = torch.cuda.get_rng_state()
+ bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ # Set the states to what it used to be before the forward pass.
+ torch.set_rng_state(ctx.fwd_cpu_rng_state)
+ _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
+
+ # Compute the forward pass.
+ flops_counter = None
+ if global_args.op_cal_tflops:
+ from mindspeed.core.training import get_flops_counter
+ flops_counter = get_flops_counter()
+ flops_counter.pause()
+
+ detached_inputs = detach_variable(inputs)
+ from mindspeed.auto_tuning.module.parse.recompute_parser import get_recompute_parser, call_hook_func
+ recompute_parser = get_recompute_parser()
+
+ if (
+ recompute_parser.skip_profiling_step <= recompute_parser.profiling_step <= recompute_parser.stop_profiling_step
+ and os.getenv('OOTB_OPTIMIZER_PROFILING', 'FALSE') == 'TRUE'):
+ call_hook_func()
+ with torch.enable_grad():
+ outputs = ctx.run_function(*detached_inputs)
+ # remove hook
+ for hook_handle in recompute_parser.modules_hooks:
+ hook_handle.remove()
+ recompute_parser.modules_hooks.clear()
+
+ if global_args.op_cal_tflops:
+ flops_counter.resume()
+
+ # Set the states back to what it was at the start of this function.
+ torch.set_rng_state(bwd_cpu_rng_state)
+ _set_cuda_rng_state(bwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
+
+ if isinstance(outputs, torch.Tensor):
+ outputs = (outputs,)
+
+ # filter out non tensor outputs for backward pass
+ outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]) and x[0].grad_fn is not None, zip(outputs, args)))
+ torch.autograd.backward(outputs, args)
+ grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
+ return (None, None) + grads
+
+
+class CheckpointFunctionWithoutOutput(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, checkpoint, *args):
+ with torch.no_grad():
+ outputs = run_function(*args)
+
+ # Store everything
+ ctx.save_for_backward(*detach_variable(args))
+ checkpoint.ctx = ctx
+
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ inputs = ctx.saved_tensors
+ outputs = ctx.outputs
+ torch.autograd.backward(outputs, args)
+ ctx.outputs = None
+ grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
+ return (None, None) + grads
+
+
+class CheckpointWithoutOutput:
+ def __init__(self):
+ self.run_function = None
+ self.fwd_cpu_rng_state = None
+ self.fwd_cuda_rng_state = None
+ self.fwd_cuda_rng_state_tracker = None
+ self.outputs = None
+
+ def checkpoint(self, run_function, distribute_saved_activations, *args):
+ self.run_function = run_function
+
+ if distribute_saved_activations:
+ raise RuntimeError(
+ "CheckpointFunctionWithoutOutput does not support "
+ "distribute_saved_activations"
+ )
+
+ #Copy the rng states.
+ self.fwd_cpu_rng_state = torch.get_rng_state()
+ self.fwd_cuda_rng_state = torch.cuda.get_rng_state()
+ self.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ outputs = CheckpointFunctionWithoutOutput.apply(run_function, self, *args)
+ self.outputs = outputs
+ if isinstance(self.outputs, torch.Tensor):
+ self.outputs = (self.outputs,)
+
+ return outputs
+
+ def discard_output(self):
+ for output in self.outputs:
+ output.untyped_storage().resize_(0)
+
+ def recompute(self, _):
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError(
+ "Checkpointing is not compatible with .grad(), "
+ "please use .backward() if possible"
+ )
+
+ # Store the current states.
+ cur_cpu_rng_state = torch.get_rng_state()
+ cur_cuda_rng_state = torch.cuda.get_rng_state()
+ cur_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ # Set the states to what it used to be before the forward pass.
+ torch.set_rng_state(self.fwd_cpu_rng_state)
+ _set_cuda_rng_state(self.fwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(self.fwd_cuda_rng_state_tracker)
+
+ with torch.enable_grad():
+ outputs = self.run_function(*self.ctx.saved_tensors)
+ self.run_function = None
+ self.fwd_cpu_rng_state = None
+ self.fwd_cuda_rng_state = None
+ self.fwd_cuda_rng_state_tracker = None
+
+ # Set the states back to what it was at the start of this function.
+ torch.set_rng_state(cur_cpu_rng_state)
+ _set_cuda_rng_state(cur_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(cur_cuda_rng_state_tracker)
+
+ if isinstance(outputs, torch.Tensor):
+ outputs = (outputs,)
+
+ for output, recomputation_output in zip(self.outputs, outputs):
+ output_size = recomputation_output.untyped_storage().size()
+ output.untyped_storage().resize_(output_size)
+ with torch.no_grad():
+ output.untyped_storage().copy_(recomputation_output.untyped_storage())
+
+ self.ctx.outputs = outputs
+ self.outputs = None
+ self.ctx = None
+
+
+
+class RngStateContext:
+ def __init__(self, cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker):
+ self.fwd_cpu_rng_state = cpu_rng_state
+ self.fwd_cuda_rng_state = cuda_rng_state
+ self.fwd_cuda_rng_state_tracker = cuda_rng_state_tracker
+
+
+class CheckpointFunctionRipipe(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, distribute_saved_activations, *args):
+ fwd_rng_state = RngStateContext(torch.get_rng_state(), torch.cuda.get_rng_state(), get_cuda_rng_tracker().get_states())
+ with torch.no_grad():
+ outputs = run_function(*args)
+
+ # Store everything.
+ ctx.detached_inputs = detach_variable(args)
+
+ def recompute():
+ # Store the current states.
+ bwd_cpu_rng_state = torch.get_rng_state()
+ bwd_cuda_rng_state = torch.cuda.get_rng_state()
+ bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
+
+ # Set the states to what it used to be before the forward pass.
+ torch.set_rng_state(fwd_rng_state.fwd_cpu_rng_state)
+ _set_cuda_rng_state(fwd_rng_state.fwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(fwd_rng_state.fwd_cuda_rng_state_tracker)
+
+ with torch.enable_grad():
+ outputs = run_function(*ctx.detached_inputs)
+ ctx.outputs = outputs
+
+ # Set the states back to what it was at the start of this function.
+ torch.set_rng_state(bwd_cpu_rng_state)
+ _set_cuda_rng_state(bwd_cuda_rng_state)
+ get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
+ if get_pipeline_checkpoint_manager().do_pre_recompute:
+ get_pipeline_checkpoint_manager().add_recompute(recompute)
+ ctx.recompute_func = recompute
+
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError(
+ "Checkpointing is not compatible with .grad(), "
+ "please use .backward() if possible"
+ )
+ if not hasattr(ctx, 'outputs'):
+ if get_pipeline_checkpoint_manager().do_pre_recompute:
+ global_args = get_args()
+ vpp_rank = get_virtual_pipeline_model_parallel_rank()
+ # For last vpp chunk of last pp stage, we don't advance its recomputation.
+ if global_args.recompute_in_advance and is_pipeline_last_stage():
+ get_pipeline_checkpoint_manager().recompute_next(vpp_rank)
+ if not hasattr(ctx, 'outputs'):
+ raise RuntimeError(f"rank-{torch.distributed.get_rank()}: recompute is not done")
+ else:
+ ctx.recompute_func()
+
+ outputs = ctx.outputs
+ detached_inputs = ctx.detached_inputs
+ ctx.outputs = None
+ ctx.detached_inputs = None
+ ctx.recompute_func = None
+
+ if isinstance(outputs, torch.Tensor):
+ outputs = (outputs,)
+
+ # filter out non tensor outputs for backward pass
+ outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args)))
+ torch.autograd.backward(outputs, args)
+ grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
+ return (None, None) + grads
+
+
+def checkpoint_wrapper(checkpoint):
+ @wraps(checkpoint)
+ def wrapper(function, distribute_saved_activations, *args):
+ if not get_pipeline_checkpoint_manager().open_ri_pipe:
+ return checkpoint(function, distribute_saved_activations, *args)
+ if not get_pipeline_checkpoint_manager().chunk_do_recompute:
+ return function(*args)
+
+ if distribute_saved_activations:
+ raise RuntimeError("no distributed")
+
+ return CheckpointFunctionRipipe.apply(function, distribute_saved_activations, *args)
+
+ return wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/layernorm_2d.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/layernorm_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c2d1ecf1c73c7ab738e71d66b56c8bc932a25ef
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/layernorm_2d.py
@@ -0,0 +1,179 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from typing import Any
+from typing import Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from torch.cuda.amp import custom_bwd
+from torch.cuda.amp import custom_fwd
+from torch.nn import Parameter
+
+from megatron.core.utils import divide
+from mindspeed.core.tensor_parallel.comm_group_api import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+
+
+class LayerNorm2D(torch.nn.Module):
+ """LayerNorm2D layer with row and column parallelism.
+
+ Arguments:
+ hidden_size (int): input normalized size from an expected input of size
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ bias: (bool, optional): Whether to add a bias, defaults to ``True``.
+ dtype: (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
+ last_dim_split_comm_intf: Reduce scatter comm intf.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ eps: float = 1e-5,
+ bias: bool = True,
+ dtype=None,
+ last_dim_split_comm_intf: CollectiveCommIntf = TPYCollectiveComm(),
+ ) -> None:
+ super(LayerNorm2D, self).__init__()
+ # layer norm config
+ self.hidden_size = hidden_size
+ self.epsilon = eps
+
+ # parallel setting
+ self.last_dim_split_comm_intf = last_dim_split_comm_intf
+ self.rs_comm_world_sz = self.last_dim_split_comm_intf.get_comm_group_world_size()
+ # partitioning dimension
+ self.partitioned_dim = divide(hidden_size, self.rs_comm_world_sz)
+ # create parameters
+ factory_kwargs = {"device": torch.cuda.current_device(), "dtype": dtype}
+
+ # [H/(xy)]
+ self.weight = Parameter(torch.ones(self.partitioned_dim, **factory_kwargs))
+ if bias:
+ # [H/(xy)]
+ self.bias = Parameter(torch.zeros(self.partitioned_dim, **factory_kwargs))
+ else:
+ self.bias = None
+
+ # set sequence parallelism flag on weight and bias parameters
+ setattr(self.weight, "2d_tp", True)
+ setattr(self.bias, "2d_tp", True)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return _ParallelLayerNorm2D.apply(
+ x,
+ self.weight,
+ self.bias,
+ self.epsilon,
+ self.hidden_size,
+ self.last_dim_split_comm_intf,
+ )
+
+
+class _ParallelLayerNorm2D(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ input_: Tensor,
+ weight,
+ bias,
+ epsilon,
+ hidden_size: int,
+ last_dim_split_comm_intf: CollectiveCommIntf
+ ) -> Tensor:
+ """
+
+ :param ctx:
+ :param input_: [s/(cp*x), b, H/y]
+ :param weight: [H/(xy)]
+ :param bias: [H/(xy)]
+ :param epsilon:
+ :param hidden_size: H
+ :param last_dim_split_comm_intf:
+ :return:
+ """
+ # [s/(cp*x), b, H/y]---> [s/(cp*x), b, 1]
+ e_x = torch.sum(input_, dim=-1, keepdim=True)
+ # [s/(cp*x), b, 1]
+ handle_ex = torch.distributed.all_reduce(
+ e_x, group=last_dim_split_comm_intf.get_comm_group(), async_op=True
+ )
+
+ # [s/(cp*x), b, H/y]---> [s/(cp*x), b, 1]
+ var_x = torch.sum(input_.float().pow(2), dim=-1, keepdim=True)
+ if handle_ex:
+ handle_ex.wait()
+
+ handle_var = torch.distributed.all_reduce(
+ var_x, group=last_dim_split_comm_intf.get_comm_group(), async_op=True
+ )
+
+ input_.sub_(e_x.div_(hidden_size))
+ e_x.mul_(e_x)
+ if handle_var:
+ handle_var.wait()
+
+ var_x = torch.rsqrt(var_x.div_(hidden_size).sub_(e_x).add_(epsilon))
+
+ ctx.hidden_size = hidden_size
+ ctx.last_dim_split_comm_intf = last_dim_split_comm_intf
+ # [s/(cp*x), b, H/y] * [s/(cp*x), b, 1] --> [s/(cp*x), b, H/y]
+ norm_x = torch.mul(input_, var_x)
+
+ if bias is not None:
+ # bias + weight * norm, [H/y] + [H/y] * [s/(cp*x), b, H/y]
+ output = torch.addcmul(bias, weight, norm_x)
+ else:
+ output = torch.mul(weight, norm_x)
+
+ ctx.save_for_backward(norm_x, var_x, bias, weight)
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+ x, var_x, bias, weight = ctx.saved_tensors
+ # calculate grad_bias
+ if bias is None:
+ grad_bias = None
+ else:
+ grad_bias = output_grad.sum(dim=(0, 1))
+
+ # calculate grad_input
+ grad_norm_x = torch.mul(output_grad, weight)
+ output_grad_sum = torch.sum(grad_norm_x, dim=-1, keepdim=True)
+ handle_grad_sum = torch.distributed.all_reduce(
+ output_grad_sum, group=ctx.last_dim_split_comm_intf.get_comm_group(), async_op=True
+ )
+ output_grad_mul_x_sum = torch.sum(grad_norm_x * x, dim=-1, keepdim=True)
+
+ # calculate grad_weight
+ grad_weight = torch.mul(output_grad, x)
+ grad_weight = grad_weight.sum(dim=(0, 1))
+
+ if handle_grad_sum:
+ handle_grad_sum.wait()
+
+ handle_grad_mul_x = torch.distributed.all_reduce(
+ output_grad_mul_x_sum, group=ctx.last_dim_split_comm_intf.get_comm_group(), async_op=True
+ )
+ output_grad_sum.div_(ctx.hidden_size)
+ grad_input = grad_norm_x.sub(output_grad_sum)
+ if handle_grad_mul_x:
+ handle_grad_mul_x.wait()
+
+ grad_input = (grad_input - x * (output_grad_mul_x_sum / ctx.hidden_size)) * var_x
+ return grad_input, grad_weight, grad_bias, None, None, None
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/linear_2d_moe_split_along_first_dim.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/linear_2d_moe_split_along_first_dim.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce89918fd71d185e7ec3554a15f702c69196fe1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/linear_2d_moe_split_along_first_dim.py
@@ -0,0 +1,298 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+
+import torch
+import torch_npu
+from torch.cuda.amp import custom_bwd
+from torch.cuda.amp import custom_fwd
+
+from mindspeed.core.tensor_parallel.comm_autograd_function import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_group_api import OverlapCollectiveIntf, TPYCollectiveComm
+from mindspeed.core.tensor_parallel.comm_utils import async_gather_tensors, sync_gather_along_last_dim, \
+ _split_along_last_dim
+from mindspeed.core.tensor_parallel.comm_utils import async_reduce_scatter_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_gather_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_reduce_scatter_along_first_dim
+
+G_FORWARD_PADDING_SIZE = 0
+G_BACKWARD_PADDING_SIZE = 0
+
+
+class MoELinear2DFC1(torch.autograd.Function):
+ """2D Linear out axe communication implementation."""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ activation_input,
+ weight,
+ bias,
+ ag_comm_intf: CollectiveCommIntf,
+ ag_overlap_comm_intf: OverlapCollectiveIntf,
+ rs_comm_intf: CollectiveCommIntf,
+ rs_overlap_comm_intf: OverlapCollectiveIntf,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ gradient_accumulation_fusion=False,
+ enable_backward_overlap_ag_with_matmul=False,
+ partition_dim=0,
+ ):
+ """
+ :param ctx: context to save some tensors or vars for backward use.
+ :param activation_input: with shape: [s/(x*cp), b, h/y]
+ :param weight: with shape: [h/y, E/x], E means the output size.
+ :param bias: bias parameter tensor.
+ :param ag_comm_intf: AllGather communication process group interface.
+ :param ag_overlap_comm_intf: AllGather communication overlap send and recv comm group
+ :param rs_comm_intf: ReduceScatter communication process group interface.
+ :param rs_overlap_comm_intf: ReduceScatter communication overlap send and recv comm group
+ :param enable_overlap_ag_with_matmul: enable overlap all-gather with matmul in forward
+ :param enable_overlap_matmul_with_rs: enable overlap matmul with reduce-scatter in forward
+ :param gradient_accumulation_fusion: enable gradient accumulation fusion
+ :param enable_backward_overlap_ag_with_matmul: enable overlap all-gather with matmul
+ :return: forward result tensor.
+ """
+ ctx.weight = weight
+ ctx.use_bias = bias is not None
+ ctx.rs_comm_intf = rs_comm_intf
+ ctx.ag_comm_intf = ag_comm_intf
+ ctx.ag_overlap_comm_intf = ag_overlap_comm_intf
+ ctx.rs_overlap_comm_intf = rs_overlap_comm_intf
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.enable_backward_overlap_ag_with_matmul = enable_backward_overlap_ag_with_matmul
+
+ activation_input = activation_input.contiguous()
+ # [n, h] -> [n, h/y]
+ activation_input = _split_along_last_dim(activation_input, TPYCollectiveComm)
+ ctx.save_for_backward(activation_input)
+ # [N, h/y] @ [h/y, E/x] -> [N, E/x]
+ matmul_res = torch.matmul(activation_input, weight.npu().t())
+ matmul_res = matmul_res.contiguous()
+ n_tokens, h = matmul_res.shape
+ rs_size = rs_comm_intf.get_comm_group_world_size()
+ global G_FORWARD_PADDING_SIZE
+ remaining = n_tokens - n_tokens // rs_size * rs_size
+ G_FORWARD_PADDING_SIZE = rs_size - remaining if remaining else 0
+ if G_FORWARD_PADDING_SIZE != 0:
+ padding_tensor = torch.zeros(G_FORWARD_PADDING_SIZE, h, dtype=matmul_res.dtype,
+ device=matmul_res.device)
+ matmul_res = torch.cat((matmul_res, padding_tensor), dim=0)
+ matmul_res = matmul_res.contiguous()
+ # [N1, E/x] -> [N1/y, E/x]
+ matmul_res = sync_reduce_scatter_along_first_dim(matmul_res, rs_comm_intf)
+ return matmul_res
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ # activation_input shape: [n, h]
+ # weight shape: [h/y, E/x]
+ activation_input, = ctx.saved_tensors
+ weight = ctx.weight
+ use_bias = ctx.use_bias
+ # [N1/y, E/x]---AG(y)---> [N1, E/x]
+ grad_output = grad_output.contiguous()
+ global G_BACKWARD_PADDING_SIZE
+ total_grad_output = sync_gather_along_first_dim(grad_output, ctx.rs_comm_intf)
+ if G_BACKWARD_PADDING_SIZE != 0:
+ real_input_num = total_grad_output.shape[0] - G_BACKWARD_PADDING_SIZE
+ # [N1, E/x] --> [N, E/x]
+ total_grad_output = total_grad_output[:real_input_num, :]
+
+ # prepare total activation_input for computing grad weight.
+ # [N, h/y]
+ total_activation_input = activation_input.contiguous()
+
+ # [N, E/x] @ [E/x, H/y]--> [N, H/y] (partial x)
+ partial_grad_input = total_grad_output.matmul(weight).contiguous()
+ grad_input = partial_grad_input
+ if ctx.gradient_accumulation_fusion:
+ import fused_weight_gradient_mlp_cuda
+ total_grad_output = total_grad_output.contiguous()
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_activation_input, total_grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_activation_input, total_grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ # [E/x, N] @ [N, h/y] ---> [E/x, h/y]
+ grad_weight = total_grad_output.t().matmul(total_activation_input)
+ grad_bias = total_grad_output.sum(dim=0) if use_bias else None
+ grad_input = sync_gather_along_last_dim(grad_input, ctx.rs_comm_intf)
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None
+
+
+class MoELinear2DFC2(torch.autograd.Function):
+ """2D Linear out axe communication implementation."""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ activation_input,
+ weight,
+ bias,
+ ag_comm_intf: CollectiveCommIntf,
+ ag_overlap_comm_intf: OverlapCollectiveIntf,
+ rs_comm_intf: CollectiveCommIntf,
+ rs_overlap_comm_intf: OverlapCollectiveIntf,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ gradient_accumulation_fusion=False,
+ enable_backward_overlap_ag_with_matmul=False,
+ partition_dim=0,
+ ):
+ """
+ :param ctx: context to save some tensors or vars for backward use.
+ :param activation_input: with shape: [s/(x*cp), b, h/y]
+ :param weight: with shape: [h/y, E/x], E means the output size.
+ :param bias: bias parameter tensor.
+ :param ag_comm_intf: AllGather communication process group interface.
+ :param ag_overlap_comm_intf: AllGather communication overlap send and recv comm group
+ :param rs_comm_intf: ReduceScatter communication process group interface.
+ :param rs_overlap_comm_intf: ReduceScatter communication overlap send and recv comm group
+ :param enable_overlap_ag_with_matmul: enable overlap all-gather with matmul in forward
+ :param enable_overlap_matmul_with_rs: enable overlap matmul with reduce-scatter in forward
+ :param gradient_accumulation_fusion: enable gradient accumulation fusion
+ :param enable_backward_overlap_ag_with_matmul: enable overlap all-gather with matmul
+ :return: forward result tensor.
+ """
+ ctx.save_for_backward(activation_input)
+ ctx.weight = weight
+ ctx.use_bias = bias is not None
+ ctx.rs_comm_intf = rs_comm_intf
+ ctx.ag_comm_intf = ag_comm_intf
+ ctx.ag_overlap_comm_intf = ag_overlap_comm_intf
+ ctx.rs_overlap_comm_intf = rs_overlap_comm_intf
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.enable_backward_overlap_ag_with_matmul = enable_backward_overlap_ag_with_matmul
+ activation_input = activation_input.contiguous()
+ # [N1/y, E/x] -> ag(y) -> [N1, E/x]
+ total_input = sync_gather_along_first_dim(activation_input, ag_comm_intf)
+ if G_FORWARD_PADDING_SIZE != 0:
+ real_input_num = total_input.shape[0] - G_FORWARD_PADDING_SIZE
+ # [N1, E/x] -> [N, E/x]
+ total_input = total_input[:real_input_num, :]
+ # [N, E/x] @ [E/x, h/y] -> [N, h/y] (partial x)
+ matmul_res = torch.matmul(total_input, weight.npu().t())
+ matmul_res = sync_gather_along_last_dim(matmul_res, TPYCollectiveComm)
+ return matmul_res
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ # activation_input shape: [N1/y, E/x]
+ # weight shape: [h/y, E/x]
+ activation_input, = ctx.saved_tensors
+ weight = ctx.weight
+ use_bias = ctx.use_bias
+ # [N, h] -> [N, h/y]
+ grad_output = grad_output.contiguous()
+ grad_output = _split_along_last_dim(grad_output, ctx.ag_comm_intf)
+
+ global G_BACKWARD_PADDING_SIZE
+ # [N1/y, E/x]---AG(y)--->[N1, E/x]
+ activation_input = activation_input.contiguous()
+ gather_input_handle, gathered_tensors = async_gather_tensors(
+ local_rank_input=activation_input, ag_comm_intf=ctx.ag_comm_intf
+ )
+ # [N, h/y] @ [E/x, H/y]--> [N, E/x] (partial y)
+ partial_grad_input = grad_output.matmul(weight).contiguous()
+ sb, h = partial_grad_input.shape
+ rs_size = ctx.ag_comm_intf.get_comm_group_world_size()
+
+ remaining = sb - sb // rs_size * rs_size
+ G_BACKWARD_PADDING_SIZE = rs_size - remaining if remaining else 0
+
+ if G_BACKWARD_PADDING_SIZE != 0:
+ padding_tensor = torch.zeros(G_BACKWARD_PADDING_SIZE, h, dtype=partial_grad_input.dtype,
+ device=partial_grad_input.device)
+ # [N, E/x] --> [N1, E/x]
+ partial_grad_input = torch.cat((partial_grad_input, padding_tensor), dim=0)
+ partial_grad_input = partial_grad_input.contiguous()
+ # [N1, E/x] --> [N1/y, E/x]
+ rs_grad_input_handle, grad_input = async_reduce_scatter_along_first_dim(
+ partial_grad_input, comm_intf=ctx.ag_comm_intf
+ )
+
+ if gather_input_handle:
+ gather_input_handle.wait()
+ # [N1, E/x]
+ total_activation_input = gathered_tensors.contiguous()
+ if G_BACKWARD_PADDING_SIZE != 0:
+ real_input_num = total_activation_input.shape[0] - G_BACKWARD_PADDING_SIZE
+ # [N1, E/x] -> [N, E/x]
+ total_activation_input = total_activation_input[:real_input_num, :]
+ total_activation_input = total_activation_input.contiguous()
+ if ctx.gradient_accumulation_fusion:
+ import fused_weight_gradient_mlp_cuda
+ grad_output = grad_output.contiguous()
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_activation_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_activation_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ # [N, E/x] @ [E/x, h/y] --> [N, h/y] (partial x)
+ grad_weight = grad_output.t().matmul(total_activation_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if rs_grad_input_handle:
+ rs_grad_input_handle.wait()
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/linear_2d_split_along_first_dim.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/linear_2d_split_along_first_dim.py
new file mode 100644
index 0000000000000000000000000000000000000000..f482c107668db012f87d23dfb688d455ae805ea8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/linear_2d_split_along_first_dim.py
@@ -0,0 +1,499 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import torch
+import torch_npu
+from torch import distributed as torch_dist
+from torch.cuda.amp import custom_bwd
+from torch.cuda.amp import custom_fwd
+from megatron.training import get_args
+
+from mindspeed.core.tensor_parallel.comm_autograd_function import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_group_api import OverlapCollectiveIntf
+from mindspeed.core.tensor_parallel.comm_utils import async_gather_tensors
+from mindspeed.core.tensor_parallel.comm_utils import async_reduce_scatter_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_gather_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_reduce_scatter_along_first_dim
+
+
+
+def get_comm_domain_rank(devid, ag_size, rs_size, para_type=0): # 在RS domain做agv2
+ if para_type == 0: # TFTF
+ if ag_size == 2: # RS=8, [0 1 2 ... 7], [8 9 10 ... 15]
+ return str(10 + devid // rs_size), devid % rs_size
+ else: # RS=2, [0, 8], [1, 9] ... [7, 15]
+ return str(20 + devid % ag_size), devid // ag_size
+ else: # FTFT
+ if ag_size == 2: # RS=8, [0 2 4 ... 14], [1 3 5 ... 15]
+ return str(10 + devid % ag_size), devid // ag_size
+ else: # RS=2, [0 1], [2 3], [4 5]...
+ return str(20 + devid // rs_size), devid % rs_size
+
+
+class Linear2DSplitAlongFirstDim(torch.autograd.Function):
+ """2D Linear out axe communication implementation."""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ activation_input,
+ weight,
+ bias,
+ ag_comm_intf: CollectiveCommIntf,
+ ag_overlap_comm_intf: OverlapCollectiveIntf,
+ rs_comm_intf: CollectiveCommIntf,
+ rs_overlap_comm_intf: OverlapCollectiveIntf,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ gradient_accumulation_fusion=False,
+ enable_backward_overlap_ag_with_matmul=False,
+ partition_dim=0,
+ ):
+ """
+ :param ctx: context to save some tensors or vars for backward use.
+ :param activation_input: with shape: [s/(x*cp), b, h/y]
+ :param weight: with shape: [h/y, E/x], E means the output size.
+ :param bias: bias parameter tensor.
+ :param ag_comm_intf: AllGather communication process group interface.
+ :param ag_overlap_comm_intf: AllGather communication overlap send and recv comm group
+ :param rs_comm_intf: ReduceScatter communication process group interface.
+ :param rs_overlap_comm_intf: ReduceScatter communication overlap send and recv comm group
+ :param enable_overlap_ag_with_matmul: enable overlap all-gather with matmul in forward
+ :param enable_overlap_matmul_with_rs: enable overlap matmul with reduce-scatter in forward
+ :param gradient_accumulation_fusion: enable gradient accumulation fusion
+ :param enable_backward_overlap_ag_with_matmul: enable overlap all-gather with matmul
+ :return: forward result tensor.
+ """
+ ctx.save_for_backward(activation_input)
+ ctx.weight = weight
+ ctx.use_bias = bias is not None
+ ctx.rs_comm_intf = rs_comm_intf
+ ctx.ag_comm_intf = ag_comm_intf
+ ctx.ag_overlap_comm_intf = ag_overlap_comm_intf
+ ctx.rs_overlap_comm_intf = rs_overlap_comm_intf
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.enable_backward_overlap_ag_with_matmul = enable_backward_overlap_ag_with_matmul
+
+ if enable_overlap_matmul_with_rs:
+ activation_input = activation_input.contiguous()
+ return Linear2DSplitAlongFirstDim._do_mm_overlap_reducescatter(
+ activation_input, weight.t(), bias, ag_comm_intf, rs_comm_intf
+ )
+
+ # first_linear forward: [s/cp, b, H/y] @ [H/y, e/x] -> [s/cp, b, e/x]
+ if enable_overlap_ag_with_matmul:
+ matmul_res, _ = Linear2DSplitAlongFirstDim._do_allgather_left_tensor_and_matmul_overlap(
+ ag_comm_intf,
+ ag_overlap_comm_intf,
+ part_left_tensor=activation_input,
+ full_right_tensor=weight.t(),
+ )
+
+ if bias is not None:
+ matmul_res += bias
+ elif get_args().coc_fused_kernel:
+ from mindspeed.ops.lcal_functional import coc_ops, TP2DConfig
+ inner_dim_is_ag = True
+ if partition_dim == 0:
+ inner_dim_is_ag = True
+ else:
+ inner_dim_is_ag = False
+ # [s/(x*cp), b, H/y] -> [s/cp, b, H/y] -> [s/(cp*y), b, H/x]
+ s, b, h = activation_input.shape
+ # Convert the tensor shapes to 2D for execution compatibility
+ activation_input = activation_input.view(
+ s * b, h
+ )
+ res_shape_0 = s * ag_comm_intf.get_comm_group_world_size() // rs_comm_intf.get_comm_group_world_size()
+ res_shape_1 = weight.shape[0]
+ matmul_res = torch.empty(res_shape_0, res_shape_1, dtype=activation_input.dtype, device=torch.cuda.current_device())
+ coc_ops.all_gather_matmul_reduce_scatter(activation_input, weight, matmul_res,
+ TP2DConfig(
+ ag_comm_intf.get_comm_group_world_size(),
+ rs_comm_intf.get_comm_group_world_size(),
+ inner_dim_is_ag),
+ bias=bias)
+ return matmul_res.view(-1, b, res_shape_1)
+ else:
+ # [s/(x*cp), b, H/y] -> [s/cp, b, H/y]
+ activation_input = activation_input.contiguous()
+ total_input = sync_gather_along_first_dim(activation_input, ag_comm_intf, buffer_name="mpu-sync-tp-2d")
+ # [s/cp, b, H/y] @ [H/y, e/x] -> [s/cp, b, e/x]
+ matmul_res = torch.matmul(total_input, weight.t())
+ # [s/cp, b, E/x] -> [s/(y*cp), b, E/x]
+ matmul_res = matmul_res.contiguous()
+ matmul_res = sync_reduce_scatter_along_first_dim(matmul_res, rs_comm_intf)
+ return matmul_res
+
+
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ """Backward implementation of Linear2DSplitAlongFirstDim, the computation and communication
+ overlap:
+
+ ----------------------------------------------------------------------------->time
+ | AG(grad_o, Y|X)
+ | AG(activation_input, X|Y)
+ | part_grad_act = MM(tot_grad_o, weight)
+ | RS(part_grad_act, X|Y)
+ | MM(tot_grad_o^T, tot_act_input)
+
+
+ :param ctx: context
+ :param grad_output: with shape: [s/cp, b, E/(xy)]
+ :return:grads of all the input para of forward function as a tuple
+ """
+ # activation_input shape: [s/(x*cp), b, h/y]
+ # weight shape: [h/y, E/x]
+ activation_input, = ctx.saved_tensors
+ weight = ctx.weight
+ use_bias = ctx.use_bias
+ s, b, h = grad_output.shape
+ # first we prepare the total inputs needed to compute grad_input, grad_weight.
+ # [s/(y*cp), b, E/x]---AG(y)---> [s/cp, b, E/x]
+ # Use sync AG to avoid communication competition, for the bandwidth is shared for A3.
+ grad_output = grad_output.contiguous()
+ if ctx.enable_backward_overlap_ag_with_matmul and get_args().coc_fused_kernel:
+ from mindspeed.ops.lcal_functional import coc_ops, CoCConfig
+ # prepare total activation_input for computing grad weight.
+ # [s/(x*cp), b, h/y]---AG(X)--->[s/cp, b, h/y]
+ activation_input = activation_input.contiguous()
+ gather_input_handle, gathered_tensors = async_gather_tensors(
+ local_rank_input=activation_input, ag_comm_intf=ctx.ag_comm_intf
+ )
+
+ # Convert the tensor shapes to 2D for execution compatibility
+ grad_output = grad_output.view(s * b, h)
+ ag_size = ctx.ag_comm_intf.get_comm_group_world_size()
+ rs_size = ctx.rs_comm_intf.get_comm_group_world_size()
+ res_shape_0 = s * b * rs_size
+
+ res_shape_1 = weight.shape[1]
+ partial_grad_input = torch.empty(res_shape_0, res_shape_1, dtype=grad_output.dtype, device=torch.cuda.current_device())
+
+ total_grad_output = torch.empty(res_shape_0, h, dtype=grad_output.dtype, device=torch.npu.current_device())
+ comm_domain, coc_rank = get_comm_domain_rank(total_grad_output.device.index, ag_size, rs_size)
+ coc_ops.set_comm_config(CoCConfig(coc_rank, rs_size, comm_domain))
+ coc_ops.all_gather_matmul_v2(input1=grad_output, input2=weight, output=partial_grad_input, comm_output=total_grad_output)
+ partial_grad_input = partial_grad_input.view(-1, b, partial_grad_input.shape[1])
+ else:
+ total_grad_output = sync_gather_along_first_dim(grad_output, ctx.rs_comm_intf, buffer_name="mpu-sync-tp-2d")
+ # prepare total activation_input for computing grad weight.
+ # [s/(x*cp), b, h/y]---AG(X)--->[s/cp, b, h/y]
+ activation_input = activation_input.contiguous()
+ gather_input_handle, gathered_tensors = async_gather_tensors(
+ local_rank_input=activation_input, ag_comm_intf=ctx.ag_comm_intf
+ )
+
+ # [s/cp, b, E/x] @ [E/x, H/y]--> [s/cp, b, H/y] (partial sum)
+ partial_grad_input = total_grad_output.matmul(weight).contiguous()
+
+ # Convert the tensor shapes to 2D for execution compatibility
+ sb = total_grad_output.shape[0] * total_grad_output.shape[1]
+ # [s/cp, b, E/x]--view--> [sb/cp, E/x]
+ total_grad_output = total_grad_output.view(sb, total_grad_output.shape[2])
+
+ # [s/cp, b, H/y] (partial sum)---RS(X)--->[s/cp, b, H/(xy)] (full sum)
+ rs_grad_input_handle, grad_input = async_reduce_scatter_along_first_dim(
+ partial_grad_input, comm_intf=ctx.ag_comm_intf
+ )
+
+ if gather_input_handle:
+ gather_input_handle.wait()
+
+ # [s/(x*cp), b, h/y]---AG(X)--->[s/cp, b, h/y]
+ total_activation_input = gathered_tensors
+ # [s/cp, b, h/y]--view--> [sb/cp, h/y]
+ total_activation_input = total_activation_input.view(-1, total_activation_input.shape[2])
+ if ctx.gradient_accumulation_fusion:
+ import fused_weight_gradient_mlp_cuda
+ total_grad_output = total_grad_output.contiguous()
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_activation_input, total_grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_activation_input, total_grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ # [E/x, sb/cp] @ [sb/cp, h/y] ---> [E/x, h/y]
+ grad_weight = total_grad_output.t().matmul(total_activation_input)
+ grad_bias = total_grad_output.sum(dim=0) if use_bias else None
+
+ if rs_grad_input_handle:
+ rs_grad_input_handle.wait()
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None
+
+ @staticmethod
+ def _do_allgather_left_tensor_and_matmul_overlap(
+ ag_comm_intf, ag_overlap_comm_intf, part_left_tensor, full_right_tensor, return_ag_res=False
+ ):
+ cur_ag_rank = ag_comm_intf.get_comm_rank()
+ ag_world_sz = ag_comm_intf.get_comm_group_world_size()
+
+ # do tp-x times matmul and reduce the partial res.
+ matmul_res = [None] * ag_world_sz
+ cur_step_rcv_handle = None
+ ring_ag_ranks = ag_overlap_comm_intf.get_ring_global_ranks()
+ next_rank = ring_ag_ranks[(cur_ag_rank + ag_world_sz - 1) % ag_world_sz]
+ prev_rank = ring_ag_ranks[(cur_ag_rank + 1) % ag_world_sz]
+ ag_comm_group = ag_comm_intf.get_comm_group()
+ ag_overlap_comm_group = ag_overlap_comm_intf.get_comm_group()
+ cur_step_tensor_to_send = part_left_tensor
+
+ # 下一次要计算的数据(本次要从上一个 rank 接收的 tensor。)
+ cur_step_rcv_input = torch.empty_like(part_left_tensor)
+ all_ag_res = None
+ if return_ag_res:
+ all_ag_res = [None] * ag_world_sz
+ all_ag_res[cur_ag_rank] = part_left_tensor
+
+ # first_linear forward: [H/y, e/x] -> [H/(xy), e/x]
+ for step in range(ag_world_sz):
+ if step < ag_world_sz - 1 and cur_ag_rank % 2 == 0: # 偶数 rank 先发再收
+ torch_dist.isend(cur_step_tensor_to_send, next_rank, ag_comm_group)
+ cur_step_rcv_handle = torch_dist.irecv(
+ cur_step_rcv_input, prev_rank, ag_overlap_comm_group
+ )
+ elif step < ag_world_sz - 1 and cur_ag_rank % 2 == 1: # 奇数 rank 先收再发
+ cur_step_rcv_handle = torch_dist.irecv(cur_step_rcv_input, prev_rank, ag_comm_group)
+ torch_dist.isend(cur_step_tensor_to_send, next_rank, ag_overlap_comm_group)
+
+ # compute: part_left_tensor @ split_right(split by inner dim)
+ # [e/x, h/(xy)]
+ cur_tensor_idx = (step + cur_ag_rank) % ag_world_sz
+ if return_ag_res and step > 0:
+ all_ag_res[cur_tensor_idx] = cur_step_tensor_to_send.clone()
+
+ # first linear forward: [s/(x*cp), b, H/y] @ [H/y, e/x] -> [s/(x*cp), b, e/x]
+ cur_step_matmul_res = torch.matmul(cur_step_tensor_to_send, full_right_tensor)
+ matmul_res[cur_tensor_idx] = cur_step_matmul_res
+
+ if step < ag_world_sz - 1:
+ cur_step_rcv_handle.wait()
+ cur_step_tensor_to_send = cur_step_rcv_input.clone()
+
+ final_matmul_res = torch.cat(matmul_res)
+
+ return final_matmul_res, all_ag_res
+
+ @staticmethod
+ def _do_mm_overlap_reducescatter(activation_input, weight, bias, ag_comm_intf, rs_comm_intf):
+ # [s/(x*cp), b, H/y] -> [s/cp, b, H/y]
+ activation_input = activation_input.contiguous()
+ total_input = sync_gather_along_first_dim(activation_input, ag_comm_intf, buffer_name="mpu-sync-tp-2d")
+ # [s/cp, b, H/y] @ [H/y, e/x] -> [s/cp, b, e/x]
+ chunk_num = rs_comm_intf.get_comm_group_world_size()
+ rs_chunks = []
+ rs_handle_and_tmp_tensors = []
+ # convert tuple to list to free used tensors ahead.
+ seq_len, b, h = total_input.size()
+ chunk_size = seq_len // chunk_num
+ input_chunks = torch.reshape(total_input.view(chunk_size, -1, h).transpose(0, 1), (chunk_num, -1, h))
+ rs_res = torch.empty((chunk_size, b, weight.size(1)), dtype=weight.dtype, device=weight.device)
+ for idx in range(chunk_num):
+ input_chunk = input_chunks[idx].reshape(chunk_size, -1, h)
+ # [s/(cp*y), b, H/y] @ [H/y, e/x] -> [s/(cp*y), b, e/x]
+ chunk_matmul_res = torch.matmul(input_chunk, weight).contiguous()
+ if bias is not None:
+ chunk_matmul_res += bias
+
+ # [s/(cp*y), b, e/x]--rs--> [s/(cp*y*y), b, e/x]
+ rs_handle, rs_chunk = async_reduce_scatter_along_first_dim(
+ chunk_matmul_res, rs_comm_intf
+ )
+ rs_chunks.append(rs_chunk)
+ rs_handle_and_tmp_tensors.append((idx, rs_handle, chunk_matmul_res))
+
+ offset = 0
+ sub_chunk_size = chunk_size // chunk_num
+ for idx, rs_handle, chunk_matmul_res_tensor in rs_handle_and_tmp_tensors:
+ if rs_handle:
+ rs_handle.wait()
+ chunk_matmul_res_tensor.untyped_storage().resize_(0)
+ rs_res[offset:offset + sub_chunk_size] = rs_chunks[idx]
+ offset += sub_chunk_size
+
+ # [s / (cp * y * y), b, e / x] -> [s/(cp*y), b, e/x]
+ final_res = torch.reshape(rs_res.view(chunk_num, -1, weight.size(1)).transpose(0, 1), (chunk_size, -1, weight.size(1)))
+ return final_res
+
+ @staticmethod
+ def _backward_ag_overlap_with_mm(ctx, grad_output):
+ """Backward implementation of Linear2DSplitAlongFirstDim, the computation and communication
+ overlap:
+
+ ----------------------------------------------------------------------------->time
+ | send(grad_o-0, Y|X)
+ | recive(grad_o-1, Y|X)
+ | part_grad_act = MM(tot_grad_o-0, weight)
+ | part_grad_act = MM2(tot_grad_o-1, weight)
+ | RS(part_grad_act, X|Y)
+ | MM(tot_grad_o^T, tot_act_input)
+
+
+ :param ctx: context
+ :param grad_output: with shape: [s/cp, b, E/(xy)]
+ :return:grads of all the input para of forward function as a tuple
+ """
+ # activation_input shape: [s/(x*cp), b, h/y]
+ # weight shape: [h/y, E/x]
+ activation_input, = ctx.saved_tensors
+ weight = ctx.weight
+ use_bias = ctx.use_bias
+ # first we prepare the total inputs needed to compute grad_input, grad_weight.
+ # [s/(y*cp), b, E/x]---AG(y)---> [s/cp, b, E/x]
+ # Use sync AG to avoid communication competition, for the bandwidth is shared for A3.
+ rs_comm_intf = ctx.rs_comm_intf
+ rs_overlap_comm_intf = ctx.rs_overlap_comm_intf
+ grad_output = grad_output.contiguous()
+ cur_rs_rank = ctx.rs_comm_intf.get_comm_rank()
+ rs_world_sz = ctx.rs_comm_intf.get_comm_group_world_size()
+ # do tp-x times matmul and reduce the partial res.
+ matmul_res = [None] * rs_world_sz
+ cur_step_rcv_handle = None
+ ring_rs_ranks = rs_overlap_comm_intf.get_ring_global_ranks()
+ next_rank = ring_rs_ranks[(cur_rs_rank + rs_world_sz - 1) % rs_world_sz]
+ prev_rank = ring_rs_ranks[(cur_rs_rank + 1) % rs_world_sz]
+ rs_comm_group = rs_comm_intf.get_comm_group()
+ rs_overlap_comm_group = rs_overlap_comm_intf.get_comm_group()
+ cur_step_tensor_to_send = grad_output
+ # 下一次要计算的数据(本次要从上一个 rank 接收的 tensor。)
+ cur_step_rcv_input = torch.empty_like(grad_output)
+ # first_linear forward: [H/y, e/x] -> [H/(xy), e/x]
+ # collect total_grad_output
+ grad_output_list = [None] * rs_world_sz
+ grad_output_list[cur_rs_rank] = grad_output
+ gather_input_handle, gathered_tensors = None, None
+ for step in range(rs_world_sz):
+ if step < rs_world_sz - 1 and cur_rs_rank % 2 == 0: # 偶数 rank 先发再收
+ torch_dist.isend(cur_step_tensor_to_send, next_rank, rs_comm_group)
+ cur_step_rcv_handle = torch_dist.irecv(
+ cur_step_rcv_input, prev_rank, rs_overlap_comm_group
+ )
+ elif step < rs_world_sz - 1 and cur_rs_rank % 2 == 1: # 奇数 rank 先收再发
+ cur_step_rcv_handle = torch_dist.irecv(cur_step_rcv_input, prev_rank, rs_comm_group)
+ torch_dist.isend(cur_step_tensor_to_send, next_rank, rs_overlap_comm_group)
+
+ # compute: grad_output @ split_right(split by inner dim)
+ # [e/x, h/(xy)]
+ cur_tensor_idx = (step + cur_rs_rank) % rs_world_sz
+
+ # first linear forward: [s/(x*cp), b, H/y] @ [H/y, e/x] -> [s/(x*cp), b, e/x]
+ cur_step_matmul_res = torch.matmul(cur_step_tensor_to_send, weight)
+ matmul_res[cur_tensor_idx] = cur_step_matmul_res
+ if step > 0:
+ grad_output_list[cur_tensor_idx] = cur_step_tensor_to_send.clone()
+ if step < rs_world_sz - 1:
+ cur_step_rcv_handle.wait()
+ cur_step_tensor_to_send = cur_step_rcv_input.clone()
+ if step == 0:
+ # prepare total activation_input for computing grad weight.
+ # [s/(x*cp), b, h/y]---AG(X)--->[s/cp, b, h/y]
+ activation_input = activation_input.contiguous()
+ gather_input_handle, gathered_tensors = async_gather_tensors(
+ local_rank_input=activation_input, ag_comm_intf=ctx.ag_comm_intf
+ )
+
+ partial_grad_input = torch.cat(matmul_res)
+ # [s/cp, b, H/y] (partial sum)---RS(X)--->[s/cp, b, H/(xy)] (full sum)
+ rs_grad_input_handle, grad_input = async_reduce_scatter_along_first_dim(
+ partial_grad_input, comm_intf=ctx.ag_comm_intf
+ )
+
+ total_grad_output = torch.cat(grad_output_list, dim=0)
+
+ # Convert the tensor shapes to 2D for execution compatibility
+ sb = total_grad_output.shape[0] * total_grad_output.shape[1]
+ # [s/cp, b, E/x]--view--> [sb/cp, E/x]
+ total_grad_output = total_grad_output.view(sb, total_grad_output.shape[2])
+
+ if gather_input_handle:
+ gather_input_handle.wait()
+
+ # [s/(x*cp), b, h/y]---AG(X)--->[s/cp, b, h/y]
+ total_activation_input = gathered_tensors
+ # [s/cp, b, h/y]--view--> [sb/cp, h/y]
+ total_activation_input = total_activation_input.view(sb, total_activation_input.shape[2])
+ if ctx.gradient_accumulation_fusion:
+ import fused_weight_gradient_mlp_cuda
+ total_grad_output = total_grad_output.contiguous()
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_activation_input, total_grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_activation_input, total_grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=activation_input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ # [E/x, sb/cp] @ [sb/cp, h/y] ---> [E/x, h/y]
+ grad_weight = total_grad_output.t().matmul(total_activation_input)
+ grad_bias = total_grad_output.sum(dim=0) if use_bias else None
+
+ if rs_grad_input_handle:
+ rs_grad_input_handle.wait()
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/norm_factory.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/norm_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..1544371e47ee56280a6af3d70c60a0fad973f5e0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/norm_factory.py
@@ -0,0 +1,73 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from functools import wraps
+from typing import List
+
+import torch
+import torch.distributed as dist
+from torch._utils import _flatten_dense_tensors
+from torch._utils import _unflatten_dense_tensors
+
+from megatron.core.transformer import TransformerConfig
+from megatron.core.utils import get_attr_wrapped_model
+from megatron.training import get_args
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+from mindspeed.core.tensor_parallel.tp_2d.layernorm_2d import LayerNorm2D
+from mindspeed.core.tensor_parallel.tp_2d.rms_norm_2d import RMSNorm2D
+
+
+def _allreduce_layernorm_grads_wrapper(function):
+ @wraps(function)
+ def wrapper(model: List[torch.nn.Module], config: TransformerConfig):
+ function(model, config)
+ layer_norm_2d_grads = []
+ for model_chunk in model:
+ for name, param in get_attr_wrapped_model(model_chunk, "named_parameters")():
+ if param.requires_grad and getattr(param, "2d_tp", False):
+ layer_norm_2d_grad = param.main_grad
+ layer_norm_2d_grads.append(layer_norm_2d_grad.data)
+
+ if layer_norm_2d_grads:
+ coalesced = _flatten_dense_tensors(layer_norm_2d_grads)
+ torch.distributed.all_reduce(coalesced, group=TPXCollectiveComm.get_comm_group())
+ for buf, synced in zip(
+ layer_norm_2d_grads, _unflatten_dense_tensors(coalesced, layer_norm_2d_grads)
+ ):
+ buf.copy_(synced)
+
+ return wrapper
+
+
+def get_norm_tp_2d(config):
+ args = get_args()
+ if args.normalization == "LayerNorm":
+ return LayerNorm2D(
+ config.hidden_size,
+ eps=config.layernorm_epsilon,
+ last_dim_split_comm_intf=TPYCollectiveComm(),
+ )
+ elif args.normalization == "RMSNorm":
+ if args.apply_layernorm_1p:
+ raise NotImplementedError(
+ "RMSNorm does not currently support the layernorm_1p formulation."
+ )
+ return RMSNorm2D(
+ config.hidden_size,
+ eps=config.layernorm_epsilon,
+ last_dim_split_comm_intf=TPYCollectiveComm(),
+ )
+ else:
+ raise Exception(f"unsupported norm type '{args.normalization}'.")
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/parallel_linear_2d.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/parallel_linear_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ff0293f7d3723e0ffa4f0e59b3ae64ca3dc298
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/parallel_linear_2d.py
@@ -0,0 +1,204 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from typing import Any, Callable
+
+import torch
+
+from megatron.core import ModelParallelConfig
+from megatron.core.tensor_parallel.layers import _initialize_affine_weight_gpu
+from megatron.core.utils import divide
+from mindspeed.core.tensor_parallel.comm_group_api import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_group_api import OverlapCollectiveIntf
+from mindspeed.core.tensor_parallel.layers import _initialize_affine_weight_cpu_2d
+from mindspeed.core.tensor_parallel.tp_2d.linear_2d_moe_split_along_first_dim import MoELinear2DFC1, MoELinear2DFC2
+from mindspeed.core.tensor_parallel.tp_2d.linear_2d_split_along_first_dim import Linear2DSplitAlongFirstDim
+
+
+class ParallelLinear2D(torch.nn.Module):
+ """Linear2D layer with row and column parallelism.
+
+ The linear layer is defined as Y = XA + b. A is parallelized along
+ its second dimension as A = [A_1, ..., A_p].
+
+ Arguments:
+ input_size: first dimension of matrix A.
+ output_size: second dimension of matrix A.
+
+ Keyword Arguments
+ bias: If true, add bias
+ gather_output: If true, call all-gather on output and make Y available
+ to all GPUs, otherwise, every GPU will have its output
+ which is Y_i = XA_i
+ init_method: method to initialize weights. Note that bias is always set
+ to zero.
+ stride: For the strided linear layers.
+ keep_master_weight_for_test: This was added for testing and should be
+ set to False. It returns the master weights
+ used for initialization.
+ skip_bias_add: If True, do not add the bias term, instead
+ return it to be added by the caller. This
+ enables performance optimations where bias can
+ be fused with other elementwise operations.
+ skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
+ as a keyword argument `weight` during the forward pass. Note
+ that this does not affect bias, which will be allocated if
+ bias is True. Defaults to False.
+ is_expert: If True, the layer is treated as an MoE expert layer.
+ config: ModelParallelConfig object
+ tp_comm_buffer_name: Communication buffer name is not used in
+ non-Transformer-Engine modules.
+ partition_dim: divide with dim, column parallel set 0, row parallel set 1
+ enable_backward_overlap_ag_with_matmul: enable overlap all-gather with matmul
+
+ """
+
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ *,
+ config: ModelParallelConfig,
+ init_method: Callable,
+ add_bias=True,
+ gather_output=False,
+ stride=1,
+ keep_master_weight_for_test=False,
+ skip_bias_add=True,
+ skip_weight_param_allocation: bool = False,
+ is_expert: bool = False,
+ ag_comm_intf: CollectiveCommIntf = None,
+ ag_sd_rcv_overlap_comm_intf: OverlapCollectiveIntf = None,
+ rs_comm_intf: CollectiveCommIntf = None,
+ rs_sd_rcv_overlap_comm_intf: OverlapCollectiveIntf = None,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ partition_dim: int = 0,
+ enable_backward_overlap_ag_with_matmul=False,
+ ):
+ super().__init__()
+ self.mp_config: ModelParallelConfig = config
+ self.para_init_method = init_method
+ self.stride = stride
+ self.keep_master_weight_for_test = keep_master_weight_for_test
+ self.add_bias = add_bias
+ self.input_size = input_size
+ self.output_size = output_size
+ self.ag_comm_intf = ag_comm_intf
+ self.rs_comm_intf = rs_comm_intf
+ self.ag_comm_world_sz = ag_comm_intf.get_comm_group_world_size()
+ self.rs_comm_world_sz = rs_comm_intf.get_comm_group_world_size()
+ # when AG comm group is small, do overlap AG with matmul.
+ self.enable_overlap_ag_with_matmul = enable_overlap_ag_with_matmul
+ self.enable_overlap_matmul_with_rs = enable_overlap_matmul_with_rs
+ self.ag_overlap_comm_intf = ag_sd_rcv_overlap_comm_intf
+ self.rs_sd_rcv_overlap_comm_intf = rs_sd_rcv_overlap_comm_intf
+
+ if input_size % self.rs_comm_world_sz:
+ raise AssertionError("input size should be divisible by tp-y")
+ if output_size % self.ag_comm_world_sz:
+ raise AssertionError("output size should be divisible by tp-x")
+
+ self.input_size_per_partition = divide(input_size, self.rs_comm_world_sz)
+ self.output_size_per_partition = divide(output_size, self.ag_comm_world_sz)
+ self.skip_bias_add = skip_bias_add
+ self.is_expert = is_expert
+ self.expert_parallel = config.expert_model_parallel_size > 1
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
+ self.enable_backward_overlap_ag_with_matmul = enable_backward_overlap_ag_with_matmul
+ if config.sequence_parallel:
+ raise RuntimeError(
+ "Nd_matmul cannot be used with sequence_parallel."
+ "If you want to train long sequences, "
+ "you can use ulysess or context_parallel that is compatible with nd_matmul."
+ )
+ self.partition_dim = partition_dim
+ self.init_linear_weights()
+
+ def init_linear_weights(self):
+ init_with_cpu = self.mp_config.use_cpu_initialization
+ device = None if init_with_cpu else torch.cuda.current_device()
+
+ self.weight = torch.nn.Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ self.input_size_per_partition,
+ device=device,
+ dtype=self.mp_config.params_dtype,
+ )
+ )
+ if self.add_bias:
+ self.bias = torch.nn.Parameter(
+ torch.empty(self.output_size_per_partition, dtype=self.mp_config.params_dtype, device=device)
+ )
+ else:
+ self.register_parameter("bias", None)
+
+ if init_with_cpu and self.mp_config.perform_initialization:
+ _initialize_affine_weight_cpu_2d(self.weight, self.partition_dim, stride=self.stride,
+ return_master_weight=self.keep_master_weight_for_test,
+ config=self.mp_config)
+ elif self.mp_config.perform_initialization:
+ _initialize_affine_weight_gpu(
+ self.weight,
+ self.para_init_method,
+ partition_dim=self.partition_dim,
+ stride=self.stride,
+ expert_parallel=(self.is_expert and self.expert_parallel),
+ )
+
+ setattr(self.weight, "allreduce", not (self.is_expert and self.expert_parallel))
+
+ if self.add_bias and self.mp_config.perform_initialization:
+ with torch.no_grad():
+ self.bias.zero_()
+
+ setattr(self.bias, "allreduce", not (self.is_expert and self.expert_parallel))
+ setattr(self.bias, "sequence_parallel", False)
+
+ def set_extra_state(self, state: Any):
+ """ Extra state is ignored """
+
+ def get_extra_state(self) -> None:
+ """ Keep compatibility with TE state dict. """
+ return None
+
+ def forward(self, activation_input):
+ if self.is_expert:
+ if self.partition_dim == 0:
+ linear_func = MoELinear2DFC1
+ else:
+ linear_func = MoELinear2DFC2
+ else:
+ linear_func = Linear2DSplitAlongFirstDim
+ matmul_output = linear_func.apply(
+ activation_input,
+ self.weight,
+ self.bias,
+ self.ag_comm_intf,
+ self.ag_overlap_comm_intf,
+ self.rs_comm_intf,
+ self.rs_sd_rcv_overlap_comm_intf,
+ self.enable_overlap_ag_with_matmul,
+ self.enable_overlap_matmul_with_rs,
+ self.gradient_accumulation_fusion,
+ self.enable_backward_overlap_ag_with_matmul,
+ self.partition_dim,
+ )
+
+ if not self.skip_bias_add:
+ output = (matmul_output + self.bias) if self.bias is not None else matmul_output
+ output_bias = None
+ else:
+ output = matmul_output
+ output_bias = self.bias
+
+ return output, output_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/rms_norm_2d.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/rms_norm_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..77e0868e7b380626e41796e188d0689f969d831d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/tp_2d/rms_norm_2d.py
@@ -0,0 +1,98 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from typing import Any
+from typing import Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from torch import nn
+from torch.cuda.amp import custom_bwd
+from torch.cuda.amp import custom_fwd
+from megatron.core.utils import divide
+from mindspeed.core.tensor_parallel.comm_group_api import CollectiveCommIntf
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+
+
+class RMSNorm2D(torch.nn.Module):
+
+ def __init__(self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ last_dim_split_comm_intf: CollectiveCommIntf = TPYCollectiveComm()):
+ """RMS Normaliation 2d module
+
+ Args:
+ hidden_size (int): The width of input, i.e. hidden size
+ eps (float): epsilon to use for the norm, default to 1e-6
+ last_dim_split_comm_intf: All-reduce at last dim comm intf.
+ """
+ super().__init__()
+ self.eps = eps
+ self.hidden_size = hidden_size
+ self.last_dim_split_comm_intf = last_dim_split_comm_intf
+ self.last_dim_split_comm_world_sz = self.last_dim_split_comm_intf.get_comm_group_world_size()
+ # partitioning dimension
+ self.partitioned_dim = divide(hidden_size, self.last_dim_split_comm_world_sz)
+ self.weight = nn.Parameter(torch.ones(self.partitioned_dim))
+
+ setattr(self.weight, "2d_tp", True)
+
+ def forward(self, x):
+ return _ParallelRMSNorm2D.apply(
+ x,
+ self.weight,
+ self.eps,
+ self.hidden_size,
+ self.last_dim_split_comm_intf,
+ )
+
+
+class _ParallelRMSNorm2D(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx: Any,
+ input_: Tensor,
+ weight,
+ epsilon,
+ hidden_size: int,
+ last_dim_split_comm_intf: CollectiveCommIntf,
+ ) -> Tensor:
+ # input_ inner: [s/cp, b, h/xy]
+ # input_ outer: [s/(cp*x), b, h/y]
+ ctx.last_dim_split_comm_intf = last_dim_split_comm_intf
+ ctx.hidden_size = hidden_size
+ pow_mean = input_.float().pow(2).mean(-1, keepdim=True)
+ torch.distributed.all_reduce(pow_mean, group=last_dim_split_comm_intf.get_comm_group())
+ var_x = torch.rsqrt(pow_mean.div_(last_dim_split_comm_intf.get_comm_group_world_size()) + epsilon)
+ norm_x = torch.mul(input_, var_x).type_as(input_)
+ ctx.save_for_backward(norm_x, var_x, weight)
+ return norm_x * weight
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
+ x, var_x, weight = ctx.saved_tensors
+ grad_norm_x = torch.mul(output_grad, weight)
+ output_grad_mul_x_sum = torch.sum(grad_norm_x * x, dim=-1, keepdim=True)
+ handle_grad = torch.distributed.all_reduce(
+ output_grad_mul_x_sum, group=ctx.last_dim_split_comm_intf.get_comm_group(), async_op=True)
+ # calculate grad_weight
+ grad_weight = torch.mul(output_grad, x)
+ if handle_grad:
+ handle_grad.wait()
+ grad_input = (grad_norm_x - x * (output_grad_mul_x_sum / ctx.hidden_size)) * var_x
+ return grad_input, grad_weight, None, None, None
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/__init__.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/adaptor.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/adaptor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ced8263706a40b303d718199e02d6b5a7838242a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/adaptor.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+import torch
+from megatron.training import get_args
+from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_and_expert_parallel_group
+from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, _initialize_affine_weight_cpu, \
+ _initialize_affine_weight_gpu, set_tensor_model_parallel_attributes, _grad_accum_fusion_available, \
+ linear_with_grad_accumulation_and_async_allreduce, linear_with_frozen_weight
+from megatron.core.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region, \
+ reduce_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, copy_to_tensor_model_parallel_region
+from megatron.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region as megatron_scatter_to_sequence_parallel_region
+from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region as megatron_gather_from_sequence_parallel_region
+
+from .unaligned_column_parallel_linear import UnalignedColumnParallelLinear
+from .unaligned_row_parallel_linear import UnalignedRowParallelLinear
+from .unaligned_utils import unaligned_divide, unaligned_scatter_to_sequence_parallel_region, \
+ unaligned_reduce_scatter_to_sequence_parallel_region, unaligned_gather_from_sequence_parallel_region
+
+
+class UnalignedColumnParallelLinearAdaptor(UnalignedColumnParallelLinear, ColumnParallelLinear):
+ def __init__(self, *args, **kwargs):
+ config = kwargs['config']
+ explicit_expert_comm = config.tensor_model_parallel_size > 1 or config.expert_model_parallel_size > 1
+ if 'is_expert' not in kwargs:
+ kwargs['is_expert'] = False
+ if 'tp_comm_buffer_name' not in kwargs:
+ kwargs['tp_comm_buffer_name'] = None
+
+ if kwargs['is_expert'] and explicit_expert_comm and config.moe_extended_tp:
+ kwargs['parallel_group'] = get_tensor_and_expert_parallel_group()
+ else:
+ kwargs['parallel_group'] = get_tensor_model_parallel_group()
+
+ if kwargs['tp_comm_buffer_name'] == 'qkv':
+ kwargs['fusion_number'] = (config.hidden_size + 2 * config.kv_channels * config.num_query_groups) // config.num_query_groups
+ else:
+ kwargs['fusion_number'] = 1
+
+ if not config.variable_seq_lengths:
+ kwargs['seq_length'] = get_args().seq_length
+
+ kwargs['_initialize_affine_weight_cpu'] = _initialize_affine_weight_cpu
+ kwargs['_initialize_affine_weight_gpu'] = _initialize_affine_weight_gpu
+ kwargs['set_tensor_model_parallel_attributes'] = set_tensor_model_parallel_attributes
+ kwargs['linear_with_grad_accumulation_and_async_allreduce'] = linear_with_grad_accumulation_and_async_allreduce
+ kwargs['gather_from_tensor_model_parallel_region'] = gather_from_tensor_model_parallel_region
+ kwargs['copy_to_tensor_model_parallel_region'] = copy_to_tensor_model_parallel_region
+ kwargs['linear_with_frozen_weight'] = linear_with_frozen_weight
+ super(UnalignedColumnParallelLinearAdaptor, self).__init__(*args, **kwargs)
+
+
+class UnalignedRowParallelLinearAdaptor(UnalignedRowParallelLinear, RowParallelLinear):
+ def __init__(self, *args, **kwargs):
+ config = kwargs['config']
+ explicit_expert_comm = config.tensor_model_parallel_size > 1 or config.expert_model_parallel_size > 1
+ if 'is_expert' not in kwargs:
+ kwargs['is_expert'] = False
+ if 'tp_comm_buffer_name' not in kwargs:
+ kwargs['tp_comm_buffer_name'] = None
+
+ if kwargs['is_expert'] and explicit_expert_comm and config.moe_extended_tp:
+ kwargs['parallel_group'] = get_tensor_and_expert_parallel_group()
+ else:
+ kwargs['parallel_group'] = get_tensor_model_parallel_group()
+
+ if kwargs['tp_comm_buffer_name'] is not None and not kwargs['tp_comm_buffer_name'].startswith('fc'): # attention.linear_proj
+ kwargs['fusion_number'] = config.hidden_size // config.num_query_groups
+ else:
+ kwargs['fusion_number'] = 1
+
+ if not config.variable_seq_lengths:
+ kwargs['seq_length'] = get_args().seq_length
+
+ kwargs['_initialize_affine_weight_cpu'] = _initialize_affine_weight_cpu
+ kwargs['_initialize_affine_weight_gpu'] = _initialize_affine_weight_gpu
+ kwargs['linear_with_grad_accumulation_and_async_allreduce'] = linear_with_grad_accumulation_and_async_allreduce
+ kwargs['scatter_to_tensor_model_parallel_region'] = scatter_to_tensor_model_parallel_region
+ kwargs['linear_with_frozen_weight'] = linear_with_frozen_weight
+ kwargs['reduce_from_tensor_model_parallel_region'] = reduce_from_tensor_model_parallel_region
+ super(UnalignedRowParallelLinearAdaptor, self).__init__(*args, **kwargs)
+
+
+def divide_adaptor(numerator, denominator):
+ if numerator % denominator != 0:
+ rank = torch.distributed.get_rank(group=get_tensor_model_parallel_group())
+ return unaligned_divide(numerator, denominator, rank)
+ return numerator // denominator
+
+
+def scatter_to_sequence_parallel_region_adaptor(embeddings):
+ world_size = torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
+ if embeddings.size()[0] % world_size != 0:
+ return unaligned_scatter_to_sequence_parallel_region(embeddings, get_tensor_model_parallel_group())
+ else:
+ return megatron_scatter_to_sequence_parallel_region(embeddings)
+
+
+def reduce_scatter_to_sequence_parallel_region_adaptor(inputs):
+ group = get_tensor_model_parallel_group()
+ return unaligned_reduce_scatter_to_sequence_parallel_region(inputs, group)
+
+
+def gather_from_sequence_parallel_region_adaptor(inputs, tensor_parallel_output_grad=True):
+ world_size = torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
+
+ dim_size = torch.tensor(inputs.shape[0], dtype=torch.long, device=inputs.device)
+ torch.distributed.all_reduce(dim_size)
+ total_dim_size = dim_size.item()
+
+ group = get_tensor_model_parallel_group()
+ if total_dim_size % world_size != 0:
+ return unaligned_gather_from_sequence_parallel_region(inputs, group, tensor_parallel_output_grad)
+ else:
+ return megatron_gather_from_sequence_parallel_region(inputs, tensor_parallel_output_grad)
+
+
+def get_rotary_seq_len(
+ self,
+ inference_params,
+ transformer,
+ transformer_input,
+ transformer_config,
+) -> float:
+ if inference_params is not None:
+ rotary_seq_len = inference_params.max_sequence_length
+ else:
+ if transformer.input_tensor is not None:
+ rotary_seq_len = transformer.input_tensor.size(0)
+ else:
+ rotary_seq_len = transformer_input.size(0)
+
+ if transformer_config.sequence_parallel:
+ if not transformer_config.variable_seq_lengths:
+ rotary_seq_len = get_args().seq_length
+ else:
+ rotary_seq_len = torch.Tensor([rotary_seq_len]).cuda().int()
+ torch.distributed.all_reduce(rotary_seq_len, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group())
+ rotary_seq_len = rotary_seq_len.item()
+
+ rotary_seq_len *= transformer_config.context_parallel_size
+
+ return rotary_seq_len
+
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_column_parallel_linear.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_column_parallel_linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7438c2cf1d775a1834f128b2cdf6e1b64b93583
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_column_parallel_linear.py
@@ -0,0 +1,250 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+import warnings
+from typing import Callable, Optional, List
+
+import torch
+from torch.nn import Parameter
+
+from mindspeed.core.tensor_parallel.unaligned_layers.unaligned_utils import unaligned_divide, \
+ unaligned_linear_with_grad_accumulation_and_async_allreduce
+
+
+class UnalignedColumnParallelLinear(torch.nn.Module):
+
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ *,
+ config,
+ init_method: Callable,
+ bias=True,
+ gather_output=False,
+ stride=1,
+ keep_master_weight_for_test=False,
+ skip_bias_add=False,
+ skip_weight_param_allocation: bool = False,
+ embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
+ grad_output_buffer: Optional[List[torch.Tensor]] = None,
+ is_expert: bool = False,
+ tp_comm_buffer_name: str = None, # Not used
+ disable_grad_reduce: bool = False,
+
+ # unaligned parallel arguments
+ parallel_group: Optional[torch.distributed.ProcessGroup] = None,
+ fusion_number: int = 1, # the number of linear fused
+ seq_length: int = None,
+ _initialize_affine_weight_cpu: Callable = None,
+ _initialize_affine_weight_gpu: Callable = None,
+ set_tensor_model_parallel_attributes: Callable = None,
+ linear_with_grad_accumulation_and_async_allreduce=None,
+ copy_to_tensor_model_parallel_region=None,
+ linear_with_frozen_weight=None,
+ gather_from_tensor_model_parallel_region=None
+ ):
+ torch.nn.Module.__init__(self)
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.gather_output = gather_output
+ # Divide the weight matrix along the last dimension.
+ self.skip_bias_add = skip_bias_add
+ self.is_expert = is_expert
+ self.expert_parallel = config.expert_model_parallel_size > 1
+ self.embedding_activation_buffer = embedding_activation_buffer
+ self.grad_output_buffer = grad_output_buffer
+ self.config = config
+ self.disable_grad_reduce = disable_grad_reduce
+
+ self.explicit_expert_comm = self.is_expert and (
+ config.tensor_model_parallel_size > 1 or self.expert_parallel
+ )
+
+ world_size = torch.distributed.get_world_size(group=parallel_group)
+ rank = torch.distributed.get_rank(group=parallel_group)
+
+ if self.output_size % fusion_number != 0:
+ raise AssertionError('output_size({}) must be divisible by fusion number({})'.format(self.output_size, fusion_number))
+ if fusion_number != 1:
+ self.output_size_per_partition = unaligned_divide(config.num_query_groups, world_size, rank)
+ self.output_size_per_partition *= fusion_number
+ else:
+ self.output_size_per_partition = unaligned_divide(self.output_size, world_size, rank)
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ if not skip_weight_param_allocation:
+ if config.use_cpu_initialization:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size_per_partition, self.input_size, dtype=config.params_dtype
+ )
+ )
+ if config.perform_initialization:
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight,
+ self.output_size,
+ self.input_size,
+ self.output_size_per_partition,
+ 0,
+ init_method,
+ stride=stride,
+ return_master_weight=keep_master_weight_for_test,
+ rank=rank,
+ world_size=world_size,
+ )
+ else:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ self.input_size,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+ if config.perform_initialization:
+ _initialize_affine_weight_gpu(
+ self.weight,
+ init_method,
+ partition_dim=0,
+ stride=stride,
+ expert_parallel=(self.is_expert and self.expert_parallel),
+ )
+
+ setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
+ else:
+ self.weight = None
+
+ if bias:
+ if config.use_cpu_initialization:
+ self.bias = Parameter(
+ torch.empty(self.output_size_per_partition, dtype=config.params_dtype)
+ )
+ else:
+ self.bias = Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+ set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
+ if config.perform_initialization:
+ # Always initialize bias to zero.
+ with torch.no_grad():
+ self.bias.zero_()
+ setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
+ else:
+ self.register_parameter('bias', None)
+
+ self.sequence_parallel = config.sequence_parallel
+ if self.sequence_parallel and world_size <= 1:
+ warnings.warn(
+ f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. "
+ f"Disabling sequence parallel."
+ )
+ self.sequence_parallel = False
+
+ self.allreduce_dgrad = world_size > 1 and not self.sequence_parallel
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
+
+ if self.allreduce_dgrad and self.sequence_parallel:
+ raise RuntimeError(
+ "`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time."
+ )
+
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+
+ # Hook adding a default empty _extra_state for state dict
+ self._register_load_state_dict_pre_hook(
+ lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
+ f'{prefix}_extra_state'
+ )
+ )
+
+ self.seq_length = seq_length
+ self.copy_to_tensor_model_parallel_region = copy_to_tensor_model_parallel_region
+ self.linear_with_frozen_weight = linear_with_frozen_weight
+ self.parallel_group = parallel_group
+ self.gather_from_tensor_model_parallel_region = gather_from_tensor_model_parallel_region
+
+ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
+ if weight is None:
+ if self.weight is None:
+ raise RuntimeError(
+ "weight was not supplied to ColumnParallelLinear forward pass "
+ "and skip_weight_param_allocation is True."
+ )
+ weight = self.weight
+ else:
+ # Check the weight passed in is the correct shape
+ expected_shape = (self.output_size_per_partition, self.input_size)
+ if weight.shape != expected_shape:
+ raise RuntimeError(
+ f"supplied weight's shape is {tuple(weight.shape)}, "
+ f"not {expected_shape} as expected"
+ )
+
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ bias = self.bias if not self.skip_bias_add else None
+
+ if (
+ self.allreduce_dgrad
+ or self.sequence_parallel
+ or self.explicit_expert_comm
+ or self.disable_grad_reduce
+ ):
+ input_parallel = input_
+ else:
+ input_parallel = self.copy_to_tensor_model_parallel_region(input_)
+
+ if self.config.defer_embedding_wgrad_compute:
+ self.embedding_activation_buffer.append(input_parallel)
+
+ allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
+ # Matrix multiply.
+ if not weight.requires_grad:
+ self._forward_impl = self.linear_with_frozen_weight
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=weight,
+ bias=bias,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=allreduce_dgrad,
+ sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
+ grad_output_buffer=self.grad_output_buffer
+ if self.config.defer_embedding_wgrad_compute
+ else None,
+ allreduce_dgrad=allreduce_dgrad
+ )
+ else:
+ self._forward_impl = unaligned_linear_with_grad_accumulation_and_async_allreduce
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=weight,
+ bias=bias,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
+ grad_output_buffer=self.grad_output_buffer
+ if self.config.defer_embedding_wgrad_compute
+ else None,
+ allreduce_dgrad=allreduce_dgrad,
+ parallel_group=self.parallel_group,
+ seq_length=self.seq_length
+ )
+ if self.gather_output:
+ # All-gather across the partitions.
+ assert not self.sequence_parallel
+ output = self.gather_from_tensor_model_parallel_region(output_parallel)
+ else:
+ output = output_parallel
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_row_parallel_linear.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_row_parallel_linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..37668ccce1e0cda36c01621d1c4fc725db78df1d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_row_parallel_linear.py
@@ -0,0 +1,206 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+from typing import Callable, Optional
+
+import torch
+from torch.nn import Parameter
+
+from mindspeed.core.tensor_parallel.unaligned_layers.unaligned_utils import unaligned_divide, \
+ unaligned_reduce_scatter_to_sequence_parallel_region, unaligned_linear_with_grad_accumulation_and_async_allreduce
+
+
+class UnalignedRowParallelLinear(torch.nn.Module):
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ *,
+ config,
+ init_method: Callable,
+ bias: bool,
+ input_is_parallel: bool,
+ skip_bias_add: bool,
+ stride: int = 1,
+ keep_master_weight_for_test: bool = False,
+ is_expert: bool = False,
+ tp_comm_buffer_name: str = None, # Not used
+
+ # unaligned parallel arguments
+ parallel_group: Optional[torch.distributed.ProcessGroup] = None,
+ fusion_number: int = 1, # the number of linear fused
+ seq_length: int = None,
+ _initialize_affine_weight_cpu: Callable = None,
+ _initialize_affine_weight_gpu: Callable = None,
+ linear_with_grad_accumulation_and_async_allreduce=None,
+ scatter_to_tensor_model_parallel_region=None,
+ linear_with_frozen_weight=None,
+ reduce_from_tensor_model_parallel_region=None,
+ ):
+ torch.nn.Module.__init__(self)
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.input_is_parallel = input_is_parallel
+ self.skip_bias_add = skip_bias_add
+ self.config = config
+ self.is_expert = is_expert
+ self.expert_parallel = config.expert_model_parallel_size > 1
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
+ self.sequence_parallel = config.sequence_parallel
+ if self.sequence_parallel and not self.input_is_parallel:
+ raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
+
+ self.explicit_expert_comm = self.is_expert and (
+ config.tensor_model_parallel_size > 1 or self.expert_parallel
+ )
+
+ # Divide the weight matrix along the last dimension.
+ world_size = torch.distributed.get_world_size(group=parallel_group)
+ rank = torch.distributed.get_rank(group=parallel_group)
+
+ if self.input_size % fusion_number != 0:
+ raise AssertionError('input_size({}) must be divisible by fusion number({})'.format(self.input_size, fusion_number))
+
+ if fusion_number != 1:
+ self.input_size_per_partition = unaligned_divide(config.num_query_groups, world_size, rank)
+ self.input_size_per_partition *= fusion_number
+ else:
+ self.input_size_per_partition = unaligned_divide(self.input_size, world_size, rank)
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ if config.use_cpu_initialization:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size, self.input_size_per_partition, dtype=config.params_dtype
+ )
+ )
+ if config.perform_initialization:
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight,
+ self.output_size,
+ self.input_size,
+ self.input_size_per_partition,
+ 1,
+ init_method,
+ stride=stride,
+ return_master_weight=keep_master_weight_for_test,
+ params_dtype=config.params_dtype,
+ rank=rank,
+ world_size=world_size,
+ )
+ else:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size,
+ self.input_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+ if config.perform_initialization:
+ _initialize_affine_weight_gpu(
+ self.weight,
+ init_method,
+ partition_dim=1,
+ stride=stride,
+ expert_parallel=(self.is_expert and self.expert_parallel),
+ )
+ setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
+
+ if bias:
+ if config.use_cpu_initialization:
+ self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
+ else:
+ self.bias = Parameter(
+ torch.empty(
+ self.output_size,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+
+ if config.perform_initialization:
+ # Always initialize bias to zero.
+ with torch.no_grad():
+ self.bias.zero_()
+ setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
+ setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
+ else:
+ self.register_parameter('bias', None)
+
+ self._forward_impl = unaligned_linear_with_grad_accumulation_and_async_allreduce
+
+ # Hook adding a default empty _extra_state for state dict
+ self._register_load_state_dict_pre_hook(
+ lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
+ f'{prefix}_extra_state'
+ )
+ )
+
+ self.seq_length = seq_length
+ self.scatter_to_tensor_model_parallel_region = scatter_to_tensor_model_parallel_region
+ self.linear_with_frozen_weight = linear_with_frozen_weight
+ self.reduce_from_tensor_model_parallel_region = reduce_from_tensor_model_parallel_region
+ self.parallel_group = parallel_group
+
+ def forward(self, input_):
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ # Set up backprop all-reduce.
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ assert not self.sequence_parallel
+ input_parallel = self.scatter_to_tensor_model_parallel_region(input_)
+ # Matrix multiply.
+ allreduce_dgrad = False
+ if not self.weight.requires_grad:
+ self._forward_impl = self.linear_with_frozen_weight
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=allreduce_dgrad,
+ sequence_parallel=False,
+ grad_output_buffer=None,
+ allreduce_dgrad=allreduce_dgrad
+ )
+ else:
+ self._forward_impl = unaligned_linear_with_grad_accumulation_and_async_allreduce
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ sequence_parallel=False,
+ grad_output_buffer=None,
+ allreduce_dgrad=allreduce_dgrad,
+ parallel_group=self.parallel_group,
+ seq_length=self.seq_length
+ )
+
+ # All-reduce across all the partitions.
+ if self.explicit_expert_comm:
+ assert self.skip_bias_add
+ output_ = output_parallel
+ elif self.sequence_parallel:
+ output_ = unaligned_reduce_scatter_to_sequence_parallel_region(output_parallel, self.parallel_group)
+ else:
+ output_ = self.reduce_from_tensor_model_parallel_region(output_parallel)
+ if not self.skip_bias_add:
+ output = (output_ + self.bias) if self.bias is not None else output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.bias
+ return output, output_bias
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_utils.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..98217089def4e181ba78762ac4d086d328daf311
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel/unaligned_layers/unaligned_utils.py
@@ -0,0 +1,304 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+import torch
+from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32, npu_matmul_add_fp16
+
+
+def unaligned_divide(numerator, world_size, rank):
+ res = numerator // world_size
+ if rank < numerator % world_size:
+ res += 1
+ return res
+
+
+def unaligned_split_along_first_dim(input_, group):
+ world_size = torch.distributed.get_world_size(group=group)
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+ rank = torch.distributed.get_rank(group=group)
+
+ # Split along first dimension.
+ dim_size = input_.size()[0]
+
+ local_dim_size = unaligned_divide(dim_size, world_size, rank)
+
+ less_dim_size = dim_size // world_size
+ dim_offset = rank * less_dim_size
+ if rank >= dim_size % world_size:
+ dim_offset += dim_size % world_size
+ else:
+ dim_offset += rank
+
+ output = input_[dim_offset: dim_offset + local_dim_size].contiguous()
+
+ return output
+
+
+def unaligned_gather_along_first_dim(input_, dim_size, group, async_op=False):
+ """Gather tensors and concatinate along the first dimension."""
+
+ world_size = torch.distributed.get_world_size(group=group)
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+
+ output = []
+ for rank in range(world_size):
+ rank_dim_size = dim_size // world_size
+ if rank < dim_size % world_size:
+ rank_dim_size += 1
+ output.append(torch.empty((int(rank_dim_size), *(input_.size()[1:])), dtype=input_.dtype,
+ device=torch.cuda.current_device()))
+
+ handle = torch.distributed.all_gather(output, input_.contiguous(), group=group, async_op=async_op)
+
+ def post_process():
+ if handle is not None:
+ handle.wait()
+ return torch.cat(output)
+
+ if async_op:
+ return post_process
+ return post_process()
+
+
+class UnalignedScatterToSequenceParallelRegion(torch.autograd.Function):
+ """Split the input and keep only the corresponding chuck to the rank."""
+
+ @staticmethod
+ def forward(ctx, input_, group):
+ ctx.dim_size = list(input_.size())[0]
+ ctx.parallel_group = group
+ return unaligned_split_along_first_dim(input_, group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return unaligned_gather_along_first_dim(grad_output, ctx.dim_size, ctx.parallel_group), None
+
+
+def unaligned_scatter_to_sequence_parallel_region(input_, group):
+ return UnalignedScatterToSequenceParallelRegion.apply(input_, group)
+
+
+def unaligned_reduce_scatter_along_first_dim(input_, group, async_op=False):
+ """Reduce-scatter the input tensor across model parallel group."""
+ world_size = torch.distributed.get_world_size(group=group)
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+ rank = torch.distributed.get_rank(group=group)
+
+ # Split along first dimension.
+ dim_size = input_.size()[0]
+
+ local_dim_size = unaligned_divide(dim_size, world_size, rank)
+
+ less_dim_size = dim_size // world_size
+ dim_offset = rank * less_dim_size
+ if rank >= dim_size % world_size:
+ dim_offset += dim_size % world_size
+ else:
+ dim_offset += rank
+
+ input_ = input_.contiguous()
+ handle = torch.distributed.all_reduce(input_, group=group, async_op=async_op)
+
+ def post_process():
+ if handle is not None:
+ handle.wait()
+ return input_[dim_offset: dim_offset + local_dim_size].contiguous()
+
+ if async_op:
+ return post_process
+ return post_process()
+
+
+class UnalignedReduceScatterToSequenceParallelRegion(torch.autograd.Function):
+ """Reduce scatter the input from the model parallel region."""
+
+ @staticmethod
+ def forward(ctx, input_, group):
+ ctx.dim_size = list(input_.size())[0]
+ ctx.parallel_group = group
+ return unaligned_reduce_scatter_along_first_dim(input_, group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return unaligned_gather_along_first_dim(grad_output, ctx.dim_size, ctx.parallel_group), None
+
+
+def unaligned_reduce_scatter_to_sequence_parallel_region(input_, group):
+ return UnalignedReduceScatterToSequenceParallelRegion.apply(input_, group)
+
+
+class UnalignedGatherFromSequenceParallelRegion(torch.autograd.Function):
+ """Gather the input from the sequence parallel region."""
+
+ @staticmethod
+ def forward(ctx, input_, dim_size, group, tensor_parallel_output_grad):
+ ctx.dim_size = dim_size
+ ctx.parallel_group = group
+ ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
+ return unaligned_gather_along_first_dim(input_, dim_size, group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.tensor_parallel_output_grad:
+ return (
+ unaligned_reduce_scatter_to_sequence_parallel_region(grad_output),
+ None,
+ None,
+ None
+ )
+
+ else:
+ return (
+ unaligned_split_along_first_dim(grad_output),
+ None,
+ None,
+ None
+ )
+
+
+class UnalignedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
+ """See linear_with_grad_accumulation_and_async_allreduce"""
+
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ allreduce_dgrad,
+ sequence_parallel,
+ grad_output_buffer,
+
+ # unaligned parallel arguments
+ parallel_group,
+ seq_length=None
+ ):
+ ctx.save_for_backward(input, weight)
+ ctx.use_bias = bias is not None
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.allreduce_dgrad = allreduce_dgrad
+ ctx.sequence_parallel = sequence_parallel
+ ctx.grad_output_buffer = grad_output_buffer
+ ctx.parallel_group = parallel_group
+
+ if sequence_parallel:
+ if seq_length is None:
+ seq_len = torch.Tensor([list(input.size())[0]]).cuda()
+ torch.distributed.all_reduce(seq_len, group=parallel_group)
+ seq_length = seq_len.item()
+ total_input = unaligned_gather_along_first_dim(input, seq_length, parallel_group)
+ else:
+ total_input = input
+
+ output = torch.matmul(total_input, weight.t())
+ if bias is not None:
+ output = output + bias
+
+ ctx.seq_length = seq_length
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ grad_output_buffer = ctx.grad_output_buffer
+ parallel_group = ctx.parallel_group
+
+ wgrad_compute = True
+ post_process = None
+ total_input = None
+ if grad_output_buffer is not None:
+ grad_output_buffer.append(grad_output)
+ wgrad_compute = False
+
+ if wgrad_compute:
+ if ctx.sequence_parallel:
+ post_process = unaligned_gather_along_first_dim(input, ctx.seq_length, parallel_group, async_op=True)
+ else:
+ total_input = input
+ grad_input = grad_output.matmul(weight)
+
+ if ctx.sequence_parallel and wgrad_compute:
+ total_input = post_process()
+
+ if wgrad_compute and grad_output.dim() == 3:
+ grad_output = grad_output.contiguous()
+ grad_output = grad_output.view(-1, grad_output.shape[2])
+ total_input = total_input.view(-1, total_input.shape[2])
+
+ if ctx.allreduce_dgrad:
+ # Asynchronous all-reduce
+ handle = torch.distributed.all_reduce(grad_input, group=parallel_group, async_op=True)
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # all-reduce is scheduled before the weight gradient computation
+
+ if ctx.sequence_parallel:
+ assert not ctx.allreduce_dgrad
+ post_process = unaligned_reduce_scatter_along_first_dim(grad_input, parallel_group, async_op=True)
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # reduce scatter is scheduled before the weight gradient computation
+
+ if ctx.gradient_accumulation_fusion:
+ if wgrad_compute:
+ if weight.main_grad.dtype == torch.float32:
+ npu_matmul_add_fp32(total_input, grad_output, weight.main_grad)
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ npu_matmul_add_fp16(total_input, grad_output, weight.main_grad)
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.sequence_parallel:
+ sub_grad_input = post_process()
+ # Need to return None's as gradient has to flow for all the input arguments
+ # provided during forward
+ return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None
+
+ if ctx.allreduce_dgrad:
+ handle.wait()
+
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
+
+
+def unaligned_gather_from_sequence_parallel_region(input_, dim_size, group, tensor_parallel_output_grad):
+ return UnalignedGatherFromSequenceParallelRegion.apply(input_, dim_size, group, tensor_parallel_output_grad)
+
+
+def unaligned_linear_with_grad_accumulation_and_async_allreduce(
+ input, weight, bias, gradient_accumulation_fusion, allreduce_dgrad, sequence_parallel, grad_output_buffer,
+ parallel_group, seq_length=None
+):
+ args = [input, weight, bias, gradient_accumulation_fusion, allreduce_dgrad, sequence_parallel, grad_output_buffer,
+ parallel_group, seq_length]
+ return UnalignedLinearWithGradAccumulationAndAsyncCommunication.apply(*args)
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel_x_union_cp.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel_x_union_cp.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a914f02dd379a7f82739d63cd83761487746f89
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel_x_union_cp.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+
+import torch
+import torch.distributed
+
+from megatron.core.parallel_state import get_nccl_options
+
+from mindspeed.core.mindspeed_parallel_group import MindspeedParallelGroup
+from mindspeed.core.simple_parallel_cfg import SimpleParallelCfg
+from mindspeed.core.singleton_meta import SingletonMeta
+
+
+class TensorParallelXUnionCP(MindspeedParallelGroup, metaclass=SingletonMeta):
+ def __init__(
+ self,
+ parallel_cfg: SimpleParallelCfg = None,
+ pg_name: str = None,
+ overlap_gp_name: str = None,
+ nccl_comm_cfgs=None,
+ ):
+ super().__init__(parallel_cfg, pg_name, overlap_gp_name, nccl_comm_cfgs)
+
+ @staticmethod
+ def init_group(
+ parallel_cfg: SimpleParallelCfg,
+ pg_name: str,
+ overlap_gp_name: str = None,
+ nccl_comm_cfgs=None,
+ ):
+ pp = parallel_cfg.pp
+ tp = parallel_cfg.tp
+ cp = parallel_cfg.cp
+ tp_x = parallel_cfg.tp_x
+
+ rank = torch.distributed.get_rank()
+ world_size: int = torch.distributed.get_world_size()
+ num_pp_groups: int = world_size // pp
+ dp = world_size // (tp * pp * cp)
+
+ all_cp_grps = []
+ for i in range(pp):
+ for j in range(dp):
+ start_rank = i * num_pp_groups + j * tp * cp
+ end_rank = i * num_pp_groups + (j + 1) * tp * cp
+ for k in range(tp):
+ ranks = range(start_rank + k, end_rank, tp)
+ all_cp_grps.append(list(ranks))
+
+ all_tp_x_grps = []
+ num_tp_grps: int = world_size // tp
+ for i in range(num_tp_grps):
+ for j in range(tp // tp_x):
+ ranks = range(i * tp + j * tp_x, i * tp + (j + 1) * tp_x)
+ all_tp_x_grps.append(list(ranks))
+
+ # Build the tensor model-parallel-x-cp groups.
+ res_group, res_overlap_group, res_global_ranks = None, None, None
+ tp_x_cp_grp_ranks = []
+ tp_x_cp_rank_in_grp = []
+ for cp_grp in all_cp_grps:
+ for cp_rank in cp_grp:
+ for tp_x_grp in all_tp_x_grps:
+ if cp_rank in tp_x_grp and tp_x_grp not in tp_x_cp_rank_in_grp:
+ tp_x_cp_rank_in_grp += tp_x_grp
+
+ tp_x_cp_rank_in_grp = sorted(tp_x_cp_rank_in_grp)
+ if tp_x_cp_rank_in_grp not in tp_x_cp_grp_ranks:
+ tp_x_cp_grp_ranks.append(tp_x_cp_rank_in_grp)
+ tp_x_cp_rank_in_grp = []
+
+ cur_overlap_group = None
+ for tp_x_cp_ranks in tp_x_cp_grp_ranks:
+ cur_group = torch.distributed.new_group(
+ tp_x_cp_ranks, pg_options=get_nccl_options(pg_name, nccl_comm_cfgs)
+ )
+ # for send-recv in parallel
+ if overlap_gp_name:
+ cur_overlap_group = torch.distributed.new_group(
+ tp_x_cp_ranks, pg_options=get_nccl_options(overlap_gp_name, nccl_comm_cfgs),
+ )
+
+ if rank in tp_x_cp_ranks:
+ res_global_ranks = tp_x_cp_ranks
+ res_group = cur_group
+ res_overlap_group = cur_overlap_group
+
+ return res_group, res_global_ranks, res_overlap_group
diff --git a/model/train/yoco_moe/mindspeed/core/tensor_parallel_y_union_cp.py b/model/train/yoco_moe/mindspeed/core/tensor_parallel_y_union_cp.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9796f13b2e761ee96b16b92a919057ae601e8ce
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/tensor_parallel_y_union_cp.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+
+import torch
+import torch.distributed
+
+from megatron.core.parallel_state import get_nccl_options
+
+from mindspeed.core.mindspeed_parallel_group import MindspeedParallelGroup
+from mindspeed.core.simple_parallel_cfg import SimpleParallelCfg
+from mindspeed.core.singleton_meta import SingletonMeta
+
+
+class TensorParallelYUnionCP(MindspeedParallelGroup, metaclass=SingletonMeta):
+ def __init__(
+ self,
+ parallel_cfg: SimpleParallelCfg = None,
+ pg_name: str = None,
+ overlap_gp_name: str = None,
+ nccl_comm_cfgs=None,
+ ):
+ super().__init__(parallel_cfg, pg_name, overlap_gp_name, nccl_comm_cfgs)
+
+ @staticmethod
+ def init_group(
+ parallel_cfg: SimpleParallelCfg,
+ pg_name: str,
+ overlap_gp_name: str = None,
+ nccl_comm_cfgs=None,
+ ):
+ pp = parallel_cfg.pp
+ tp = parallel_cfg.tp
+ cp = parallel_cfg.cp
+ tp_x = parallel_cfg.tp_x
+
+ rank = torch.distributed.get_rank()
+ world_size: int = torch.distributed.get_world_size()
+ num_pp_groups: int = world_size // pp
+ dp = world_size // (tp * pp * cp)
+
+ all_cp_grps = []
+ for i in range(pp):
+ for j in range(dp):
+ start_rank = i * num_pp_groups + j * tp * cp
+ end_rank = i * num_pp_groups + (j + 1) * tp * cp
+ for k in range(tp):
+ ranks = range(start_rank + k, end_rank, tp)
+ all_cp_grps.append(list(ranks))
+
+ all_tp_y_grps = []
+ num_tp_grps: int = world_size // tp
+ for i in range(num_tp_grps):
+ num_tp_x_group = tp // tp_x
+ for j in range(tp // num_tp_x_group):
+ ranks = range(i * tp + j, (i + 1) * tp, tp_x)
+ all_tp_y_grps.append(list(ranks))
+
+ # Build the tensor model-parallel-y-cp groups.
+ res_group, res_overlap_group, res_global_ranks = None, None, None
+ tp_y_cp_grp_ranks = []
+ tp_y_cp_rank_in_grp = []
+ for cp_grp in all_cp_grps:
+ for cp_rank in cp_grp:
+ for tp_y_grp in all_tp_y_grps:
+ if cp_rank in tp_y_grp and tp_y_grp not in tp_y_cp_rank_in_grp:
+ tp_y_cp_rank_in_grp += tp_y_grp
+
+ tp_y_cp_rank_in_grp = sorted(tp_y_cp_rank_in_grp)
+ if tp_y_cp_rank_in_grp not in tp_y_cp_grp_ranks:
+ tp_y_cp_grp_ranks.append(tp_y_cp_rank_in_grp)
+ tp_y_cp_rank_in_grp = []
+
+ cur_overlap_group = None
+ for tp_y_cp_ranks in tp_y_cp_grp_ranks:
+ cur_group = torch.distributed.new_group(
+ tp_y_cp_ranks, pg_options=get_nccl_options(pg_name, nccl_comm_cfgs)
+ )
+ # for send-recv in parallel
+ if overlap_gp_name:
+ cur_overlap_group = torch.distributed.new_group(
+ tp_y_cp_ranks, pg_options=get_nccl_options(overlap_gp_name, nccl_comm_cfgs),
+ )
+
+ if rank in tp_y_cp_ranks:
+ res_global_ranks = tp_y_cp_ranks
+ res_group = cur_group
+ res_overlap_group = cur_overlap_group
+
+ return res_group, res_global_ranks, res_overlap_group
diff --git a/model/train/yoco_moe/mindspeed/core/training.py b/model/train/yoco_moe/mindspeed/core/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..53dc5ef1765c8cdbcb50a7c9ced8f74f9541c1cd
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/training.py
@@ -0,0 +1,544 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+"""Pretrain utilities."""
+
+import os
+import sys
+import gc
+import os
+from functools import wraps
+import torch
+import torch_npu
+from datetime import datetime
+from megatron.training import get_args
+from megatron.training import get_timers
+from megatron.training import is_last_rank
+from megatron.core import parallel_state
+from megatron.core.num_microbatches_calculator import get_num_microbatches
+from megatron.core.transformer.moe.moe_utils import track_moe_metrics
+from megatron.training import print_rank_0
+from megatron.training.arguments import parse_args
+from megatron.training.global_vars import (set_args, get_tensorboard_writer, get_wandb_writer,
+ get_one_logger)
+from megatron.training.training import num_floating_point_operations
+from megatron.training.utils import print_rank_last, report_memory
+from megatron.training.theoretical_memory_usage import report_theoretical_memory
+from mindspeed.core.auto_parallel.auto_parallel_apply import search_optimal_configuration
+from mindspeed.core.auto_parallel.auto_parallel_profiling import Profiling, OperateProfile
+from mindspeed.core.memory.auto_pipeline.autopipeline import autopipeline_profiling
+from mindspeed.core.performance.auto_pipeline_perf.autopipeline_perf import (autopipelineperf_profiling, check_out_of_memory,
+ calculate_num_of_activations, check_skip_profiling,
+ broadcast_skip_in_ranks)
+from mindspeed.core.performance.auto_pipeline_perf.optimpipeline_solver import solve_optimpipeline, broadcast_oom_in_ranks, broadcast_mbs_in_ranks, save_profiling_data
+from mindspeed.core.performance.auto_pipeline_perf.schedulepipeline_solver import (solve_pipelineschedule, broadcast_enable_schedule_in_ranks,
+ broadcast_scheduler_in_ranks, broadcast_layer_in_ranks,
+ all_gather_time, average_time_by_rank)
+from mindspeed.core.memory.auto_pipeline.autopipeline_apply import apply_autopipeline
+from mindspeed.core.memory.auto_pipeline.autopipeline_solver import solve_autopipeline, broadcast_policy_in_ranks, destroy_global_vars
+from mindspeed.arguments import parse_args_wrapper
+
+
+POLICY = None
+OPTIMIZED_MBS_LIST = None
+PP_SCHEDULE_LIST = None
+OPTIMAL_LAYERS = None
+ORIGIN_MBS = None
+DATA_PARALLEL_SIZE = 1
+ENABLE_SCHEDULER = False
+FLOPS_COUNTER = None
+RECORDED_COUNT = 0
+TRAVERSED_COUNT = 0
+
+
+def generated_flops_counter():
+ from torch_npu.utils.flops_count import FlopsCounter
+ global FLOPS_COUNTER
+ FLOPS_COUNTER = FlopsCounter()
+
+
+def get_flops_counter():
+ global FLOPS_COUNTER
+ if FLOPS_COUNTER is None:
+ generated_flops_counter()
+ return FLOPS_COUNTER
+
+
+def set_count(count):
+ global RECORDED_COUNT
+ global TRAVERSED_COUNT
+ RECORDED_COUNT = count[0]
+ TRAVERSED_COUNT = count[1]
+
+
+def get_count():
+ global RECORDED_COUNT
+ global TRAVERSED_COUNT
+ if RECORDED_COUNT == 0 and TRAVERSED_COUNT == 0:
+ flops_counter = get_flops_counter()
+ count = flops_counter.get_flops()
+ set_count(count)
+ return RECORDED_COUNT, TRAVERSED_COUNT
+
+
+def train_decorator(train):
+ @wraps(train)
+ def wrapper(*args, **kwargs):
+ args_ = get_args()
+ if args_.profile:
+ args_.profile_npu = True
+ args_.profile = False
+ else:
+ args_.profile_npu = False
+
+ is_profile = hasattr(args_, 'profile_npu') and args_.profile_npu \
+ and ((torch.distributed.get_rank() in args_.profile_ranks) or (-1 in args_.profile_ranks))
+ if is_profile:
+ active = args_.profile_step_end - args_.profile_step_start
+ skip_first = args_.profile_step_start
+
+ if args_.profile_with_cpu:
+ activities = [torch_npu.profiler.ProfilerActivity.NPU, torch_npu.profiler.ProfilerActivity.CPU]
+ else:
+ activities = [torch_npu.profiler.ProfilerActivity.NPU]
+
+ if args_.profile_level == 'level0':
+ profiler_level = torch_npu.profiler.ProfilerLevel.Level0
+ elif args_.profile_level == 'level1':
+ profiler_level = torch_npu.profiler.ProfilerLevel.Level1
+ elif args_.profile_level == 'level2':
+ profiler_level = torch_npu.profiler.ProfilerLevel.Level2
+ else:
+ raise ValueError(f"profiler_level only support level0, level1, level2, but gets {args_.profile_level}")
+
+ experimental_config = torch_npu.profiler._ExperimentalConfig(
+ aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
+ profiler_level=profiler_level,
+ l2_cache=False
+ )
+
+ with torch_npu.profiler.profile(
+ activities=activities,
+ record_shapes=args_.profile_record_shapes,
+ profile_memory=args_.profile_with_memory,
+ with_stack=args_.profile_with_stack,
+ experimental_config=experimental_config,
+ schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=active, repeat=1, skip_first=skip_first),
+ on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(args_.profile_save_path)
+ ) as prof:
+ args_.prof = prof
+ return train(*args, **kwargs)
+ else:
+ return train(*args, **kwargs)
+
+ return wrapper
+
+
+def train_step_decorator(train_step):
+ @wraps(train_step)
+ def wrapper(*args, **kwargs):
+ nonlocal train_step
+ args_ = get_args()
+ flop_count = None
+ if args_.op_cal_tflops:
+ flop_count = get_flops_counter()
+ flop_count.start()
+ if args_.profile_operator:
+ op_profile = OperateProfile(args_)
+ ret = train_step(*args, **kwargs)
+ op_profile.step()
+ elif args_.prof_file:
+ profiling = Profiling(args_)
+ train_step = profiling.hook_train_step(train_step)
+ ret = train_step(*args, **kwargs)
+ else:
+ ret = train_step(*args, **kwargs)
+ is_profile = args_.profile_npu and ((torch.distributed.get_rank() in args_.profile_ranks) or (-1 in args_.profile_ranks))
+ if is_profile:
+ args_.prof.step()
+ if args_.op_cal_tflops:
+ counts = flop_count.get_flops()
+ set_count(counts)
+ flop_count.stop()
+ return ret
+ return wrapper
+
+
+def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
+ loss_scale, report_memory_flag, skipped_iter,
+ grad_norm, params_norm, num_zeros_in_grad):
+ """Log training information such as losses, timing, ...."""
+ args = get_args()
+ timers = get_timers()
+ writer = get_tensorboard_writer()
+ wandb_writer = get_wandb_writer()
+ one_logger = get_one_logger()
+
+ # Advanced, skipped, and Nan iterations.
+ advanced_iters_key = 'advanced iterations'
+ skipped_iters_key = 'skipped iterations'
+ nan_iters_key = 'nan iterations'
+ # Advanced iterations.
+ if not skipped_iter:
+ total_loss_dict[advanced_iters_key] = total_loss_dict.get(
+ advanced_iters_key, 0) + 1
+ else:
+ if advanced_iters_key not in total_loss_dict:
+ total_loss_dict[advanced_iters_key] = 0
+ # Skipped iterations.
+ total_loss_dict[skipped_iters_key] = total_loss_dict.get(
+ skipped_iters_key, 0) + skipped_iter
+ # Update losses and set nan iterations
+ got_nan = False
+ for key in loss_dict:
+ if not skipped_iter:
+ total_loss_dict[key] = total_loss_dict.get(
+ key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + loss_dict[key]
+ else:
+ value = loss_dict[key].float().sum().item()
+ is_nan = value == float('inf') or \
+ value == -float('inf') or \
+ value != value
+ got_nan = got_nan or is_nan
+ total_loss_dict[nan_iters_key] = total_loss_dict.get(
+ nan_iters_key, 0) + int(got_nan)
+
+ # Logging.
+ timers_to_log = [
+ 'forward-backward',
+ 'forward-compute',
+ 'backward-compute',
+ 'batch-generator',
+ 'forward-recv',
+ 'forward-send',
+ 'backward-recv',
+ 'backward-send',
+ 'forward-send-forward-recv',
+ 'forward-send-backward-recv',
+ 'backward-send-forward-recv',
+ 'backward-send-backward-recv',
+ 'forward-backward-send-forward-backward-recv',
+ 'layernorm-grads-all-reduce',
+ 'embedding-grads-all-reduce',
+ 'all-grads-sync',
+ 'params-all-gather',
+ 'optimizer-copy-to-main-grad',
+ 'optimizer-unscale-and-check-inf',
+ 'optimizer-clip-main-grad',
+ 'optimizer-count-zeros',
+ 'optimizer-inner-step',
+ 'optimizer-copy-main-to-model-params',
+ 'optimizer']
+
+ # Calculate batch size.
+ batch_size = args.micro_batch_size * args.data_parallel_size * \
+ get_num_microbatches()
+
+ # Track app tag & app tag ID
+ if one_logger:
+ job_name = os.environ.get('SLURM_JOB_NAME', None)
+ current_app_tag = f'{job_name}_{batch_size}_{args.world_size}'
+ one_logger.log_app_tag(current_app_tag)
+
+ total_iterations = total_loss_dict[advanced_iters_key] + \
+ total_loss_dict[skipped_iters_key]
+
+ # Tensorboard values.
+ # Timer requires all the ranks to call.
+ if args.log_timers_to_tensorboard and \
+ (iteration % args.tensorboard_log_interval == 0):
+ timers.write(timers_to_log, writer, iteration,
+ normalizer=total_iterations)
+ if writer and (iteration % args.tensorboard_log_interval == 0):
+ if wandb_writer:
+ wandb_writer.log({'samples vs steps': args.consumed_train_samples},
+ iteration)
+ if args.log_learning_rate_to_tensorboard:
+ writer.add_scalar('learning-rate', learning_rate, iteration)
+ if args.decoupled_lr is not None:
+ writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
+ writer.add_scalar('learning-rate vs samples', learning_rate,
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({'learning-rate': learning_rate}, iteration)
+ if args.log_batch_size_to_tensorboard:
+ writer.add_scalar('batch-size', batch_size, iteration)
+ writer.add_scalar('batch-size vs samples', batch_size,
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({'batch-size': batch_size}, iteration)
+ for key in loss_dict:
+ writer.add_scalar(key , loss_dict[key], iteration)
+ writer.add_scalar(key + ' vs samples', loss_dict[key],
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({key: loss_dict[key]}, iteration)
+ if args.log_loss_scale_to_tensorboard:
+ writer.add_scalar('loss-scale', loss_scale, iteration)
+ writer.add_scalar('loss-scale vs samples', loss_scale,
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({'loss-scale': loss_scale}, iteration)
+ if args.log_world_size_to_tensorboard:
+ writer.add_scalar('world-size', args.world_size, iteration)
+ writer.add_scalar('world-size vs samples', args.world_size,
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({'world-size': args.world_size}, iteration)
+ if grad_norm is not None:
+ writer.add_scalar('grad-norm', grad_norm, iteration)
+ writer.add_scalar('grad-norm vs samples', grad_norm,
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({'grad-norm': grad_norm}, iteration)
+ if num_zeros_in_grad is not None:
+ writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
+ writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration)
+ if params_norm is not None:
+ writer.add_scalar('params-norm', params_norm, iteration)
+ writer.add_scalar('params-norm vs samples', params_norm,
+ args.consumed_train_samples)
+ if wandb_writer:
+ wandb_writer.log({'params-norm': params_norm}, iteration)
+ if args.log_memory_to_tensorboard:
+ mem_stats = torch.cuda.memory_stats()
+ writer.add_scalar(
+ "mem-reserved-bytes",
+ mem_stats["reserved_bytes.all.current"],
+ iteration,
+ )
+ writer.add_scalar(
+ "mem-allocated-bytes",
+ mem_stats["allocated_bytes.all.current"],
+ iteration,
+ )
+ writer.add_scalar(
+ "mem-allocated-count",
+ mem_stats["allocation.all.current"],
+ iteration,
+ )
+ if args.num_experts is not None:
+ moe_loss_scale = 1 / get_num_microbatches()
+ track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
+
+ if iteration % args.log_interval == 0:
+ elapsed_time = timers('interval-time').elapsed(barrier=True)
+ elapsed_time_per_iteration = elapsed_time / total_iterations
+
+ throughput = num_floating_point_operations(args, batch_size) / (
+ elapsed_time_per_iteration * 10**12 * args.world_size)
+
+ # select all nodes info
+ counts_0, counts_1 = get_count()
+ counts_0_tensor = torch.tensor([counts_0], device="npu")
+ counts_1_tensor = torch.tensor([counts_1], device="npu")
+
+ torch.distributed.all_reduce(
+ counts_0_tensor, op=torch.distributed.ReduceOp.SUM
+ )
+ torch.distributed.all_reduce(
+ counts_1_tensor, op=torch.distributed.ReduceOp.SUM
+ )
+
+ mfu = counts_0_tensor.cpu().item() / (10 ** 12 * elapsed_time_per_iteration * args.world_size)
+ hfu = counts_1_tensor.cpu().item() / (10 ** 12 * elapsed_time_per_iteration * args.world_size)
+
+ if args.log_timers_to_tensorboard:
+ if writer:
+ writer.add_scalar('iteration-time',
+ elapsed_time_per_iteration, iteration)
+ if wandb_writer:
+ wandb_writer.log({'iteration-time': elapsed_time_per_iteration},
+ iteration)
+ log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
+ log_string += ' iteration {:8d}/{:8d} |'.format(
+ iteration, args.train_iters)
+ log_string += ' consumed samples: {:12d} |'.format(
+ args.consumed_train_samples)
+ log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
+ elapsed_time_per_iteration * 1000.0)
+ if args.log_throughput:
+ log_string += f' theoretical throughput per NPU (TFLOP/s/NPU): {throughput:.1f} |'
+ log_string += f' actual throughput per NPU (TFLOP/s/NPU): {mfu:.1f} |'
+ log_string += f' actual throughput per NPU with recompute (TFLOP/s/NPU): {hfu:.1f} |'
+ if args.log_timers_to_tensorboard:
+ if writer:
+ writer.add_scalar('throughput', throughput, iteration)
+ if wandb_writer:
+ wandb_writer.log({'throughput': throughput}, iteration)
+ assert learning_rate is not None
+ # Decoupled_learning_rate should be not None only on first and last pipeline stage.
+ log_string += ' learning rate: {:.6E} |'.format(learning_rate)
+ if args.decoupled_lr is not None and (parallel_state.is_pipeline_first_stage(ignore_virtual=True) or
+ parallel_state.is_pipeline_last_stage(ignore_virtual=True)):
+ assert decoupled_learning_rate is not None
+ log_string += ' decoupled learning rate: {:.6E} |'.format(decoupled_learning_rate)
+ else:
+ assert decoupled_learning_rate is None
+ log_string += ' global batch size: {:5d} |'.format(batch_size)
+ for key in total_loss_dict:
+ if key not in [advanced_iters_key, skipped_iters_key,
+ nan_iters_key]:
+ avg = total_loss_dict[key].item() / \
+ float(max(1, total_loss_dict[advanced_iters_key]))
+ if avg > 0.0:
+ log_string += ' {}: {:.6E} |'.format(key, avg)
+ total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda')
+ log_string += ' loss scale: {:.1f} |'.format(loss_scale)
+ if grad_norm is not None:
+ log_string += ' grad norm: {:.3f} |'.format(grad_norm)
+ if num_zeros_in_grad is not None:
+ log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
+ if params_norm is not None:
+ log_string += ' params norm: {:.3f} |'.format(params_norm)
+ log_string += ' number of skipped iterations: {:3d} |'.format(
+ total_loss_dict[skipped_iters_key])
+ log_string += ' number of nan iterations: {:3d} |'.format(
+ total_loss_dict[nan_iters_key])
+ total_loss_dict[advanced_iters_key] = 0
+ total_loss_dict[skipped_iters_key] = 0
+ total_loss_dict[nan_iters_key] = 0
+ print_rank_last(log_string)
+ if report_memory_flag and learning_rate > 0.:
+ # Report memory after optimizer state has been initialized.
+ if torch.distributed.get_rank() == 0:
+ num_microbatches = get_num_microbatches()
+ report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
+ report_memory('(after {} iterations)'.format(iteration))
+ report_memory_flag = False
+ timers.log(timers_to_log, normalizer=args.log_interval)
+
+ return report_memory_flag
+
+
+def pretrain_decorator(pretrain):
+ @wraps(pretrain)
+ def wrapper(*args, **kwargs):
+ global POLICY
+ global OPTIMIZED_MBS_LIST
+ global PP_SCHEDULE_LIST
+ global OPTIMAL_LAYERS
+ global ORIGIN_MBS
+ global DATA_PARALLEL_SIZE
+ global ENABLE_SCHEDULER
+ new_parse_args = parse_args_wrapper(parse_args)
+ argument = new_parse_args(kwargs.get('extra_args_provider'), False)
+
+ if argument.auto_tuning:
+ set_args(argument)
+ print("pretrain_decorator set_args ========================================")
+
+ from mindspeed.auto_tuning.auto_tuning import auto_tuning
+ global_args = get_args()
+ assert global_args.auto_tuning_ranks >= 16, "Auto-tuning searching space should be >= 16."
+ working_dir_root = os.path.realpath(global_args.auto_tuning_work_dir)
+ if not os.path.exists(working_dir_root) and global_args.rank % torch.cuda.device_count() == 0:
+ os.makedirs(working_dir_root)
+
+ if global_args.rank % torch.cuda.device_count() == 0:
+ print("only rank 0 run auto tuning ========================================")
+ auto_tuning(global_args, working_dir=working_dir_root)
+ return
+
+ if argument.auto_parallel:
+ set_args(argument)
+ search_optimal_configuration(argument)
+ return
+
+ if argument.automated_pipeline and not argument.num_layer_list:
+ context, POLICY = autopipeline_profiling(args[1], args[2], args[3],
+ args[0], None, argument)
+ if context:
+ POLICY = solve_autopipeline(context)
+ parallel_state.destroy_global_memory_buffer()
+ parallel_state.destroy_model_parallel()
+ destroy_global_vars()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if argument.automated_pipeline_perf:
+ ORIGIN_MBS = argument.micro_batch_size
+ is_skip, exist_policy = check_skip_profiling(argument, config_file="autopipeline_perf_config.json")
+ if not is_skip:
+ global_context = []
+ mbs_time, pp_schedule_time = 0, 0
+ mbs_tries = 1
+ num_forwards_first_stage = 0
+ is_oom = False
+ forward_time_dict = {}
+ backward_time_dict = {}
+
+ while mbs_tries < ORIGIN_MBS + 2:
+ context = autopipelineperf_profiling(mbs_tries, args[1], args[2], args[3],
+ args[0], None)
+ if mbs_tries == ORIGIN_MBS:
+ schedule_context = context
+ forward_time_list = all_gather_time(argument, schedule_context['fwd_time'])
+ forward_time_dict = average_time_by_rank(forward_time_list)
+ backward_time_list = all_gather_time(argument, schedule_context['bwd_time'])
+ backward_time_dict = average_time_by_rank(backward_time_list)
+ num_forwards_first_stage = calculate_num_of_activations(schedule_context)
+
+ parallel_state.destroy_global_memory_buffer()
+ parallel_state.destroy_model_parallel()
+ destroy_global_vars()
+ gc.collect()
+ torch.cuda.empty_cache()
+ global_context.append((context['fwd_time'], context['bwd_time'], context['comm_time']))
+ DATA_PARALLEL_SIZE = context['data_parallel_size']
+ if not is_oom:
+ is_oom = check_out_of_memory(argument, context, mbs_tries)
+ is_oom = broadcast_oom_in_ranks(0, is_oom)
+ mbs_tries += 1
+ if mbs_tries <= ORIGIN_MBS and is_oom:
+ raise AssertionError(
+ 'A risk of Out of Memory could occur, please '
+ 'reset to a smaller micro batch size.')
+ if mbs_tries > ORIGIN_MBS and is_oom:
+ break
+ if len(global_context) > 0:
+ OPTIMIZED_MBS_LIST, mbs_time = solve_optimpipeline(argument, DATA_PARALLEL_SIZE, global_context)
+ PP_SCHEDULE_LIST, pp_schedule_time, OPTIMAL_LAYERS = solve_pipelineschedule(argument, DATA_PARALLEL_SIZE, num_forwards_first_stage, forward_time_dict, backward_time_dict)
+ if torch.distributed.get_rank() == 0 and mbs_time > pp_schedule_time and num_forwards_first_stage > 2:
+ ENABLE_SCHEDULER = True
+ ENABLE_SCHEDULER = broadcast_enable_schedule_in_ranks(0, ENABLE_SCHEDULER)
+ optimized_policy = (ENABLE_SCHEDULER, OPTIMIZED_MBS_LIST, PP_SCHEDULE_LIST, OPTIMAL_LAYERS)
+ save_profiling_data(optimized_policy, config_file="autopipeline_perf_config.json")
+ else:
+ ENABLE_SCHEDULER = exist_policy[0]
+ OPTIMIZED_MBS_LIST = exist_policy[1]
+ PP_SCHEDULE_LIST = exist_policy[2]
+ OPTIMAL_LAYERS = exist_policy[3]
+ pretrain(*args, **kwargs)
+ return wrapper
+
+
+def setup_model_and_optimizer_decorator(setup_model_and_optimizer):
+ @wraps(setup_model_and_optimizer)
+ def wrapper(*args, **kwargs):
+ global POLICY
+ global OPTIMIZED_MBS_LIST
+ global PP_SCHEDULE_LIST
+ global OPTIMAL_LAYERS
+ global ENABLE_SCHEDULER
+ argument = get_args()
+ if argument.automated_pipeline and POLICY:
+ if torch.distributed.get_rank() == 0:
+ broadcast_policy_in_ranks(0, POLICY)
+ else:
+ broadcast_policy_in_ranks(0)
+ if argument.automated_pipeline_perf and ENABLE_SCHEDULER:
+ broadcast_scheduler_in_ranks(0, PP_SCHEDULE_LIST)
+ broadcast_layer_in_ranks(0, OPTIMAL_LAYERS)
+ elif argument.automated_pipeline_perf and OPTIMIZED_MBS_LIST:
+ broadcast_mbs_in_ranks(0, OPTIMIZED_MBS_LIST)
+ model, optimizer, opt_param_scheduler = setup_model_and_optimizer(*args, **kwargs)
+ if argument.recompute_module_list:
+ apply_autopipeline(model)
+ if argument.profile_memory and torch.distributed.get_rank() in argument.profile_ranks:
+ profiling = Profiling(argument)
+ profiling.register_recursive_hook("", model)
+ return model, optimizer, opt_param_scheduler
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/__init__.py b/model/train/yoco_moe/mindspeed/core/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/attention.py b/model/train/yoco_moe/mindspeed/core/transformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..338ae8c0223b3f96da98a245b9af66ad211a9fb9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/attention.py
@@ -0,0 +1,496 @@
+# coding=utf-8
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from dataclasses import dataclass
+from functools import wraps
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+
+from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
+from megatron.core.transformer import TransformerConfig, ModuleSpec, build_module
+from megatron.core.transformer.attention import SelfAttentionSubmodules, CrossAttentionSubmodules, \
+ Attention
+from megatron.core.transformer.enums import AttnMaskType
+from megatron.core import mpu, parallel_state
+from megatron.core.utils import divide
+from megatron.training import get_args
+
+from mindspeed.core.context_parallel.ulysses_context_parallel import UlyssesContextAttention
+from mindspeed.core.parallel_state import get_context_parallel_group_for_hybrid_ulysses, \
+ get_tensor_model_parallel_world_size_for_nd1_dim1
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm, TPXOverlapCollectiveComm, \
+ TPYCollectiveComm, TPYOverlapCollectiveComm
+from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
+from mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d import ParallelLinear2D
+
+
+@dataclass
+class SelfAttentionSubmodules:
+ linear_qkv: Union[ModuleSpec, type] = None
+ core_attention: Union[ModuleSpec, type] = None
+ linear_proj: Union[ModuleSpec, type] = None
+ q_layernorm: Union[ModuleSpec, type] = None
+ k_layernorm: Union[ModuleSpec, type] = None
+ linear_qb: Union[ModuleSpec, type] = None
+ linear_kvb: Union[ModuleSpec, type] = None
+
+
+def attention_init(
+ self,
+ config: TransformerConfig,
+ submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
+ layer_number: int,
+ attn_mask_type: AttnMaskType,
+ attention_type: str,
+):
+ super(Attention, self).__init__(config=config)
+ self.config = config
+ self.layer_number = layer_number
+ self.attn_mask_type = attn_mask_type
+ self.attention_type = attention_type
+
+ # For normal attention without groups, num_query_groups == num_attention_heads,
+ # so these two will be the same
+ self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
+ self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
+
+ args = get_args()
+ # patch for tp-2d
+ world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size()
+ # Per attention head and per partition values.
+ self.hidden_size_per_attention_head = divide(
+ self.query_projection_size, self.config.num_attention_heads
+ )
+ self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
+ self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
+
+ self.core_attention = build_module(
+ submodules.core_attention,
+ config=self.config,
+ layer_number=self.layer_number,
+ attn_mask_type=self.attn_mask_type,
+ attention_type=self.attention_type,
+ )
+
+ self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
+
+ # Output.
+ self.linear_proj = build_module(
+ submodules.linear_proj,
+ self.query_projection_size,
+ self.config.hidden_size,
+ config=self.config,
+ init_method=self.config.output_layer_init_method,
+ bias=self.config.add_bias_linear,
+ input_is_parallel=True,
+ skip_bias_add=True,
+ is_expert=False,
+ tp_comm_buffer_name='proj',
+ )
+ cp = config.context_parallel_size
+ if args.tp_2d:
+ tp_y_cp_sz = cp * args.tp_y
+ else:
+ tp_y_cp_sz = cp
+ if tp_y_cp_sz > 1 and args.context_parallel_algo in ['ulysses_cp_algo', 'hybrid_cp_algo',
+ 'hybrid_adaptive_cp_algo']:
+ if args.tp_2d:
+ tp_y_cp = TensorParallelYUnionCP()
+ ulysses_group = tp_y_cp.group
+ else:
+ ulysses_group = mpu.get_context_parallel_group()
+ if args.context_parallel_algo in ['hybrid_cp_algo', 'hybrid_adaptive_cp_algo']:
+ ulysses_group = get_context_parallel_group_for_hybrid_ulysses()
+ self.core_attention = UlyssesContextAttention(self.core_attention, ulysses_group)
+
+
+def attention_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ fn(self, *args, **kwargs)
+ if self.config.num_query_groups is None:
+ self.config.num_query_groups = self.config.num_attention_heads
+ self.num_attention_heads_per_partition = self.config.num_attention_heads * self.num_query_groups_per_partition // self.config.num_query_groups
+
+ return wrapper
+
+
+def self_attention_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self,
+ config: TransformerConfig,
+ submodules: SelfAttentionSubmodules,
+ layer_number: int,
+ attn_mask_type=AttnMaskType.padding, ):
+ args = get_args()
+ if args.overlap_param_gather:
+ config.reset_attention_order = True
+ fn(self, config, submodules, layer_number, attn_mask_type)
+ return wrapper
+
+
+def self_attention_init_mla_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self,
+ config: TransformerConfig,
+ submodules: SelfAttentionSubmodules,
+ layer_number: int,
+ attn_mask_type=AttnMaskType.padding, ):
+
+ args = get_args()
+ fn(self, config, submodules, layer_number, attn_mask_type)
+ if args.multi_head_latent_attention:
+ self.use_flash_attn = args.use_flash_attn
+ self.shape_order = args.shape_order
+ self.qk_rope_head_dim = args.qk_rope_head_dim
+ self.qk_nope_head_dim = args.qk_nope_head_dim
+ self.q_lora_rank = args.q_lora_rank
+ self.kv_lora_rank = args.kv_lora_rank
+ self.v_head_dim = args.v_head_dim
+
+ query_projection_size = self.config.num_attention_heads * self.v_head_dim
+ self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
+
+ if self.q_lora_rank is None:
+ self.q_rank = self.config.num_attention_heads * self.q_head_dim
+ self.q_layernorm = None
+ else:
+ self.q_rank = self.q_lora_rank
+ if submodules.q_layernorm is not None:
+ self.q_layernorm = build_module(
+ submodules.q_layernorm,
+ hidden_size=self.q_lora_rank,
+ config=self.config,
+ eps=self.config.layernorm_epsilon,
+ )
+ else:
+ self.q_layernorm = None
+ self.linear_qb = build_module(
+ submodules.linear_qb,
+ self.q_lora_rank,
+ self.config.num_attention_heads * self.q_head_dim,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='qb',
+ )
+
+ self.linear_qkv = build_module(
+ submodules.linear_qkv,
+ self.config.hidden_size,
+ self.q_rank + self.kv_lora_rank + self.qk_rope_head_dim,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='qkv',
+ )
+
+ if submodules.k_layernorm is not None:
+ self.k_layernorm = build_module(
+ submodules.k_layernorm,
+ hidden_size=self.kv_lora_rank,
+ config=self.config,
+ eps=self.config.layernorm_epsilon,
+ )
+ else:
+ self.k_layernorm = None
+
+ self.linear_kvb = build_module(
+ submodules.linear_kvb,
+ self.kv_lora_rank,
+ self.config.num_attention_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='kvb',
+ )
+
+ self.linear_proj = build_module(
+ submodules.linear_proj,
+ query_projection_size,
+ self.config.hidden_size,
+ config=self.config,
+ init_method=self.config.output_layer_init_method,
+ bias=self.config.add_bias_linear,
+ input_is_parallel=True,
+ skip_bias_add=True,
+ is_expert=False,
+ tp_comm_buffer_name='proj',
+ )
+
+ return wrapper
+
+
+def self_attention_init_tp2d_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self,
+ config: TransformerConfig,
+ submodules: SelfAttentionSubmodules,
+ layer_number: int,
+ attn_mask_type=AttnMaskType.padding, ):
+
+ args = get_args()
+ fn(self, config, submodules, layer_number, attn_mask_type)
+ if args.tp_2d:
+ attn_heads_split_num = get_tensor_model_parallel_world_size_for_nd1_dim1()
+ self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, attn_heads_split_num)
+ self.num_query_groups_per_partition = divide(self.config.num_query_groups, attn_heads_split_num)
+ self.linear_qkv = ParallelLinear2D(
+ self.config.hidden_size,
+ self.query_projection_size + 2 * self.kv_projection_size,
+ config=self.config,
+ init_method=self.config.init_method,
+ add_bias=self.config.add_bias_linear,
+ skip_bias_add=True,
+ ag_comm_intf=TPXCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ rs_comm_intf=TPYCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ partition_dim=0,
+ enable_backward_overlap_ag_with_matmul=False,
+ )
+ self.linear_proj = ParallelLinear2D(
+ self.query_projection_size,
+ self.config.hidden_size,
+ config=self.config,
+ init_method=self.config.output_layer_init_method,
+ add_bias=self.config.add_bias_linear,
+ skip_bias_add=True,
+ ag_comm_intf=TPYCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ rs_comm_intf=TPXCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ partition_dim=1,
+ enable_backward_overlap_ag_with_matmul=args.enable_backward_overlap_ag_with_matmul
+ )
+
+ return wrapper
+
+
+def attention_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(
+ self,
+ hidden_states,
+ attention_mask,
+ key_value_states=None,
+ inference_params=None,
+ rotary_pos_emb=None,
+ packed_seq_params=None,
+ ):
+ args = get_args()
+ if args.multi_head_latent_attention:
+ # hidden_states: [sq, b, h]
+
+ # For self attention we just duplicate the rotary_pos_emb if it isn't already
+ if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
+ rotary_pos_emb = (rotary_pos_emb,) * 2
+
+ q_len, bsz, _ = hidden_states.shape
+ mixed_x_layer, _ = self.linear_qkv(hidden_states)
+
+ # [sq, b, hp] --> [sq, b, ng, hn]
+ q_a, compressed_kv, k_pe = torch.split(
+ mixed_x_layer,
+ [
+ self.q_rank, self.kv_lora_rank, self.qk_rope_head_dim,
+ ],
+ dim=-1)
+
+ if self.q_layernorm is None:
+ q = q_a
+ else:
+ q, _ = self.linear_qb(self.q_layernorm(q_a))
+
+ q = q.view(q_len, bsz, self.config.num_attention_heads, -1)
+
+ q_nope, q_pe = torch.split(
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+
+ k_pe = k_pe.view(q_len, bsz, 1, self.qk_rope_head_dim)
+ kv, _ = self.linear_kvb(self.k_layernorm(compressed_kv))
+ kv = kv.view(q_len, bsz, self.config.num_attention_heads, self.qk_nope_head_dim +
+ self.v_head_dim)
+ k_nope, value = torch.split(
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
+ )
+
+ if rotary_pos_emb is not None:
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+
+ b, h, s, d = q_pe.shape
+ q_pe = q_pe.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+ b, h, s, d = k_pe.shape
+ k_pe = k_pe.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
+
+ if packed_seq_params is not None:
+ cu_seqlens_q = packed_seq_params.cu_seqlens_q
+ cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
+ else:
+ cu_seqlens_q = cu_seqlens_kv = None
+
+ q_pe = apply_rotary_pos_emb(q_pe, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q)
+ k_pe = apply_rotary_pos_emb(k_pe, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)
+
+ query = torch.cat([q_nope, q_pe], dim=-1)
+
+ k_pe = k_pe.repeat(1, 1, query.shape[2], 1)
+ key = torch.cat([k_nope, k_pe], dim=-1)
+
+ if self.use_flash_attn and self.q_head_dim != self.v_head_dim:
+ if self.shape_order == "BNSD":
+ value = F.pad(value, [0, self.q_head_dim - self.v_head_dim])
+ else:
+ query = F.pad(query, [0, 256 - self.q_head_dim])
+ key = F.pad(key, [0, 256 - self.q_head_dim])
+ value = F.pad(value, [0, 256 - self.v_head_dim])
+
+ # ==================================
+ # core attention computation
+ # ==================================
+ attn_mask_type = AttnMaskType.causal
+ if self.checkpoint_core_attention and self.training:
+ core_attn_out = self._checkpointed_attention_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ attn_mask_type=attn_mask_type,
+ packed_seq_params=packed_seq_params,
+ )
+ else:
+ core_attn_out = self.core_attention(
+ query,
+ key,
+ value,
+ attention_mask,
+ attn_mask_type=attn_mask_type,
+ packed_seq_params=packed_seq_params,
+ )
+
+ if packed_seq_params is not None:
+ # reshape to same output shape as unpacked case
+ # (t, np, hn) -> (t, b=1, h=np*hn)
+ # t is the pack size = sum (sq_i)
+ # note that batch is a dummy dimension in the packed case
+ core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
+
+ if self.use_flash_attn:
+ core_attn_out = core_attn_out.view(q_len, bsz, self.config.num_attention_heads, -1)
+ core_attn_out = core_attn_out[:, :, :, : self.v_head_dim]
+ core_attn_out = core_attn_out.reshape(q_len, bsz, self.config.num_attention_heads * self.v_head_dim)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output, bias = self.linear_proj(core_attn_out)
+ else:
+ output, bias = fn(
+ self,
+ hidden_states,
+ attention_mask,
+ key_value_states,
+ inference_params,
+ rotary_pos_emb,
+ packed_seq_params
+ )
+
+ return output, bias
+
+ return wrapper
+
+
+def attention_forward(
+ self,
+ hidden_states,
+ attention_mask,
+ key_value_states=None,
+ inference_params=None,
+ rotary_pos_emb=None,
+ packed_seq_params=None,
+):
+
+ # For self attention we just duplicate the rotary_pos_emb if it isn't already
+ if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
+ rotary_pos_emb = (rotary_pos_emb,) * 2
+
+ # =====================
+ # Query, Key, and Value
+ # =====================
+ # Get the query, key and value tensors based on the type of attention -
+ # self or cross attn.
+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
+
+ # ===================================================
+ # Adjust key, value, and rotary_pos_emb for inference
+ # ===================================================
+ key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
+ inference_params, key, value, rotary_pos_emb
+ )
+
+ # ================================================
+ # relative positional embedding (rotary embedding)
+ # ================================================
+ if rotary_pos_emb is not None:
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+
+ if packed_seq_params is not None:
+ cu_seqlens_q = packed_seq_params
+ cu_seqlens_kv = packed_seq_params
+ else:
+ cu_seqlens_q = cu_seqlens_kv = None
+ query = apply_rotary_pos_emb(
+ query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q,
+ )
+ key = apply_rotary_pos_emb(
+ key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv,
+ )
+
+
+ # ==================================
+ # core attention computation
+ # ==================================
+
+ if self.checkpoint_core_attention and self.training:
+ core_attn_out = self._checkpointed_attention_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ attn_mask_type=attn_mask_type,
+ packed_seq_params=packed_seq_params,
+ )
+ else:
+ core_attn_out = self.core_attention(
+ query,
+ key,
+ value,
+ attention_mask,
+ attn_mask_type=attn_mask_type,
+ packed_seq_params=packed_seq_params,
+ )
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output, bias = self.linear_proj(core_attn_out)
+
+ return output, bias
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/custom_layers/transformer_engine.py b/model/train/yoco_moe/mindspeed/core/transformer/custom_layers/transformer_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e385f48c02f1e467d093b451095642bc1d1be5
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/custom_layers/transformer_engine.py
@@ -0,0 +1,53 @@
+import torch
+import torch_npu
+import torch.nn as nn
+
+from megatron.core.transformer.transformer_config import TransformerConfig
+from megatron.legacy.model.rms_norm import RMSNorm
+from megatron.training import get_args
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+from mindspeed.core.tensor_parallel.tp_2d.layernorm_2d import LayerNorm2D
+from mindspeed.core.tensor_parallel.tp_2d.rms_norm_2d import RMSNorm2D
+
+
+class PTNorm:
+ """
+ Conditional Initialization of Transformer-Engine’s LayerNorm or RMSNorm Instance
+ """
+
+ def __new__(
+ cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5,
+ ):
+ args = get_args()
+ if config.normalization == "LayerNorm":
+ if args.tp_2d:
+ instance = LayerNorm2D(
+ hidden_size,
+ eps=eps,
+ last_dim_split_comm_intf=TPYCollectiveComm(),
+ )
+ else:
+ instance = nn.LayerNorm(
+ normalized_shape=hidden_size,
+ eps=eps,
+ )
+ elif config.normalization == "RMSNorm":
+ if args.tp_2d:
+ instance = RMSNorm2D(
+ hidden_size,
+ eps=eps,
+ last_dim_split_comm_intf=TPYCollectiveComm(),
+ )
+ instance.use_fused_rmsnorm = False
+ else:
+ instance = RMSNorm(
+ dim=hidden_size,
+ eps=eps,
+ sequence_parallel=config.sequence_parallel,
+ )
+ instance.use_fused_rmsnorm = True
+ else:
+ raise Exception('Only LayerNorm and RMSNorm are curently supported')
+
+ return instance
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/dot_product_attention.py b/model/train/yoco_moe/mindspeed/core/transformer/dot_product_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc1b9eb3828474a825e7caf665fee6e0162dafab
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/dot_product_attention.py
@@ -0,0 +1,309 @@
+# coding=utf-8
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import math
+from functools import wraps
+
+import torch
+import torch_npu
+from torch import Tensor
+from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
+from megatron.core.transformer import TransformerConfig
+from megatron.core.transformer.dot_product_attention import DotProductAttention
+from megatron.core.transformer.enums import AttnMaskType
+from megatron.core.transformer.utils import attention_mask_func
+from megatron.core.utils import divide
+from megatron.training import get_args
+from megatron.core import mpu, parallel_state
+from mindspeed.core.models.common.embeddings.rotary_pos_embedding import yarn_get_mscale
+from mindspeed.core.context_parallel.ring_context_parallel import ringattn_context_parallel
+from mindspeed.core.context_parallel.ulysses_context_parallel import ulyssesattn_context_parallel
+from mindspeed.core.context_parallel.context_parallel_kv_cache import get_cache_policy
+from mindspeed.ops.fusion_attention_v2 import npu_fusion_attention
+from mindspeed.model.alibi_mask import AlibiForFusionAttnSingleton
+from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid_ring,
+ get_context_parallel_for_hybrid_ring_world_size,
+ get_context_parallel_for_hybrid_ring_rank,
+ get_context_parallel_for_hybrid_ring_global_ranks,
+ get_ring_ranks_for_intra_window,
+ get_ring_ranks_for_inter_window_kv,
+ get_ring_ranks_for_inter_window_dkv,
+ get_ring_group_for_intra_window,
+ get_ring_group_for_intra_window_send_recv_overlap)
+from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
+from mindspeed.model.transformer import get_attention_mask
+from mindspeed.utils import get_actual_seq_len
+from mindspeed.core.context_parallel.adaptive_context_parallel import adaptive_attn_context_parallel
+from mindspeed.core.context_parallel.utils import get_scheduling_info
+
+try:
+ from einops import rearrange
+except ImportError:
+ rearrange = None
+
+
+def dot_product_attention_init(
+ self,
+ config: TransformerConfig,
+ layer_number: int,
+ attn_mask_type: AttnMaskType,
+ attention_type: str,
+ attention_dropout: float = None,
+):
+ cp_size = config.context_parallel_size
+ config.context_parallel_size = 1
+
+ super(DotProductAttention, self).__init__(config=config)
+ assert (
+ self.config.context_parallel_size == 1
+ ), "Context parallelism is only supported by TEDotProductAttention!"
+
+ assert (
+ self.config.window_size is None
+ ), "Sliding Window Attention is only supported by TEDotProductAttention!"
+
+ self.layer_number = max(1, layer_number)
+ self.attn_mask_type = attn_mask_type
+ self.attention_type = attention_type # unused for now
+
+ projection_size = self.config.kv_channels * self.config.num_attention_heads
+ args = get_args()
+ # Per attention head and per partition values.
+ world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size()
+ self.hidden_size_per_partition = divide(projection_size, world_size)
+ self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
+ self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
+ self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
+
+ coeff = None
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+ if self.config.apply_query_key_layer_scaling:
+ coeff = self.layer_number
+ self.norm_factor *= coeff
+
+ self.scale_mask_softmax = FusedScaleMaskSoftmax(
+ input_in_fp16=self.config.fp16,
+ input_in_bf16=self.config.bf16,
+ attn_mask_type=self.attn_mask_type,
+ scaled_masked_softmax_fusion=self.config.masked_softmax_fusion,
+ mask_func=attention_mask_func,
+ softmax_in_fp32=self.config.attention_softmax_in_fp32,
+ scale=coeff,
+ )
+
+ # Dropout. Note that for a single iteration, this layer will generate
+ # different outputs on different number of parallel partitions but
+ # on average it should not be partition dependent.
+ self.attention_dropout = torch.nn.Dropout(
+ self.config.attention_dropout if attention_dropout is None else attention_dropout
+ )
+
+ config.context_parallel_size = cp_size
+
+ # add pse
+ self.pse = None
+ self.pse_type = args.alibi_fusion_attn_type
+
+ if args.multi_head_latent_attention:
+ self.scale_mask_softmax.scale = True
+ self.hidden_size_per_partition = config.num_attention_heads * args.v_head_dim
+ self.q_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
+ self.softmax_scale = self.q_head_dim ** (-0.5)
+
+ if args.rope_scaling_type is not None:
+ mscale_all_dim = args.rope_scaling_mscale_all_dim if args.rope_scaling_mscale_all_dim else 0
+ scaling_factor = args.rope_scaling_factor
+
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.softmax_scale = self.softmax_scale * mscale * mscale
+
+ self.norm_factor = 1.0 / self.softmax_scale
+
+ if self.pse_type is None:
+ self.pse_type = 1 # not use pse
+ elif self.pse_type == 0:
+ alibi = AlibiForFusionAttnSingleton.get_alibi_tensor_for_fusion_attn(args.seq_length,
+ config.num_attention_heads,
+ config.params_dtype,
+ args.alibi_diagonal_opposite,
+ 1024)
+ self.pse = alibi
+ elif self.pse_type == 2 or self.pse_type == 3:
+ self.pse = AlibiForFusionAttnSingleton.get_alibi_slopes_for_fusion_attn(config.num_attention_heads)
+
+
+def dot_product_attention_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ fn(self, *args, **kwargs)
+ if self.config.num_query_groups is None:
+ self.config.num_query_groups = self.config.num_attention_heads
+ self.num_attention_heads_per_partition = self.config.num_attention_heads * self.num_query_groups_per_partition // self.config.num_query_groups
+ return wrapper
+
+
+def dot_product_attention_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params):
+ if attention_mask is None and self.attn_mask_type == AttnMaskType.causal:
+ attention_mask = get_attention_mask()
+ if get_args().use_flash_attn:
+ return dot_product_attention_forward(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params)
+ return fn(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params)
+
+ return wrapper
+
+
+def dot_product_attention_forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attention_mask,
+ attn_mask_type,
+ packed_seq_params,
+):
+ args = get_args()
+ seq_len, bsz, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
+
+ sparse_mode = args.sparse_mode
+ if attn_mask_type == AttnMaskType.no_mask:
+ sparse_mode = 0 # default mask
+
+ scale = 1.0 / math.sqrt(
+ self.hidden_size_per_attention_head) if self.scale_mask_softmax.scale is None else self.softmax_scale
+
+ cp_expanded_by_2d_tp = args.tp_2d and args.tp_y > 1
+ if cp_expanded_by_2d_tp:
+ tp_y_cp_sz = TensorParallelYUnionCP().get_parallel_group_world_size()
+ else:
+ tp_y_cp_sz = self.config.context_parallel_size
+
+ if (self.config.context_parallel_size > 1 and args.context_parallel_algo == "ulysses_cp_algo"
+ and args.context_parallel_kv_cache_policy):
+ self.ulysses_comm_para['cache_policy'] = get_cache_policy(
+ self.layer_number, args.context_parallel_kv_cache_policy, args.context_parallel_cache_interval
+ )
+ self.ulysses_comm_para['use_ulysses_allgather_kv'] = args.use_ulysses_allgather_kv
+
+ attn_para = dict()
+ attn_para['packed_seq_params'] = packed_seq_params
+ attn_para['attention_mask'] = attention_mask
+ attn_para['scale'] = scale
+ attn_para['pre_tokens'] = args.pre_tockens
+ attn_para['next_tokens'] = args.next_tockens
+ attn_para['keep_prob'] = 1 - self.attention_dropout.p
+ attn_para['sparse_mode'] = sparse_mode
+ output = ulyssesattn_context_parallel(query, key, value, attn_para, self.ulysses_comm_para)
+
+ return output
+
+ if tp_y_cp_sz > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo',
+ 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']:
+ in_hybrid_mode = False
+ if get_context_parallel_group_for_hybrid_ring(check_initialized=False) is not None:
+ in_hybrid_mode = True
+
+ if not in_hybrid_mode:
+ if cp_expanded_by_2d_tp:
+ tp_y_cp = TensorParallelYUnionCP()
+ cp_group = tp_y_cp.group
+ cp_size = tp_y_cp.get_parallel_group_world_size()
+ rank = tp_y_cp.get_parallel_rank()
+ cp_global_ranks = tp_y_cp.global_ranks
+ else:
+ cp_group = mpu.get_context_parallel_group()
+ cp_size = mpu.get_context_parallel_world_size()
+ rank = mpu.get_context_parallel_rank()
+ cp_global_ranks = mpu.get_context_parallel_global_ranks()
+ else:
+ cp_group = get_context_parallel_group_for_hybrid_ring()
+ cp_size = get_context_parallel_for_hybrid_ring_world_size()
+ rank = get_context_parallel_for_hybrid_ring_rank()
+ cp_global_ranks = get_context_parallel_for_hybrid_ring_global_ranks()
+
+ cp_para = dict()
+ cp_para['megatron_cp_in_bnsd'] = self.config.megatron_cp_in_bnsd
+ cp_para['causal'] = args.attention_mask_type == 'causal'
+ cp_para['cp_group'] = cp_group
+ cp_para['cp_size'] = cp_size
+ cp_para['rank'] = rank
+
+ query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]]
+ if args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
+ cp_para['cp_global_ranks'] = cp_global_ranks
+ if args.use_cp_send_recv_overlap:
+ if cp_expanded_by_2d_tp:
+ cp_para['cp_group_for_send_recv_overlap'] = tp_y_cp.overlap_group
+ else:
+ cp_para['cp_group_for_send_recv_overlap'] = mpu.get_context_parallel_group_for_send_recv_overlap()
+ else:
+ cp_para['cp_group_for_send_recv_overlap'] = None
+ cp_para['pse'] = self.pse
+ cp_para['pse_type'] = self.pse_type
+
+ if self.config.context_parallel_size > 1 and not args.tp_2d:
+ cp_para['cp_inner_ranks'] = get_ring_ranks_for_intra_window()
+ cp_para['cp_outer_ranks'] = get_ring_ranks_for_inter_window_kv()
+ cp_para['cp_dkv_outer_ranks'] = get_ring_ranks_for_inter_window_dkv()
+ cp_para['cp_group_for_intra_window'] = get_ring_group_for_intra_window()
+ cp_para['cp_group_for_intra_window_send_recv_overlap'] = get_ring_group_for_intra_window_send_recv_overlap()
+
+ cp_para['cache_policy'] = get_cache_policy(
+ self.layer_number, args.context_parallel_kv_cache_policy, args.context_parallel_cache_interval
+ )
+
+ output = ringattn_context_parallel(query, key, value, n_head, cp_para, scale, attention_mask, self.attention_dropout.p,
+ packed_seq_params)
+ else:
+ cp_para['scheduling_info'] = get_scheduling_info()
+ output = adaptive_attn_context_parallel(query, key, value, n_head, cp_para, scale, attention_mask, self.attention_dropout.p)
+
+ else:
+ if packed_seq_params is not None: # TND
+ cp_size = mpu.get_context_parallel_world_size()
+ actual_seq_qlen = packed_seq_params.cu_seqlens_q.tolist()
+ actual_seq_kvlen = packed_seq_params.cu_seqlens_kv.tolist()
+ query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]]
+ shape_order = 'TND'
+ else: # SBH
+ actual_seq_qlen = None
+ actual_seq_kvlen = None
+ query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]]
+ shape_order = 'SBH'
+ if args.use_fusion_attn_v2:
+ output = npu_fusion_attention(
+ query, key, value, n_head, shape_order,
+ pse=self.pse,
+ padding_mask=None,
+ atten_mask=attention_mask,
+ scale=scale,
+ pse_type=self.pse_type,
+ pre_tokens=args.pre_tockens,
+ next_tokens=args.next_tockens,
+ keep_prob=1 - self.attention_dropout.p,
+ inner_precise=0,
+ sparse_mode=sparse_mode,
+ actual_seq_qlen=actual_seq_qlen,
+ actual_seq_kvlen=actual_seq_kvlen
+ )[0]
+ else:
+ output = torch_npu.npu_fusion_attention(
+ query, key, value, n_head, shape_order,
+ pse=None,
+ padding_mask=None,
+ atten_mask=attention_mask,
+ scale=scale,
+ pre_tockens=args.pre_tockens,
+ next_tockens=args.next_tockens,
+ keep_prob=1 - self.attention_dropout.p,
+ inner_precise=0,
+ sparse_mode=sparse_mode,
+ actual_seq_qlen=actual_seq_qlen,
+ actual_seq_kvlen=actual_seq_kvlen
+ )[0]
+ if packed_seq_params is not None:
+ output = rearrange(output, '(b s) h d -> s b (h d)', s=seq_len, b=bsz)
+ return output
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/mlp.py b/model/train/yoco_moe/mindspeed/core/transformer/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dfc883ccd8d57f96df4c8e06bf29d766a6b8909
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/mlp.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from functools import wraps
+from megatron.core.transformer.spec_utils import build_module
+from megatron.core.transformer.transformer_config import TransformerConfig
+from megatron.core.transformer.mlp import MLPSubmodules, MLP
+from megatron.training import get_args
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm, TPXOverlapCollectiveComm, \
+ TPYCollectiveComm, TPYOverlapCollectiveComm
+from mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d import ParallelLinear2D
+
+
+def mlp_init(
+ self,
+ config: TransformerConfig,
+ submodules: MLPSubmodules,
+ is_expert: bool = False,
+ input_size: int = None,
+ shared_expert=False,
+):
+ super(MLP, self).__init__(config=config)
+
+ self.config: TransformerConfig = config
+
+ self.input_size = input_size if input_size is not None else self.config.hidden_size
+
+ ffn_hidden_size = self.config.ffn_hidden_size
+ if self.config.gated_linear_unit:
+ ffn_hidden_size *= 2
+ if shared_expert:
+ self.linear_fc1 = build_module(
+ submodules.linear_fc1,
+ self.input_size,
+ ffn_hidden_size,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear,
+ skip_bias_add=True,
+ is_expert=is_expert,
+ tp_comm_buffer_name='fc1',
+ shared_expert=shared_expert
+ )
+ else:
+ self.linear_fc1 = build_module(
+ submodules.linear_fc1,
+ self.input_size,
+ ffn_hidden_size,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear,
+ skip_bias_add=True,
+ is_expert=is_expert,
+ tp_comm_buffer_name='fc1'
+ )
+
+ self.activation_func = self.config.activation_func
+
+ if shared_expert:
+ self.linear_fc2 = build_module(
+ submodules.linear_fc2,
+ self.config.ffn_hidden_size,
+ self.config.hidden_size,
+ config=self.config,
+ init_method=self.config.output_layer_init_method,
+ bias=self.config.add_bias_linear,
+ input_is_parallel=True,
+ skip_bias_add=True,
+ is_expert=is_expert,
+ tp_comm_buffer_name='fc2',
+ shared_expert=shared_expert
+ )
+ else:
+ self.linear_fc2 = build_module(
+ submodules.linear_fc2,
+ self.config.ffn_hidden_size,
+ self.config.hidden_size,
+ config=self.config,
+ init_method=self.config.output_layer_init_method,
+ bias=self.config.add_bias_linear,
+ input_is_parallel=True,
+ skip_bias_add=True,
+ is_expert=is_expert,
+ tp_comm_buffer_name='fc2'
+ )
+
+ self.shared_expert = shared_expert
+
+
+def mlp_init_2d_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *arg, **kwargs):
+ fn(self, *arg, **kwargs)
+ args = get_args()
+ is_expert = False
+ if get_args().num_experts is not None:
+ is_expert = True
+ if args.tp_2d:
+ ffn_hidden_size = self.config.ffn_hidden_size
+ if self.config.gated_linear_unit:
+ ffn_hidden_size *= 2
+ self.linear_fc1 = ParallelLinear2D(
+ self.config.hidden_size,
+ ffn_hidden_size,
+ config=self.config,
+ init_method=self.config.init_method,
+ add_bias=self.config.add_bias_linear,
+ skip_bias_add=True,
+ is_expert=is_expert,
+ ag_comm_intf=TPXCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ rs_comm_intf=TPYCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=args.enable_overlap_matmul_with_rs,
+ partition_dim=0,
+ enable_backward_overlap_ag_with_matmul=args.enable_backward_overlap_ag_with_matmul)
+ self.linear_fc2 = ParallelLinear2D(
+ self.config.ffn_hidden_size,
+ self.config.hidden_size,
+ config=self.config,
+ init_method=self.config.output_layer_init_method,
+ add_bias=self.config.add_bias_linear,
+ skip_bias_add=True,
+ is_expert=is_expert,
+ ag_comm_intf=TPYCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ rs_comm_intf=TPXCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=args.enable_overlap_ag_with_matmul,
+ enable_overlap_matmul_with_rs=False,
+ partition_dim=1,
+ enable_backward_overlap_ag_with_matmul=args.enable_backward_overlap_ag_with_matmul)
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/module.py b/model/train/yoco_moe/mindspeed/core/transformer/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..5986c8c8d83a5f9dfac235ec8036a133dc168917
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/module.py
@@ -0,0 +1,17 @@
+from functools import wraps
+import torch
+from megatron.core.transformer.spec_utils import build_module
+
+
+def megatron_module_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, config):
+ fn(self, config)
+ if hasattr(config, 'reset_attention_order') and config.reset_attention_order:
+ # Create linear_qkv module before self_attention.
+ self.linear_qkv = build_module(torch.nn.GELU)
+ # Free memory to avoid memory fragmentation. It will be assigned a real linear function later.
+ self.linear_qkv = None
+ config.reset_attention_order = False
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/__init__.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/comm_utils.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/comm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd2a4dae52dca5afb3dafbb508a59915ca5f696a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/comm_utils.py
@@ -0,0 +1,257 @@
+# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import einops
+import torch
+import torch.distributed
+import torch.distributed as dist
+import torch_npu
+from megatron.core import parallel_state
+from megatron.core.parallel_state import get_global_memory_buffer, get_tensor_model_parallel_rank
+
+from typing import Optional, List
+
+COMM_STREAM = None
+
+
+def async_all_gather(input_, group, event=None, is_use_get_global_memory_buffer=False, last_dim=False):
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return input_, input_, None
+ if last_dim:
+ rank = get_tensor_model_parallel_rank()
+ ag_out = [torch.empty_like(input_) for _ in range(world_size)]
+ ag_out[rank] = input_
+ else:
+ dim_size = list(input_.size())
+ new_dim_size = dim_size[0] * world_size
+ dim_size[0] = new_dim_size
+
+ if is_use_get_global_memory_buffer:
+ ag_out = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
+ else:
+ ag_out = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ if event:
+ # multi stream wait event
+ global COMM_STREAM
+ if COMM_STREAM is None:
+ COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
+ with torch_npu.npu.stream(COMM_STREAM):
+ event.wait()
+ if last_dim:
+ handle = torch.distributed.all_gather(ag_out, input_.contiguous(), group=group, async_op=True)
+ else:
+ handle = torch.distributed._all_gather_base(
+ ag_out, input_.contiguous(), group=group, async_op=True
+ )
+ else:
+ if last_dim:
+ handle = torch.distributed.all_gather(ag_out, input_.contiguous(), group=group, async_op=True)
+ else:
+ handle = torch.distributed._all_gather_base(
+ ag_out, input_.contiguous(), group=group, async_op=True
+ )
+ return input_, ag_out, handle
+
+
+def async_reduce_scatter(input_, group, event=None, stream=None, is_use_get_global_memory_buffer=False):
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return input_, input_, None
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] // world_size
+ if is_use_get_global_memory_buffer:
+ rs_out = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
+ else:
+ rs_out = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ if event or stream:
+ # multi stream wait event
+ global COMM_STREAM
+ if COMM_STREAM is None:
+ COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
+ with torch_npu.npu.stream(COMM_STREAM):
+ if event:
+ event.wait()
+ if stream:
+ torch.cuda.current_stream().wait_stream(stream)
+ handle = torch.distributed._reduce_scatter_base(
+ rs_out, input_.contiguous(), group=group, async_op=True
+ )
+ else:
+ handle = torch.distributed._reduce_scatter_base(
+ rs_out, input_.contiguous(), group=group, async_op=True
+ )
+ return input_, rs_out, handle
+
+
+def async_all_to_all(input_, output_split_sizes, input_split_sizes, group, event=None):
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return input_, input_, None
+ if output_split_sizes is None:
+ # Equal split (all2all)
+ a2a_out = torch.empty_like(input_)
+ else:
+ # Unequal split (all2all-v)
+ a2a_out = input_.new_empty(
+ size=[sum(output_split_sizes)] + list(input_.size()[1:]),
+ dtype=input_.dtype,
+ device=torch.cuda.current_device(),
+ )
+
+ if event:
+ # multi stream wait event
+ global COMM_STREAM
+ if COMM_STREAM is None:
+ COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
+ with torch_npu.npu.stream(COMM_STREAM):
+ event.wait()
+ handle = dist.all_to_all_single(
+ a2a_out,
+ input_.contiguous(),
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=True
+ )
+ else:
+ handle = dist.all_to_all_single(
+ a2a_out,
+ input_.contiguous(),
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=True
+ )
+ return input_, a2a_out, handle
+
+
+def transfer_tensor_last_dim_to_first(input_x):
+ num_dims = input_x.dim()
+ return einops.rearrange(input_x, "... lst -> lst ...").contiguous(), num_dims
+
+
+def transfer_tensor_first_dim_to_last(input_x, num_dims):
+ return einops.rearrange(input_x, "first ... -> ... first").contiguous()
+
+
+def _gather_no_grad(input_: torch.Tensor, output_split_sizes=None, group=None):
+ if group is None:
+ group = parallel_state.get_tensor_model_parallel_group()
+ world_size = torch.distributed.get_world_size(group)
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+
+ dim_size = list(input_.size())
+ if output_split_sizes is None:
+ dim_size[0] = dim_size[0] * world_size
+ output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
+ torch.distributed._all_gather_base(output, input_.contiguous(), group=group)
+ else:
+ dim_size[0] = sum(output_split_sizes)
+ output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
+ output_tensor_list = list(torch.split(output, output_split_sizes, dim=0))
+ torch.distributed.all_gather(output_tensor_list, input_, group=group)
+
+ return output
+
+
+def _reduce_scatter_no_grad(input_: torch.Tensor, input_split_sizes=None, group=None):
+ if group is None:
+ group = parallel_state.get_tensor_model_parallel_group()
+ world_size = torch.distributed.get_world_size(group)
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+
+ if input_split_sizes is None:
+ dim_size = list(input_.size())
+ if dim_size[0] % world_size != 0:
+ raise ValueError("First dimension of the tensor should be divisible by tensor parallel size")
+ dim_size[0] = dim_size[0] // world_size
+
+ output = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
+ torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=group)
+ else:
+ rank = torch.distributed.get_rank(group)
+ input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
+ output = torch.empty_like(input_tensor_list[rank])
+ torch.distributed.reduce_scatter(output, input_tensor_list, group=group)
+ return output
+
+
+class _Gather(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input_, output_split_sizes=None, group=None):
+ """Forward function."""
+ ctx.output_split_sizes = output_split_sizes
+ ctx.group = group
+ return _gather_no_grad(input_, output_split_sizes, group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """Backward function."""
+ output_split_sizes = ctx.output_split_sizes
+ group = ctx.group
+ return _reduce_scatter_no_grad(grad_output, output_split_sizes, group), None, None
+
+
+class _ReduceScatter(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input_, input_split_sizes=None, group=None):
+ """Forward function."""
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ return _reduce_scatter_no_grad(input_, input_split_sizes, group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """Backward function."""
+ input_split_sizes = ctx.input_split_sizes
+ group = ctx.group
+ return _gather_no_grad(grad_output, input_split_sizes, group), None, None
+
+
+# Backport from latest Megatron-LM
+def all_gather(
+ input_: torch.Tensor,
+ output_split_sizes: Optional[List[int]] = None,
+ group: torch.distributed.ProcessGroup = None,
+) -> torch.Tensor:
+ """Gather tensors and concatenate along the first dimension.
+
+ Args:
+ input_tensor (torch.Tensor):
+ A tensor to be gathered.
+ output_split_sizes (List[int], optional):
+ A list specifying the sizes of the output splits along the first dimension.
+ If None, equal splitting is assumed. Default: None.
+ group (torch.distributed.ProcessGroup):
+ The process group to perform communication.
+ Default: None, will use tensor model parallel group.
+
+ Returns:
+ torch.Tensor: Gathered tensor.
+ """
+ return _Gather.apply(input_, output_split_sizes, group)
+
+
+def reduce_scatter(
+ input_: torch.Tensor,
+ input_split_sizes: Optional[List[int]] = None,
+ group: torch.distributed.ProcessGroup = None,
+) -> torch.Tensor:
+ """Reduce-scatter the input tensor across model parallel group.
+
+ Args:
+ input_ (torch.Tensor): The input tensor to be reduce-scattered.
+ input_split_sizes (List[int], optional): A list specifying the sizes of
+ the input splits along the first dimension for each rank. If None,
+ equal splitting is assumed. Default: None.
+ group (torch.distributed.ProcessGroup):
+ The process group to perform communication.
+ Default: None, will use tensor model parallel group.
+ """
+ return _ReduceScatter.apply(input_, input_split_sizes, group)
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/experts.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/experts.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdccc93ae813cdba7693ed5dcae363da36740c70
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/experts.py
@@ -0,0 +1,197 @@
+# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from functools import wraps
+import torch
+import torch.nn.functional as F
+from megatron.core import parallel_state, tensor_parallel
+from megatron.training import get_args
+from mindspeed.model.transformer import should_recompute_activation
+from mindspeed.core.fusions.fused_bias_swiglu import fused_swiglu
+from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
+from mindspeed.core.transformer.moe.grouped_gemm_util import fused_alltoall_gather_bmm, fused_bmm_reducescatter_alltoall
+from mindspeed.core.transformer.moe.grouped_mlp_with_comp_and_comm_overlap_all2all import grouped_mlp_with_comp_and_comm_overlap_all2all
+from mindspeed.core.transformer.moe.grouped_mlp_with_comp_and_comm_overlap_allgather import grouped_mlp_with_comp_and_comm_overlap_allgather
+from mindspeed.core.transformer.moe import grouped_gemm_util as gg
+
+
+def get_zeros_with_tp(input_):
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
+ zeros_shape = input_.shape[:-1] + (input_.shape[-1] * world_size,)
+ return torch.zeros(zeros_shape, dtype=input_.dtype, layout=input_.layout, device=input_.device)
+
+
+def sequential_mlp_forward(self, permuted_local_hidden_states, tokens_per_expert):
+ output_local = get_zeros_with_tp(permuted_local_hidden_states)
+ output_bias_local = None
+ if self.add_bias:
+ output_bias_local = get_zeros_with_tp(permuted_local_hidden_states)
+
+ cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
+ # Insert zero at the begining for offset index's convenience
+ zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
+ cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
+
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
+ if not hasattr(self, 'comm_stream'):
+ self.comm_stream = torch.cuda.Stream()
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
+
+ for expert_num, expert in enumerate(self.local_experts):
+ start = cumsum_num_tokens[expert_num]
+ end = cumsum_num_tokens[expert_num + 1]
+ hidden = permuted_local_hidden_states[start:end]
+
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
+ with torch.cuda.stream(self.comm_stream):
+ hidden = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(hidden)
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
+
+ output, output_bias = expert(hidden)
+
+ output_local[start:end] = output
+ if self.add_bias:
+ output_bias = output_bias.expand_as(output)
+ output_bias_local[start:end, :] = output_bias
+
+ return output_local, output_bias_local
+
+
+def group_mlp_forward(self, permuted_local_hidden_states, tokens_per_expert, ctx=None):
+ if permuted_local_hidden_states.nelement() != 0:
+ w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
+ w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
+ else:
+ w1 = self.weight1.view(self.config.hidden_size, -1)
+ w2 = self.weight2.view(-1, self.config.hidden_size)
+ group_list = torch.cumsum(tokens_per_expert, dim=0)
+ if get_args().moe_alltoall_overlap_comm:
+ return grouped_mlp_with_comp_and_comm_overlap_all2all(permuted_local_hidden_states, w1, w2,
+ (self.weight1, self.weight2, self.activation_func,
+ group_list, ctx.layer_number),
+ ctx=ctx)
+ else:
+ return grouped_mlp_with_comp_and_comm_overlap_allgather(permuted_local_hidden_states, w1, w2,
+ (self.weight1, self.weight2, self.activation_func,
+ group_list, self.layer_number))
+
+
+def groupedmlp_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ args_ = get_args()
+ tp_size = parallel_state.get_tensor_model_parallel_world_size()
+ # set tp size to 1 before GMM init to aviod weight sharding
+ if args_.moe_tp_extend_ep:
+ parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = 1
+ fn(self, *args, **kwargs)
+ if args_.moe_tp_extend_ep:
+ parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = tp_size
+ if self.config.gated_linear_unit:
+ assert (self.config.activation_func == F.silu
+ ), 'Activation function must be silu when using fused_swiglu.'
+ self.activation_func = fused_swiglu
+ self.layer_number = None
+ self.set_recompute_activation_func = False
+
+ return wrapper
+
+
+def groupedmlp_forward(self, permuted_local_hidden_states, tokens_per_expert):
+ args = get_args()
+ is_recompute_activation = should_recompute_activation(
+ self.layer_number) and not args.moe_alltoall_overlap_comm and not args.moe_allgather_overlap_comm
+
+ gemm_fusion = args.gemm_gradient_accumulation_fusion
+ tp_group = parallel_state.get_tensor_model_parallel_group()
+ ep_group = parallel_state.get_expert_model_parallel_group()
+
+ if not is_recompute_activation:
+ if permuted_local_hidden_states.nelement() != 0:
+ # Reshape the weights for the grouped GEMMs.
+ w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
+ w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
+
+ if args.moe_bmm_mc2:
+ # input to alltoall_gather_bmm op input: [E*C, H/TP] -> [E, C, H/TP]
+ permuted_local_hidden_states = permuted_local_hidden_states.view(self.config.num_moe_experts,
+ permuted_local_hidden_states.shape[
+ 0] // self.config.num_moe_experts,
+ -1)
+ bmm_param = {'group_ep': ep_group, 'group_tp': tp_group, 'shard_type': 0,
+ 'need_recompute': False}
+ fc1_output = fused_alltoall_gather_bmm(permuted_local_hidden_states, w1, None, bmm_param)
+ intermediate_parallel = self.activation_func(fc1_output)
+ fc2_output = fused_bmm_reducescatter_alltoall(intermediate_parallel, w2, None, bmm_param)
+ # revert the output shape: [E, C, H/TP] -> [E*C, H/TP]
+ fc2_output = fc2_output.view(-1, fc2_output.shape[2])
+ else:
+ fc1_output = gg.ops.gmm(permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False,
+ gemm_fusion=gemm_fusion, original_weight=self.weight1)
+ intermediate_parallel = self.activation_func(fc1_output)
+ fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False,
+ gemm_fusion=gemm_fusion, original_weight=self.weight2)
+ else:
+ # No token is allocated for local experts.
+ assert torch.count_nonzero(tokens_per_expert) == 0
+
+ # Make sure parameters still have gradients when no tokens are routed to this set of experts.
+ w1 = self.weight1.view(self.config.hidden_size, -1)
+ w2 = self.weight2.view(-1, self.config.hidden_size)
+ h = torch.matmul(permuted_local_hidden_states, w1)
+ h = self.activation_func(h)
+ h = torch.matmul(h, w2)
+ fc2_output = h
+ else:
+ if permuted_local_hidden_states.nelement() != 0:
+ w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
+ w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
+
+ bmm_param = {'group_ep': ep_group, 'group_tp': tp_group, 'shard_type': 0,
+ 'need_recompute': False}
+
+ if args.moe_bmm_mc2:
+ # input to alltoall_gather_bmm op input: [E*C, H/TP] -> [E, C, H/TP]
+ permuted_local_hidden_states = permuted_local_hidden_states.view(self.config.num_moe_experts,
+ permuted_local_hidden_states.shape[
+ 0] // self.config.num_moe_experts,
+ -1)
+
+ fc1_output = fused_alltoall_gather_bmm(permuted_local_hidden_states, w1, None, bmm_param)
+ else:
+ fc1_output = gg.ops.gmm(
+ permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False, gemm_fusion=gemm_fusion,
+ original_weight=self.weight1
+ )
+
+ self.activation_checkpoint_manager = CheckpointWithoutOutput()
+ intermediate_parallel = self.activation_checkpoint_manager.checkpoint(self.activation_func,
+ False,
+ fc1_output)
+ if args.moe_bmm_mc2:
+ fc2_output = fused_bmm_reducescatter_alltoall(intermediate_parallel, w2, None, bmm_param)
+ # revert the output shape: [E, C, H/TP] -> [E*C, H/TP]
+ fc2_output = fc2_output.view(-1, fc2_output.shape[2])
+ else:
+ fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False,
+ gemm_fusion=gemm_fusion, original_weight=self.weight2)
+ else:
+ assert torch.count_nonzero(tokens_per_expert) == 0
+ w1 = self.weight1.view(self.config.hidden_size, -1)
+ w2 = self.weight2.view(-1, self.config.hidden_size)
+ h = torch.matmul(permuted_local_hidden_states, w1)
+ self.activation_checkpoint_manager = CheckpointWithoutOutput()
+ intermediate_parallel = self.activation_checkpoint_manager.checkpoint(self.activation_func,
+ False,
+ h)
+ h = torch.matmul(intermediate_parallel, w2)
+ fc2_output = h
+
+ # discard the output of the activation function,
+ # which will be restored by recomputation during backward.
+ self.activation_checkpoint_manager.discard_output()
+
+ # when backward to output of dense_4h_to_h,
+ # recompute and restore the output of activation function.
+ if fc2_output.requires_grad:
+ fc2_output.register_hook(self.activation_checkpoint_manager.recompute)
+ return fc2_output, None
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_gemm_util.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_gemm_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c0ed476cc621d3f3d7e9c13028af3fcb9e78450
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_gemm_util.py
@@ -0,0 +1,195 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import torch
+from mindspeed.ops.npu_all_to_all_all_gather_bmm import npu_alltoall_allgather_bmm
+from mindspeed.ops.npu_bmm_reduce_scatter_all_to_all import npu_bmm_reducescatter_alltoall
+
+
+def grouped_gemm_is_available():
+ try:
+ from mindspeed.ops.gmm import npu_gmm
+ return True
+ except ImportError:
+ return False
+
+
+def assert_grouped_gemm_is_available():
+ if not grouped_gemm_is_available():
+ raise ImportError("from mindspeed.ops.gmm import npu_gmm failed.")
+
+
+class Ops:
+ @staticmethod
+ def gmm(a, b, batch_sizes, trans_b=False, gemm_fusion=False, original_weight=None):
+ from mindspeed.ops.gmm import npu_gmm
+
+ if trans_b:
+ b = b.t()
+ group_list = torch.cumsum(batch_sizes, dim=0)
+ return npu_gmm(a, b, bias=None, group_list=group_list, group_type=0, gemm_fusion=gemm_fusion, original_weight=original_weight)
+
+
+ops = Ops
+
+
+def get_device_capability():
+ return 9, 0
+
+
+def get_hcomm_info_world(comm_group):
+ rank = torch.distributed.get_rank()
+ hcomm_info = None
+
+ if torch.__version__ > "2.0.1":
+ hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
+ else:
+ hcomm_info = comm_group.get_hccl_comm_name(rank)
+ return hcomm_info
+
+
+class FusedAllgatherBmmFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, bmm_param):
+
+ group_ep = bmm_param['group_ep']
+ group_tp = bmm_param['group_tp']
+ need_recompute = bmm_param['need_recompute']
+ shard_type = bmm_param['shard_type']
+
+ ep_size = torch.distributed.get_world_size(group=group_ep)
+ tp_size = torch.distributed.get_world_size(group=group_tp)
+
+ tp_group_hcomm = get_hcomm_info_world(group_tp)
+ ep_group_hcomm = get_hcomm_info_world(group_ep)
+
+ out = npu_alltoall_allgather_bmm(
+ input_, weight, ep_group_hcomm, ep_size, tp_group_hcomm, tp_size, bias=bias, shard_type=shard_type,
+ act_type="None", need_allgather_out=True, need_activation_feature=False
+ )
+ bmm_out = out[0]
+ allgather_out = out[1]
+
+ if need_recompute:
+ ctx.save_for_backward(input_, weight)
+ else:
+ ctx.save_for_backward(allgather_out, weight)
+
+ ctx.bias = bias
+ ctx.need_recompute = need_recompute
+ ctx.group_ep = ep_group_hcomm
+ ctx.group_tp = tp_group_hcomm
+ ctx.ep_size = ep_size
+ ctx.tp_size = tp_size
+ ctx.shard_type = shard_type
+ return bmm_out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ need_recompute = ctx.need_recompute
+ bias = ctx.bias
+ group_ep = ctx.group_ep
+ group_tp = ctx.group_tp
+ ep_size = ctx.ep_size
+ tp_size = ctx.tp_size
+ shard_type = ctx.shard_type
+
+ allgather_out = None
+ input_ = None
+
+ if need_recompute:
+ input_, weight = ctx.saved_tensors
+ else:
+ allgather_out, weight = ctx.saved_tensors
+
+ if need_recompute:
+ out = npu_alltoall_allgather_bmm(
+ input_, weight, group_ep, ep_size, group_tp, tp_size, bias=bias, shard_type=shard_type,
+ act_type="None", need_allgather_out=True, need_activation_feature=False
+ )
+ allgather_out = out[1]
+
+ # b,m,k @ b,k,n -> b,m,n
+ # dx: b,m,n @ (b,k,n).t() -> b,m,k
+ out = npu_bmm_reducescatter_alltoall(
+ grad_output, weight.transpose(-1, -2), group_ep, ep_size, group_tp, tp_size,
+ bias=None, shard_type=shard_type
+ )
+
+ # b,m,k @ b,k,n -> b,m,n
+ # dw: (b,m,k).t() @ (b,m,n).t() -> b,k,n
+ grad_bmm_w = torch.bmm(allgather_out.transpose(-1, -2), grad_output)
+ grad_bias = None
+ if bias is not None:
+ grad_bias = torch.sum(grad_output, dim=-1)
+
+ return out, grad_bmm_w, grad_bias, None
+
+
+class FusedBmmReduceScatterFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, bmm_param):
+
+ group_ep = bmm_param['group_ep']
+ group_tp = bmm_param['group_tp']
+ shard_type = bmm_param['shard_type']
+
+ ep_size = torch.distributed.get_world_size(group=group_ep)
+ tp_size = torch.distributed.get_world_size(group=group_tp)
+
+ tp_group_hcomm = get_hcomm_info_world(group_tp)
+ ep_group_hcomm = get_hcomm_info_world(group_ep)
+
+ out = npu_bmm_reducescatter_alltoall(
+ input_, weight, ep_group_hcomm, ep_size, tp_group_hcomm, tp_size,
+ bias=bias, shard_type=shard_type
+ )
+
+ ctx.save_for_backward(input_, weight)
+
+ ctx.bias = bias
+ ctx.group_ep = ep_group_hcomm
+ ctx.group_tp = tp_group_hcomm
+ ctx.ep_size = ep_size
+ ctx.tp_size = tp_size
+ ctx.shard_type = shard_type
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ bias = ctx.bias
+ group_ep = ctx.group_ep
+ group_tp = ctx.group_tp
+ ep_size = ctx.ep_size
+ tp_size = ctx.tp_size
+ shard_type = ctx.shard_type
+
+ input_, weight = ctx.saved_tensors
+
+ # b,m,k @ b,k,n -> b,m,n
+ # dx: b,m,n @ (b,k,n).t() -> b,m,k
+ out = npu_alltoall_allgather_bmm(
+ grad_output, weight.transpose(-1, -2), group_ep, ep_size, group_tp, tp_size,
+ bias=bias, shard_type=shard_type, act_type="None", need_allgather_out=True, need_activation_feature=False
+ )
+ bmm_out_grad = out[0]
+ allgather_out_grad = out[1]
+
+ # b,m,k @ b,k,n -> b,m,n
+ # dw: (b,m,k).t() @ (b,m,n).t() -> b,k,n
+ grad_bmm_w = torch.bmm(input_.transpose(-1, -2), allgather_out_grad)
+ grad_bias = None
+ if bias is not None:
+ grad_bias = torch.sum(grad_output, dim=-1)
+
+ return bmm_out_grad, grad_bmm_w, grad_bias, None
+
+
+def fused_alltoall_gather_bmm(input_, weight, bias, bmm_param):
+ return FusedAllgatherBmmFunction.apply(input_, weight, bias, bmm_param)
+
+
+def fused_bmm_reducescatter_alltoall(input_, weight, bias, bmm_param):
+ return FusedBmmReduceScatterFunction.apply(input_, weight, bias, bmm_param)
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_mlp_with_comp_and_comm_overlap_all2all.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_mlp_with_comp_and_comm_overlap_all2all.py
new file mode 100644
index 0000000000000000000000000000000000000000..18a487f6d4890be20f5ac98bbe5a4ff43523b9b8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_mlp_with_comp_and_comm_overlap_all2all.py
@@ -0,0 +1,298 @@
+# Copyright (c) 2024, Huawei Technologies.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from einops import rearrange
+from megatron.training import get_args
+from megatron.core import parallel_state
+from megatron.core.parallel_state import get_expert_model_parallel_group, get_tensor_and_expert_parallel_group, get_tensor_model_parallel_group
+from megatron.core.transformer.moe.moe_utils import permute
+from mindspeed.model.transformer import should_recompute_activation
+from mindspeed.core.transformer.moe.moe_layer_overlap_all2all import gmm_op
+from mindspeed.core.transformer.moe.comm_utils import (async_all_to_all, async_reduce_scatter, async_all_gather,
+ transfer_tensor_last_dim_to_first)
+from mindspeed.core.transformer.moe.moe_utils import (only_recompute_activation, forward_func, backward_func,
+ get_gemm_backward_need_tensors,
+ set_all2all_experts_output,
+ permute_with_ep, get_all2all_experts_output,
+ get_permute_with_ep_local_input_tokens)
+from mindspeed.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32
+
+
+class GroupedMlpWithCompAndCommOverlapAll2All(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, inputs, weights1, weights2, args, moe_layer_ctx):
+ original_weight1, original_weight2, activation_func, group_list, layer_number = args
+ global_args = get_args()
+ moe_zero_memory = global_args.moe_zero_memory
+ moe_experts_pipeline_degree = global_args.moe_experts_pipeline_degree
+ ctx.layer_number = layer_number
+ ctx.moe_zero_memory = moe_zero_memory
+ ctx.moe_experts_pipeline_degree = moe_experts_pipeline_degree
+ use_gmm = (inputs.nelement() != 0)
+ ctx.use_gmm = use_gmm
+ if use_gmm:
+ mm1_out = gmm_op(inputs, weights1, [], group_list, 0)[0]
+ else:
+ mm1_out = torch.matmul(inputs, weights1)
+ if moe_zero_memory != "disable" or moe_experts_pipeline_degree:
+ inputs.untyped_storage().resize_(0)
+ act_out, detached_act_inputs = forward_func(activation_func, mm1_out)
+
+ is_only_recompute_activation = only_recompute_activation(layer_number)
+ if moe_zero_memory == "level1" and not is_only_recompute_activation:
+ mm1_out.untyped_storage().resize_(0)
+ if use_gmm:
+ mm2_out = gmm_op(act_out, weights2, [], group_list, 0)[0]
+ else:
+ mm2_out = torch.matmul(act_out, weights2)
+
+ if moe_zero_memory == "level1" and not is_only_recompute_activation:
+ act_out.untyped_storage().resize_(0)
+ moe_layer_ctx.recompute_tensors = (inputs, mm1_out, act_out)
+ is_recompute_activation = moe_zero_memory == "level0" or should_recompute_activation(layer_number) or (
+ moe_zero_memory == "level1" and is_only_recompute_activation)
+ if is_recompute_activation:
+ act_out.untyped_storage().resize_(0)
+ ctx.activation_func = activation_func
+ if moe_zero_memory != "level0" and not (moe_zero_memory == "level1" and is_only_recompute_activation):
+ ctx.save_for_backward(inputs, detached_act_inputs, act_out, weights1, weights2, original_weight1,
+ original_weight2, group_list)
+ else:
+ ctx.save_for_backward(detached_act_inputs, act_out, weights1, weights2, original_weight1, original_weight2,
+ group_list)
+
+ return mm2_out, None
+
+ @staticmethod
+ def backward(ctx, *grad_outs):
+ grad_outs = grad_outs[0]
+ global_args = get_args()
+ moe_hierarchical_alltoallv = global_args.moe_hierarchical_alltoallv
+ layer_number = ctx.layer_number
+ moe_zero_memory = ctx.moe_zero_memory
+ moe_experts_pipeline_degree = ctx.moe_experts_pipeline_degree
+ is_only_recompute_activation = only_recompute_activation(layer_number)
+ if moe_zero_memory != "level0" and not (moe_zero_memory == "level1" and is_only_recompute_activation):
+ mm1_inputs, act_inputs, mm2_inputs, weights1, weights2, original_weight1, original_weight2, group_list = ctx.saved_tensors
+ else:
+ act_inputs, mm2_inputs, weights1, weights2, original_weight1, original_weight2, group_list = ctx.saved_tensors
+ if moe_experts_pipeline_degree:
+ inputs_save = get_gemm_backward_need_tensors()
+ _, inputs, ag_handle_i = async_all_gather(inputs_save, get_tensor_model_parallel_group(()), last_dim=True)
+ else:
+ ((detach_input, indices, scores_ep, router_topk, global_input_tokens_local_experts_indices),
+ permute2_input_detach, permute2_graph, output_splits, input_splits,
+ input_splits_tp_ep) = get_gemm_backward_need_tensors()
+
+ # grad of mm2 dx
+ if ctx.use_gmm:
+ weights2 = rearrange(weights2, 'n h f -> n f h')
+ grad_mm2_inputs = gmm_op(grad_outs, weights2, [], group_list, 0)[0]
+ else:
+ grad_mm2_inputs = torch.matmul(grad_outs, weights2.t())
+ act_graph = mm2_inputs
+ is_recompute_activation = moe_zero_memory == "level0" or should_recompute_activation(layer_number) or (
+ moe_zero_memory == "level1" and is_only_recompute_activation)
+ if is_recompute_activation:
+ activation_func = ctx.activation_func
+ mm2_inputs = activation_func(act_inputs)
+
+ if moe_hierarchical_alltoallv:
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ tp_group = parallel_state.get_tensor_model_parallel_group()
+ permute1_graph, scores_ep, hidden_states_ep = get_all2all_experts_output()
+ if moe_zero_memory == "disable":
+ _, detach_scores_grad, detach_scores_handle = async_reduce_scatter(scores_ep.grad, group=ep_group)
+ else:
+ detach_scores_grad = None
+ detach_scores_handle = None
+
+ # grad of activation_func
+ act_graph.backward(grad_mm2_inputs)
+ if moe_zero_memory == "level0" or (moe_zero_memory == "level1" and is_only_recompute_activation):
+ permutated_local_input_tokens = get_permute_with_ep_local_input_tokens()
+ _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
+ permutated_local_input_tokens,
+ output_splits,
+ input_splits,
+ tp_group,
+ )
+
+ # gmm1 dx
+ if ctx.use_gmm:
+ weights1 = rearrange(weights1, 'n h f -> n f h')
+ mm1_inputs_grad = \
+ gmm_op(act_inputs.grad, weights1, [], group_list, 0)[0]
+ else:
+ mm1_inputs_grad = torch.matmul(act_inputs.grad, weights1.t())
+
+ backward_func(permute2_graph, mm1_inputs_grad)
+ mm1_inputs_grad.untyped_storage().resize_(0)
+
+ if moe_zero_memory == "level0" or (moe_zero_memory == "level1" and is_only_recompute_activation):
+ permute1_ep_all_to_all_handle.wait()
+ permutated_local_input_tokens.untyped_storage().resize_(0)
+ _, permute1_backward_input, bw_permute1_ep_all2all_handle = async_all_to_all(
+ permute2_input_detach.grad,
+ input_splits,
+ output_splits,
+ tp_group,
+ )
+
+ # gmm2 dw
+ if ctx.use_gmm:
+ if get_args().gemm_gradient_accumulation_fusion:
+
+ npu_groupmatmul_add_fp32(mm2_inputs, grad_outs, group_list, original_weight2.main_grad)
+
+ if hasattr(original_weight2, 'grad_added_to_main_grad'):
+ if getattr(weights2, 'zero_out_wgrad', False):
+ grad_weights2 = torch.zeros(
+ weights2.transpose(-1, -2).shape,
+ dtype=mm2_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weights2 = torch.empty(
+ weights2.transpose(-1, -2).shape,
+ dtype=mm2_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ original_weight2.grad_added_to_main_grad = True
+ else:
+ grad_weights2 = None
+ else:
+ grad_weights2 = gmm_op(mm2_inputs.t(), grad_outs, [], group_list, 2)[0]
+ else:
+ grad_weights2 = torch.matmul(mm2_inputs.t(), grad_outs)
+
+ # grad of activation_func
+ grad_outs.untyped_storage().resize_(0)
+ mm2_inputs.untyped_storage().resize_(0)
+ if moe_hierarchical_alltoallv:
+ grad_mm2_inputs.untyped_storage().resize_(0)
+ act_inputs.untyped_storage().resize_(0)
+ bw_permute1_ep_all2all_handle.wait()
+
+ backward_func(permute1_graph, permute1_backward_input)
+ permute1_backward_input.untyped_storage().resize_(0)
+ if moe_zero_memory == "disable":
+ detach_scores_handle.wait()
+
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ _, detach_input_grad, detach_input_handle = async_reduce_scatter(hidden_states_ep.grad, group=ep_group)
+ set_all2all_experts_output((detach_scores_grad, detach_input_grad, detach_input_handle))
+ else:
+ act_graph.backward(grad_mm2_inputs)
+ grad_mm2_inputs.untyped_storage().resize_(0)
+ act_inputs.untyped_storage().resize_(0)
+ if moe_zero_memory == "level0" or (moe_zero_memory == "level1" and is_only_recompute_activation):
+ def alltoall_token_permutation1(hidden_states, indices):
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ permutated_local_input_tokens, _ = permute(
+ hidden_states, indices
+ )
+ return permutated_local_input_tokens
+
+ permutated_local_input_tokens = alltoall_token_permutation1(detach_input, indices)
+
+ ep_group = get_expert_model_parallel_group()
+ if global_args.moe_tp_extend_ep:
+ ep_group = get_tensor_and_expert_parallel_group()
+ _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
+ permutated_local_input_tokens,
+ output_splits,
+ input_splits,
+ ep_group,
+ )
+ if ctx.use_gmm:
+ weights1 = rearrange(weights1, 'n h f -> n f h')
+ mm1_inputs_grad = gmm_op(act_inputs.grad, weights1, [], group_list, 0)[0]
+ else:
+ mm1_inputs_grad = torch.matmul(act_inputs.grad, weights1.t())
+
+ # 峰值
+ if moe_experts_pipeline_degree:
+ ag_handle_i.wait()
+ mm1_inputs = torch.cat(inputs, dim=inputs_save.dim() - 1).contiguous()
+ else:
+ backward_func(permute2_graph, mm1_inputs_grad)
+ mm1_inputs_grad.untyped_storage().resize_(0)
+ ep_group = get_expert_model_parallel_group()
+ if global_args.moe_tp_extend_ep:
+ ep_group = get_tensor_and_expert_parallel_group()
+
+ if moe_zero_memory == "level0" or (moe_zero_memory == "level1" and is_only_recompute_activation):
+ permute1_ep_all_to_all_handle.wait()
+ permutated_local_input_tokens.untyped_storage().resize_(0)
+
+ if moe_experts_pipeline_degree:
+ mm1_inputs_grad, num_dim = transfer_tensor_last_dim_to_first(mm1_inputs_grad)
+ rs_input_i, expert_output, rs_handle_i = async_reduce_scatter(mm1_inputs_grad,
+ get_tensor_model_parallel_group())
+ set_all2all_experts_output((rs_input_i, expert_output, rs_handle_i, mm1_inputs_grad, num_dim))
+ else:
+ _, permute1_backward_input, bw_permute1_ep_all2all_handle = async_all_to_all(
+ permute2_input_detach.grad,
+ input_splits,
+ output_splits,
+ ep_group,
+ )
+ set_all2all_experts_output((permute1_backward_input, bw_permute1_ep_all2all_handle))
+
+ if moe_zero_memory == "level0" or (moe_zero_memory == "level1" and is_only_recompute_activation):
+ mm1_inputs, _ = permute(
+ global_input_tokens, global_input_tokens_local_experts_indices
+ )
+
+ global_input_tokens.untyped_storage().resize_(0)
+
+ if ctx.use_gmm:
+ if get_args().gemm_gradient_accumulation_fusion:
+ npu_groupmatmul_add_fp32(mm1_inputs, act_inputs.grad, group_list, original_weight1.main_grad)
+ if hasattr(original_weight1, 'grad_added_to_main_grad'):
+ if getattr(weights1, 'zero_out_wgrad', False):
+ mm1_weights_grad = torch.zeros(
+ weights1.transpose(-1, -2).shape,
+ dtype=mm1_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ mm1_weights_grad = torch.empty(
+ weights1.transpose(-1, -2).shape,
+ dtype=mm1_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ original_weight1.grad_added_to_main_grad = True
+ else:
+ mm1_weights_grad = None
+ else:
+ mm1_weights_grad = gmm_op(mm1_inputs.t(), act_inputs.grad, [], group_list, 2)[0]
+ else:
+ mm1_weights_grad = torch.matmul(mm1_inputs.t(), act_inputs.grad)
+ act_inputs.grad.untyped_storage().resize_(0)
+ if moe_experts_pipeline_degree:
+ return None, mm1_weights_grad, grad_weights2, None, None
+ else:
+ return mm1_inputs_grad, mm1_weights_grad, grad_weights2, None, None
+
+
+def grouped_mlp_with_comp_and_comm_overlap_all2all(inputs, weights1, weights2, args, ctx):
+ return GroupedMlpWithCompAndCommOverlapAll2All.apply(inputs, weights1, weights2, args, ctx)
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_mlp_with_comp_and_comm_overlap_allgather.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_mlp_with_comp_and_comm_overlap_allgather.py
new file mode 100644
index 0000000000000000000000000000000000000000..42b5a2e79a40aadd152c208b3d002811f41a6ca1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/grouped_mlp_with_comp_and_comm_overlap_allgather.py
@@ -0,0 +1,187 @@
+# Copyright (c) 2024, Huawei Technologies.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import acl
+from einops import rearrange
+from megatron.core.parallel_state import get_expert_model_parallel_group, get_tensor_and_expert_parallel_group, get_tensor_and_expert_parallel_world_size, get_expert_model_parallel_world_size
+from megatron.training import get_args
+from mindspeed.ops.gmm import GMMFunction
+from mindspeed.model.transformer import should_recompute_activation
+from mindspeed.core.transformer.moe.moe_utils import (get_gemm_backward_need_tensors, get_ag_tp_hidden_status,
+ set_rs_global_hidden_states_grad_with_handle)
+from mindspeed.core.transformer.moe.moe_utils import forward_func, backward_func
+from mindspeed.core.transformer.moe.comm_utils import async_all_gather, async_reduce_scatter
+from mindspeed.core.transformer.moe.token_dispatcher import cann_version_check
+from mindspeed.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32
+from .moe_layer_overlap_all2all import gmm_op
+
+
+class GroupedMlpWithCompAndCommOverlapAllGather(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, inputs, weights1, weights2, args):
+ original_weight1, original_weight2, activation_func, group_list, layer_number = args
+ use_gmm = (inputs.nelement() != 0)
+ ctx.use_gmm = use_gmm
+ if use_gmm:
+ mm1_out = gmm_op(inputs, weights1, [], group_list, 0)[0]
+ else:
+ mm1_out = torch.matmul(inputs, weights1)
+ inputs.untyped_storage().resize_(0)
+ act_out, detached_act_inputs = forward_func(activation_func, mm1_out)
+ if use_gmm:
+ mm2_out = gmm_op(act_out, weights2, [], group_list, 0)[0]
+ else:
+ mm2_out = torch.matmul(act_out, weights2)
+ if should_recompute_activation(layer_number):
+ act_out.untyped_storage().resize_(0)
+ ctx.activation_func = activation_func
+ ctx.layer_number = layer_number
+ ctx.save_for_backward(detached_act_inputs, act_out, weights1, weights2, original_weight1, original_weight2, group_list)
+ return mm2_out, None
+
+ @staticmethod
+ def backward(ctx, *grad_outs):
+ grad_outs = grad_outs[0]
+ layer_number = ctx.layer_number
+ act_inputs, act_graph, weights1, weights2, original_weight1, original_weight2, group_list = ctx.saved_tensors
+ token_unpermutation_graph, global_hidden_states_detach, indices, global_local_map = get_gemm_backward_need_tensors()
+
+ # grad of mm2
+ if ctx.use_gmm:
+ weights2 = rearrange(weights2, 'n h f -> n f h')
+ grad_mm2_inputs = gmm_op(grad_outs, weights2, [], group_list, 0)[0]
+ else:
+ grad_mm2_inputs = torch.matmul(grad_outs, weights2.t())
+ if should_recompute_activation(layer_number):
+ activation_func = ctx.activation_func
+ act_out = activation_func(act_inputs)
+ mm2_inputs = act_out
+ else:
+ mm2_inputs = act_graph
+
+ if ctx.use_gmm:
+ if get_args().gemm_gradient_accumulation_fusion:
+
+ npu_groupmatmul_add_fp32(mm2_inputs, grad_outs, group_list, original_weight2.main_grad)
+
+ if hasattr(original_weight2, 'grad_added_to_main_grad'):
+ if getattr(weights2, 'zero_out_wgrad', False):
+ grad_weights2 = torch.zeros(
+ weights2.transpose(-1, -2).shape,
+ dtype=mm2_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weights2 = torch.empty(
+ weights2.transpose(-1, -2).shape,
+ dtype=mm2_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ original_weight2.grad_added_to_main_grad = True
+ else:
+ grad_weights2 = None
+ else:
+ grad_weights2 = gmm_op(mm2_inputs.t(), grad_outs, [], group_list, 2)[0]
+ else:
+ grad_weights2 = torch.matmul(mm2_inputs.t(), grad_outs)
+
+ grad_outs.untyped_storage().resize_(0)
+ mm2_inputs.untyped_storage().resize_(0)
+
+ # grad of activation_func
+ act_graph.backward(grad_mm2_inputs)
+ grad_mm2_inputs.untyped_storage().resize_(0)
+ act_inputs.untyped_storage().resize_(0)
+ mm1_outs_grad = act_inputs.grad
+
+ # re-gather mm1 forward inputs
+ ag_inputs_tp = get_ag_tp_hidden_status()
+ ag_inputs_tp = ag_inputs_tp.view(-1, ag_inputs_tp.shape[-1])
+ ag_group = get_expert_model_parallel_group()
+ if '910B' in acl.get_soc_name() or not get_args().n_shared_experts:
+ ag_group = get_tensor_and_expert_parallel_group()
+ _, ag_inputs_tp_ep, ag_handle = async_all_gather(ag_inputs_tp, ag_group)
+ if ctx.use_gmm:
+ # grad of mm1-inputs
+ weights1 = rearrange(weights1, 'n h f -> n f h')
+ mm1_inputs_grad = gmm_op(act_inputs.grad, weights1, [], group_list, 0)[0]
+ else:
+ mm1_inputs_grad = torch.matmul(act_inputs.grad, weights1.t())
+
+ # token 反重排的反向
+ backward_func(token_unpermutation_graph, mm1_inputs_grad)
+ mm1_inputs_grad.untyped_storage().resize_(0)
+ _, rs_global_hidden_states_grad, rs_handle = async_reduce_scatter(global_hidden_states_detach.grad,
+ get_tensor_and_expert_parallel_group())
+ rs_global_hidden_states_grad_with_handle = (rs_global_hidden_states_grad, rs_handle)
+ ag_handle.wait()
+
+ # token 重排计算
+ global_args = get_args()
+ num_local_experts = global_args.num_experts // get_expert_model_parallel_world_size()
+ if global_args.moe_tp_extend_ep:
+ num_local_experts = global_args.num_experts // get_tensor_and_expert_parallel_world_size()
+ if cann_version_check:
+ mm1_inputs = ag_inputs_tp_ep[global_local_map, :]
+ if num_local_experts > 1:
+ mm1_inputs = mm1_inputs[indices, :]
+ else:
+ mm1_inputs = torch.gather(ag_inputs_tp_ep, 0, global_local_map)
+ if num_local_experts > 1:
+ mm1_inputs = torch.gather(mm1_inputs, 0, indices)
+
+ global_local_map.untyped_storage().resize_(0)
+ indices.untyped_storage().resize_(0)
+ ag_inputs_tp_ep.untyped_storage().resize_(0)
+
+ if ctx.use_gmm:
+ if get_args().gemm_gradient_accumulation_fusion:
+
+ npu_groupmatmul_add_fp32(mm1_inputs, act_inputs.grad, group_list, original_weight1.main_grad)
+
+ if hasattr(original_weight1, 'grad_added_to_main_grad'):
+ if getattr(weights1, 'zero_out_wgrad', False):
+ mm1_weights_grad = torch.zeros(
+ weights1.transpose(-1, -2).shape,
+ dtype=mm1_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ mm1_weights_grad = torch.empty(
+ weights1.transpose(-1, -2).shape,
+ dtype=mm1_inputs.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ original_weight1.grad_added_to_main_grad = True
+ else:
+ mm1_weights_grad = None
+ else:
+ mm1_weights_grad = gmm_op(mm1_inputs.t(), act_inputs.grad, [], group_list, 2)[0]
+ else:
+ mm1_weights_grad = torch.matmul(mm1_inputs.t(), act_inputs.grad)
+
+ mm1_outs_grad.untyped_storage().resize_(0)
+
+ set_rs_global_hidden_states_grad_with_handle(rs_global_hidden_states_grad_with_handle)
+ return mm1_inputs_grad, mm1_weights_grad, grad_weights2, None
+
+
+def grouped_mlp_with_comp_and_comm_overlap_allgather(inputs, weights1, weights2, args):
+ return GroupedMlpWithCompAndCommOverlapAllGather.apply(inputs, weights1, weights2, args)
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/layers.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ca1f8b3485fea7acc0ff5ce1cf2dd2e81173185
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/layers.py
@@ -0,0 +1,638 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+
+import os
+import warnings
+from typing import Any, Callable, List, Optional
+
+import acl
+import torch
+import torch_npu
+import torch.distributed
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torch.nn.parameter import Parameter
+from megatron.core.tensor_parallel.layers import (
+ _initialize_affine_weight_cpu,
+ _initialize_affine_weight_gpu,
+ linear_with_grad_accumulation_and_async_allreduce,
+ linear_with_frozen_weight
+)
+from megatron.core.tensor_parallel.mappings import (
+ copy_to_tensor_model_parallel_region,
+ gather_from_tensor_model_parallel_region,
+ reduce_from_tensor_model_parallel_region,
+ reduce_scatter_to_sequence_parallel_region,
+ scatter_to_tensor_model_parallel_region,
+ _reduce_scatter_along_first_dim,
+ _gather_along_first_dim
+)
+from megatron.core.tensor_parallel.utils import VocabUtility, divide, split_tensor_along_last_dim
+from megatron.core.utils import (
+ make_tp_sharded_tensor_for_checkpoint,
+ prepare_input_tensors_for_wgrad_compute
+)
+from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
+from megatron.core.model_parallel_config import ModelParallelConfig
+from megatron.core.parallel_state import (
+ get_global_memory_buffer,
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
+
+
+class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
+ """See linear_with_grad_accumulation_and_async_allreduce"""
+
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ grad_output_buffer,
+ shared_expert,
+ ):
+ ctx.save_for_backward(input, weight)
+ ctx.use_bias = bias is not None
+ ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
+ ctx.async_grad_allreduce = async_grad_allreduce
+ ctx.sequence_parallel = sequence_parallel
+ ctx.grad_output_buffer = grad_output_buffer
+ ctx.shared_expert = shared_expert
+ ctx.need_save = True
+ if sequence_parallel:
+ if shared_expert:
+ from mindspeed.core.transformer.moe.moe_utils import AG_SHARED_EXPERTS_INPUTS
+ ag_shared_experts_inputs = AG_SHARED_EXPERTS_INPUTS.pop(0)
+ if isinstance(ag_shared_experts_inputs, tuple):
+ ag_shared_experts_inputs, handle = ag_shared_experts_inputs
+ handle.wait()
+ ctx.need_save = False
+ total_input = ag_shared_experts_inputs
+ else:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
+ torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group()
+ )
+ total_input = all_gather_buffer
+ else:
+ total_input = input
+
+ output = torch.matmul(total_input, weight.t())
+
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ grad_output_buffer = ctx.grad_output_buffer
+
+ wgrad_compute = True
+ if grad_output_buffer is not None:
+ grad_output_buffer.append(grad_output)
+ wgrad_compute = False
+
+ if wgrad_compute:
+ from mindspeed.core.transformer.moe.moe_utils import set_ag_tp_hidden_status
+ if ctx.sequence_parallel:
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = get_global_memory_buffer().get_tensor(
+ dim_size, input.dtype, "mpu"
+ )
+ handle = torch.distributed._all_gather_base(
+ all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # gather is scheduled before the input gradient computation
+ total_input = all_gather_buffer
+ else:
+ total_input = input
+ if ctx.need_save and '910B' not in acl.get_soc_name():
+ set_ag_tp_hidden_status(total_input)
+ grad_input = grad_output.matmul(weight)
+
+ if ctx.sequence_parallel and wgrad_compute:
+ handle.wait()
+
+ if wgrad_compute:
+ grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
+ grad_output, total_input
+ )
+
+ if ctx.async_grad_allreduce:
+ # Asynchronous all-reduce
+ handle = torch.distributed.all_reduce(
+ grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # all-reduce is scheduled before the weight gradient computation
+
+ if ctx.sequence_parallel:
+ assert not ctx.async_grad_allreduce
+ dim_size = list(input.size())
+ sub_grad_input = torch.empty(
+ dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
+ )
+ # reduce_scatter
+ handle = torch.distributed._reduce_scatter_base(
+ sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # reduce scatter is scheduled before the weight gradient computation
+
+ if ctx.gradient_accumulation_fusion:
+ if wgrad_compute:
+ import fused_weight_gradient_mlp_cuda
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+
+ if hasattr(weight, 'grad_added_to_main_grad'):
+ # When overlap_grad_reduce is True, need to ensure that backward hooks
+ # are all run on the main backprop thread to prevent deadlocks. Setup
+ # dummy grad_weight tensor to prevent backward hooks from being run
+ # in a background thread.
+ if getattr(weight, 'zero_out_wgrad', False):
+ grad_weight = torch.zeros(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ else:
+ grad_weight = torch.empty(
+ weight.main_grad.shape,
+ dtype=input.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ weight.grad_added_to_main_grad = True
+ else:
+ grad_weight = None
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.sequence_parallel:
+ handle.wait()
+ # Need to return None's as gradient has to flow for all the input arguments
+ # provided during forward
+ return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None
+
+ if ctx.async_grad_allreduce:
+ handle.wait()
+
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None
+
+
+def linear_with_grad_accumulation_and_async_allreduce(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias: Optional[torch.Tensor],
+ gradient_accumulation_fusion: bool,
+ async_grad_allreduce: bool,
+ sequence_parallel: bool,
+ grad_output_buffer: Optional[List[torch.Tensor]] = None,
+ shared_expert: bool = False
+) -> torch.Tensor:
+ args = [
+ input,
+ weight,
+ bias,
+ gradient_accumulation_fusion,
+ async_grad_allreduce,
+ sequence_parallel,
+ grad_output_buffer,
+ shared_expert,
+ ]
+
+ if not linear_with_grad_accumulation_and_async_allreduce.warned:
+ if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
+ if sequence_parallel:
+ warnings.warn(
+ "When using sequence parallelism it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce.warned = True
+
+ if async_grad_allreduce:
+ warnings.warn(
+ "When using async grad allreduce it is recommended to set the "
+ "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
+ "maximum speedup"
+ )
+ linear_with_grad_accumulation_and_async_allreduce.warned = True
+
+ return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
+
+
+linear_with_grad_accumulation_and_async_allreduce.warned = False
+
+
+class ColumnParallelLinear(torch.nn.Module):
+
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ *,
+ config: ModelParallelConfig,
+ init_method: Callable,
+ bias=True,
+ gather_output=False,
+ stride=1,
+ keep_master_weight_for_test=False,
+ skip_bias_add=False,
+ skip_weight_param_allocation: bool = False,
+ embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
+ grad_output_buffer: Optional[List[torch.Tensor]] = None,
+ is_expert: bool = False,
+ tp_comm_buffer_name: str = None, # Not used
+ shared_expert: bool = False
+ ):
+ super(ColumnParallelLinear, self).__init__()
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.gather_output = gather_output
+ # Divide the weight matrix along the last dimension.
+ world_size = get_tensor_model_parallel_world_size()
+ self.output_size_per_partition = divide(output_size, world_size)
+ self.skip_bias_add = skip_bias_add
+ self.is_expert = is_expert
+ self.expert_parallel = config.expert_model_parallel_size > 1
+ self.embedding_activation_buffer = embedding_activation_buffer
+ self.grad_output_buffer = grad_output_buffer
+ self.config = config
+ self.shared_expert = shared_expert
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ if not skip_weight_param_allocation:
+ if config.use_cpu_initialization:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size_per_partition, self.input_size, dtype=config.params_dtype
+ )
+ )
+ if config.perform_initialization:
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight,
+ self.output_size,
+ self.input_size,
+ self.output_size_per_partition,
+ 0,
+ init_method,
+ stride=stride,
+ return_master_weight=keep_master_weight_for_test,
+ )
+ else:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size_per_partition,
+ self.input_size,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+ if config.perform_initialization:
+ _initialize_affine_weight_gpu(
+ self.weight,
+ init_method,
+ partition_dim=0,
+ stride=stride,
+ expert_parallel=(self.is_expert and self.expert_parallel),
+ )
+
+ setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
+ else:
+ self.weight = None
+
+ self.register_parameter('bias', None)
+
+ self.async_tensor_model_parallel_allreduce = (
+ config.async_tensor_model_parallel_allreduce and world_size > 1
+ )
+
+ self.sequence_parallel = config.sequence_parallel
+ if self.sequence_parallel and world_size <= 1:
+ self.sequence_parallel = False
+
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
+
+ if self.async_tensor_model_parallel_allreduce and self.sequence_parallel:
+ raise RuntimeError(
+ "`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
+ "cannot be enabled at the same time."
+ )
+
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+ self.explicit_expert_comm = self.is_expert and (
+ self.sequence_parallel or self.expert_parallel
+ )
+
+ # Hook adding a default empty _extra_state for state dict
+ self._register_load_state_dict_pre_hook(
+ lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
+ f'{prefix}_extra_state'
+ )
+ )
+
+ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
+ """Forward of ColumnParallelLinear
+
+ Args:
+ input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
+
+ weight (optional): weight tensor to use, compulsory when
+ skip_weight_param_allocation is True.
+
+ Returns:
+ - output
+ - bias
+
+ """
+ if weight is None:
+ if self.weight is None:
+ raise RuntimeError(
+ "weight was not supplied to ColumnParallelLinear forward pass "
+ "and skip_weight_param_allocation is True."
+ )
+ weight = self.weight
+ else:
+ # Check the weight passed in is the correct shape
+ expected_shape = (self.output_size_per_partition, self.input_size)
+ if weight.shape != expected_shape:
+ raise RuntimeError(
+ f"supplied weight's shape is {tuple(weight.shape)}, "
+ f"not {expected_shape} as expected"
+ )
+
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ bias = self.bias if not self.skip_bias_add else None
+
+ if (
+ self.async_tensor_model_parallel_allreduce
+ or self.sequence_parallel
+ or self.explicit_expert_comm
+ ):
+ input_parallel = input_
+ else:
+ input_parallel = copy_to_tensor_model_parallel_region(input_)
+
+ if self.config.defer_embedding_wgrad_compute:
+ self.embedding_activation_buffer.append(input_parallel)
+
+ # Matrix multiply.
+ if not weight.requires_grad:
+ self._forward_impl = linear_with_frozen_weight
+ else:
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=weight,
+ bias=bias,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False
+ if self.explicit_expert_comm
+ else self.async_tensor_model_parallel_allreduce,
+ sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
+ grad_output_buffer=self.grad_output_buffer
+ if self.config.defer_embedding_wgrad_compute
+ else None,
+ shared_expert=self.shared_expert
+ )
+ if self.gather_output:
+ # All-gather across the partitions.
+ assert not self.sequence_parallel
+ output = gather_from_tensor_model_parallel_region(output_parallel)
+ else:
+ output = output_parallel
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
+ """ Sharding along axis 0, bias sharded """
+ state_dict = self.state_dict(prefix='', keep_vars=True)
+ return make_sharded_tensors_for_checkpoint(
+ state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
+ )
+
+ def set_extra_state(self, state: Any):
+ """ Extra state is ignored """
+
+ def get_extra_state(self) -> None:
+ """ Keep compatibility with TE state dict. """
+ return None
+
+
+class RowParallelLinear(torch.nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ *,
+ config: ModelParallelConfig,
+ init_method: Callable,
+ bias: bool,
+ input_is_parallel: bool,
+ skip_bias_add: bool,
+ stride: int = 1,
+ keep_master_weight_for_test: bool = False,
+ is_expert: bool = False,
+ tp_comm_buffer_name: str = None, # Not used
+ shared_expert: bool = False
+ ):
+ super(RowParallelLinear, self).__init__()
+
+ # Keep input parameters
+ self.input_size = input_size
+ self.output_size = output_size
+ self.input_is_parallel = input_is_parallel
+ # Divide the weight matrix along the last dimension.
+ world_size = get_tensor_model_parallel_world_size()
+ self.input_size_per_partition = divide(input_size, world_size)
+ self.skip_bias_add = skip_bias_add
+ self.config = config
+ self.is_expert = is_expert
+ self.expert_parallel = config.expert_model_parallel_size > 1
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
+ self.sequence_parallel = config.sequence_parallel
+ self.shared_expert = shared_expert
+ if self.sequence_parallel and not self.input_is_parallel:
+ raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
+
+ # Parameters.
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
+ # we allocate the transpose.
+ # Initialize weight.
+ if config.use_cpu_initialization:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size, self.input_size_per_partition, dtype=config.params_dtype
+ )
+ )
+ if config.perform_initialization:
+ self.master_weight = _initialize_affine_weight_cpu(
+ self.weight,
+ self.output_size,
+ self.input_size,
+ self.input_size_per_partition,
+ 1,
+ init_method,
+ stride=stride,
+ return_master_weight=keep_master_weight_for_test,
+ params_dtype=config.params_dtype,
+ )
+ else:
+ self.weight = Parameter(
+ torch.empty(
+ self.output_size,
+ self.input_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+ if config.perform_initialization:
+ _initialize_affine_weight_gpu(
+ self.weight,
+ init_method,
+ partition_dim=1,
+ stride=stride,
+ expert_parallel=(self.is_expert and self.expert_parallel),
+ )
+ setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
+
+ if bias:
+ if config.use_cpu_initialization:
+ self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
+ else:
+ self.bias = Parameter(
+ torch.empty(
+ self.output_size,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+
+ if config.perform_initialization:
+ # Always initialize bias to zero.
+ with torch.no_grad():
+ self.bias.zero_()
+ setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
+ setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
+ else:
+ self.register_parameter('bias', None)
+
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+ self.explicit_expert_comm = self.is_expert and (
+ self.sequence_parallel or self.expert_parallel
+ )
+
+ # Hook adding a default empty _extra_state for state dict
+ self._register_load_state_dict_pre_hook(
+ lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
+ f'{prefix}_extra_state'
+ )
+ )
+
+ def forward(self, input_):
+ """Forward of RowParallelLinear
+
+ Args:
+ input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
+
+ Returns:
+ - output
+ - bias
+ """
+
+ if self.config._cpu_offloading_context is not None:
+ if self.config._cpu_offloading_context.inside_context == True:
+ assert (
+ self.config.cpu_offloading == False
+ ), "CPU Offloading cannot be enabled while using non-TE modules"
+
+ # Set up backprop all-reduce.
+ if self.input_is_parallel:
+ input_parallel = input_
+ else:
+ assert not self.sequence_parallel
+ input_parallel = scatter_to_tensor_model_parallel_region(input_)
+ # Matrix multiply.
+ if not self.weight.requires_grad:
+ self._forward_impl = linear_with_frozen_weight
+ else:
+ self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
+ output_parallel = self._forward_impl(
+ input=input_parallel,
+ weight=self.weight,
+ bias=None,
+ gradient_accumulation_fusion=self.gradient_accumulation_fusion,
+ async_grad_allreduce=False,
+ sequence_parallel=False,
+ )
+
+ # All-reduce across all the partitions.
+ if self.explicit_expert_comm or self.shared_expert:
+ assert self.skip_bias_add
+ output_ = output_parallel
+ elif self.sequence_parallel:
+ output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
+ else:
+ output_ = reduce_from_tensor_model_parallel_region(output_parallel)
+ if not self.skip_bias_add:
+ output = (output_ + self.bias) if self.bias is not None else output_
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.bias
+ return output, output_bias
+
+ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
+ """ Sharding along axis 1, bias not sharded """
+ state_dict = self.state_dict(prefix='', keep_vars=True)
+ return make_sharded_tensors_for_checkpoint(
+ state_dict, prefix, {'weight': 1}, sharded_offsets
+ )
+
+ def set_extra_state(self, state: Any):
+ """ Extra state is ignored """
+
+ def get_extra_state(self) -> None:
+ """ Keep compatibility with TE state dict. """
+ return None
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac9b45a005f192eac43b39438eae2bd7ffb2215
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import types
+import copy
+from copy import deepcopy
+from functools import wraps
+import torch
+from megatron.training import get_args
+from megatron.core import parallel_state, tensor_parallel
+from megatron.core.transformer.mlp import MLPSubmodules, MLP
+from megatron.core.transformer.moe.moe_layer import MoELayer
+from megatron.core.transformer.moe.router import TopKRouter
+from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP
+from megatron.core.transformer.moe.token_dispatcher import (
+ MoEAllGatherTokenDispatcher,
+ MoEAlltoAllTokenDispatcher,
+)
+from mindspeed.core.transformer.moe.moe_layer_overlap_all2all import MoELayerOverlapAll2All
+from mindspeed.core.transformer.moe.moe_layer_overlap_allgather import MoELayerOverlapAllGather
+
+
+def base_moe_init_wrapper(init_func):
+ @wraps(init_func)
+ def base_moe_init(*args, **kwargs):
+ init_func(*args, **kwargs)
+ self = args[0]
+ global_args = get_args()
+ if global_args.moe_tp_extend_ep:
+ tp_size = parallel_state.get_tensor_model_parallel_world_size()
+ assert self.config.num_moe_experts % (self.expert_parallel_size * tp_size) == 0
+ self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size // tp_size
+ local_expert_indices_offset = (
+ parallel_state.get_expert_model_parallel_rank() * self.num_local_experts * tp_size + \
+ parallel_state.get_tensor_model_parallel_rank() * self.num_local_experts
+ )
+ self.local_expert_indices = [
+ local_expert_indices_offset + i for i in range(self.num_local_experts)
+ ]
+ assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
+
+ return base_moe_init
+
+
+def moe_layer_init(self, config, submodules=None, layer_number=None):
+ self.submodules = submodules
+ super(MoELayer, self).__init__(config=config, layer_number=layer_number)
+ self.router = TopKRouter(config=self.config)
+ moe_experts_pipeline_degree = get_args().moe_experts_pipeline_degree
+ if self.config.moe_grouped_gemm:
+ if moe_experts_pipeline_degree == 0:
+ self.experts = GroupedMLP(self.num_local_experts, self.config)
+ else:
+ expert = GroupedMLP(self.num_local_experts // moe_experts_pipeline_degree, self.config)
+ self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(moe_experts_pipeline_degree)])
+ else:
+ if not isinstance(self.submodules, MLPSubmodules):
+ raise TypeError("submodules should be instance of MLPSubmodules")
+ self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
+ if config.moe_token_dispatcher_type == "allgather":
+ self.token_dispatcher = MoEAllGatherTokenDispatcher(
+ self.num_local_experts, self.local_expert_indices, config=self.config
+ )
+ elif config.moe_token_dispatcher_type == "alltoall":
+ self.token_dispatcher = MoEAlltoAllTokenDispatcher(
+ self.num_local_experts, self.local_expert_indices, config=self.config
+ )
+ else:
+ raise ValueError(
+ f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
+ )
+
+ return moe_layer_init
+
+
+def moe_layer_init_wrapper(init_func):
+ @wraps(init_func)
+ def wrapper(*args, **kwargs):
+ init_func(*args, **kwargs)
+ self = args[0]
+ global_args = get_args()
+ self.moe_alltoall_overlap_comm = global_args.moe_alltoall_overlap_comm
+ self.moe_allgather_overlap_comm = global_args.moe_allgather_overlap_comm
+
+ if global_args.n_shared_experts:
+ config = deepcopy(self.config)
+ config.ffn_hidden_size = global_args.n_shared_experts * self.config.ffn_hidden_size
+ if self.moe_allgather_overlap_comm or self.moe_alltoall_overlap_comm:
+ from mindspeed.core.transformer.moe.layers import ColumnParallelLinear, RowParallelLinear
+ self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear,
+ linear_fc2=RowParallelLinear,),
+ shared_expert=True)
+ else:
+ from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
+ self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear,
+ linear_fc2=RowParallelLinear,))
+
+ self.moe_adaptive_recompute_activation = global_args.moe_adaptive_recompute_activation
+ self.recompute_threshold = 0
+ if hasattr(self.config, 'moe_token_dispatcher_type') and self.config.moe_token_dispatcher_type == 'allgather':
+ self.moe_adaptive_recompute_activation_scale = global_args.moe_adaptive_recompute_activation_scale
+ self.recompute_threshold = parallel_state.get_tensor_model_parallel_world_size() * parallel_state.get_data_parallel_world_size() * \
+ self.config.moe_router_topk * self.moe_adaptive_recompute_activation_scale / self.config.num_moe_experts
+ self.token_dispatcher.all_tokens_per_expert = None
+ self.forward = types.MethodType(moe_adaptive_forward, self)
+
+ return wrapper
+
+
+def moe_adaptive_forward(self, hidden_states: torch.Tensor):
+ if self.moe_alltoall_overlap_comm:
+ return MoELayerOverlapAll2All.apply(hidden_states, self)
+ if self.moe_allgather_overlap_comm:
+ return MoELayerOverlapAllGather.apply(hidden_states, self)
+
+ def custom_forward(hidden_states):
+ args = get_args()
+ scores, indices = self.router(hidden_states)
+ if args.n_shared_experts:
+ if not hasattr(self, 'comm_stream'):
+ self.comm_stream = torch.cuda.Stream()
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.comm_stream):
+ share_experts_output, share_experts_bias = self.shared_experts(hidden_states)
+ (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
+ hidden_states, scores, indices
+ )
+ expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
+ output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
+ if args.n_shared_experts:
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
+ output = output + share_experts_output
+ if self.token_dispatcher.add_bias:
+ mlp_bias = mlp_bias + share_experts_bias
+ return output, mlp_bias
+
+ threshold = hidden_states.shape[0] * hidden_states.shape[1] * self.recompute_threshold
+ moe_adaptive_recompute_activation_bool = self.moe_adaptive_recompute_activation and \
+ (self.token_dispatcher.all_tokens_per_expert is None or torch.max(self.token_dispatcher.all_tokens_per_expert) > threshold)
+ if self.moe_layer_recompute or moe_adaptive_recompute_activation_bool:
+ output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
+ else:
+ output, mlp_bias = custom_forward(hidden_states)
+ return output, mlp_bias
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer_overlap_all2all.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer_overlap_all2all.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf446323e1d34d23e0b88c94024f5e44f095b31c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer_overlap_all2all.py
@@ -0,0 +1,605 @@
+# Copyright (c) 2024, Huawei Technologies.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size
+from megatron.core import tensor_parallel, parallel_state
+from megatron.core.transformer.moe.moe_layer import MoELayer
+from megatron.training import get_args
+from megatron.core.transformer.moe.moe_utils import permute, save_to_aux_losses_tracker
+from mindspeed.moe.utils import MoEAuxLossAutoScaler
+from mindspeed.core.transformer.moe.comm_utils import (async_all_to_all, async_all_gather, async_reduce_scatter,
+ transfer_tensor_last_dim_to_first, transfer_tensor_first_dim_to_last)
+from mindspeed.core.transformer.moe.moe_utils import (forward_func, backward_func, permute_with_ep)
+from mindspeed.ops.gmm import GMMFunction
+from mindspeed.core.transformer.moe.moe_utils import (AG_SHARED_EXPERTS_INPUTS, only_recompute_activation,
+ set_gemm_backward_need_tensors,
+ set_all2all_experts_output, get_all2all_experts_output,
+ get_prob_backward_need_tensors,
+ set_permute_with_ep_local_input_tokens)
+
+
+def gmm_op(x, weight, bias, group_list, group_type):
+ if isinstance(group_list, torch.Tensor) and group_list.device.type == 'cpu':
+ group_list = group_list.tolist()
+ return GMMFunction.builder.load().npu_gmm([x], [weight], bias, group_list, group_type, 0)
+
+
+def moe_experts_pipeline_forward_func(tokens_per_expert, moe_layer, dispatched_input, ctx, save_tensors):
+ input_list = []
+ expert_graphs = []
+ expert_outputs = []
+ tokens_per_expert_list = []
+ moe_experts_pipeline_degree = ctx.moe_experts_pipeline_degree
+
+ # 1. 划分子集
+ # 赋值self.input_list和self.tokens_per_expert_list
+ tokens_per_expert = tokens_per_expert.cpu()
+ group_list = torch.cumsum(tokens_per_expert, dim=0)
+ num_experts_overlap = moe_layer.num_local_experts // moe_experts_pipeline_degree
+
+ for i in range(moe_experts_pipeline_degree):
+ start_id = i * num_experts_overlap
+ start = 0
+ if i != 0:
+ start = group_list[start_id - 1]
+ end_id = (i + 1) * num_experts_overlap
+ end = group_list[end_id - 1]
+ input_i = dispatched_input[start : end]
+ tokens_per_expert_i = tokens_per_expert[start_id : end_id]
+ input_list.append(input_i)
+ tokens_per_expert_list.append(tokens_per_expert_i)
+ ctx.input_list = input_list
+
+ # 2. 对每个专家子集的输入数据进行模型计算,并将计算结果保存在expert_outputs中
+ ag_handle_i_next = None
+ rs_handle_i = None
+ input_i_next = None
+ num_dim = None
+ rs_input_i = None
+
+ for i in range(moe_experts_pipeline_degree):
+ if i == 0:
+ _, input_i, ag_handle_i = async_all_gather(input_list[i], get_tensor_model_parallel_group(), last_dim=True)
+ _, input_i_next, ag_handle_i_next = async_all_gather(input_list[i + 1], get_tensor_model_parallel_group(), last_dim=True)
+ elif i != (moe_experts_pipeline_degree - 1):
+ input_i = input_i_next
+ ag_handle_i = ag_handle_i_next
+ _, input_i_next, ag_handle_i_next = async_all_gather(input_list[i + 1], get_tensor_model_parallel_group(),
+ last_dim=True)
+ else:
+ input_i = input_i_next
+ ag_handle_i = ag_handle_i_next
+
+ ag_handle_i.wait()
+ input_i = torch.cat(input_i, dim=input_list[i].dim() - 1).contiguous()
+ input_i = input_i.detach()
+ input_i.requires_grad = True
+ (expert_output, mlp_bias), *_ = forward_func(moe_layer.experts[i], (input_i, tokens_per_expert_list[i], ctx))
+ if rs_handle_i is not None:
+ rs_handle_i.wait()
+ rs_input_i.untyped_storage().resize_(0)
+ expert_graphs[i - 1].untyped_storage().resize_(0)
+ expert_outputs[i - 1] = transfer_tensor_first_dim_to_last(expert_outputs[i - 1], num_dim)
+ expert_outputs[i - 1].requires_grad = True
+ # sub expert graph
+ expert_graphs.append(expert_output)
+
+ expert_output, num_dim = transfer_tensor_last_dim_to_first(expert_output)
+ rs_input_i, rs_expert_output, rs_handle_i = async_reduce_scatter(expert_output, get_tensor_model_parallel_group())
+
+ expert_outputs.append(rs_expert_output)
+
+ if i == (moe_experts_pipeline_degree - 1):
+ rs_handle_i.wait()
+ rs_input_i.untyped_storage().resize_(0)
+ expert_graphs[i].untyped_storage().resize_(0)
+ expert_outputs[i] = transfer_tensor_first_dim_to_last(expert_outputs[i], num_dim)
+ expert_outputs[i].requires_grad = True
+
+ ctx.expert_graphs = expert_graphs
+ ctx.expert_outputs = expert_outputs
+
+ # 3. 将所有子集的计算结果拼接在一起,保存在`expert_output`中
+ with torch.enable_grad():
+ expert_output = torch.cat(expert_outputs, dim=0)
+
+ for temp in expert_outputs:
+ temp.untyped_storage().resize_(0)
+
+ return expert_output, mlp_bias
+
+
+def moe_experts_pipeline_backward_func(ctx, input_list):
+ expert_grad_outputs = []
+
+ ag_handle_i_next = None
+ rs_handle_i = None
+ input_i_next = None
+ num_dim = None
+ mm1_inputs_grad = None
+ ag_input_i = None
+ ag_input_i_next = None
+ rs_input_i = None
+ ag_input_list = []
+
+ moe_experts_pipeline_degree = ctx.moe_experts_pipeline_degree
+ expert_graphs = ctx. expert_graphs
+ expert_outputs = ctx.expert_outputs
+
+ for i in range(moe_experts_pipeline_degree):
+ if i == 0:
+ ag_input_i, input_i, ag_handle_i = async_all_gather(expert_outputs[i].grad, get_tensor_model_parallel_group(),
+ last_dim=True)
+ ag_input_i_next, input_i_next, ag_handle_i_next = async_all_gather(expert_outputs[i + 1].grad,
+ get_tensor_model_parallel_group(),
+ last_dim=True)
+ elif i != (moe_experts_pipeline_degree - 1):
+ input_i = input_i_next
+ ag_handle_i = ag_handle_i_next
+ ag_input_i = ag_input_i_next
+ ag_input_i_next, input_i_next, ag_handle_i_next = async_all_gather(expert_outputs[i + 1].grad,
+ get_tensor_model_parallel_group(),
+ last_dim=True)
+ else:
+ input_i = input_i_next
+ ag_handle_i = ag_handle_i_next
+ ag_input_i = ag_input_i_next
+
+ ag_handle_i.wait()
+ ag_input_list.append(ag_input_i)
+ input_i = torch.cat(input_i, dim=expert_outputs[i].grad.dim() - 1).contiguous()
+
+ set_gemm_backward_need_tensors(input_list[i])
+
+ backward_func(expert_graphs[i], input_i)
+
+ if rs_handle_i is not None:
+ rs_handle_i.wait()
+ rs_input_i.untyped_storage().resize_(0)
+ mm1_inputs_grad.untyped_storage().resize_(0)
+ expert_grad_outputs[i - 1] = transfer_tensor_first_dim_to_last(expert_grad_outputs[i - 1], num_dim)
+
+ rs_input_i, expert_output, rs_handle_i, mm1_inputs_grad, num_dim = get_all2all_experts_output()
+ expert_grad_outputs.append(expert_output)
+
+ if i == (moe_experts_pipeline_degree - 1):
+ rs_handle_i.wait()
+ rs_input_i.untyped_storage().resize_(0)
+ mm1_inputs_grad.untyped_storage().resize_(0)
+ expert_grad_outputs[i] = transfer_tensor_first_dim_to_last(expert_grad_outputs[i], num_dim)
+
+ for ag_input in ag_input_list:
+ ag_input.untyped_storage().resize_(0)
+
+ expert_grad_output = torch.cat(expert_grad_outputs, dim=0)
+ return expert_grad_output
+
+
+class MoELayerOverlapAll2All(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, hidden_states, moe_layer: MoELayer):
+ args = get_args()
+ moe_hierarchical_alltoallv = args.moe_hierarchical_alltoallv
+ moe_experts_pipeline_degree = args.moe_experts_pipeline_degree
+ ctx.moe_experts_pipeline_degree = moe_experts_pipeline_degree
+ save_tensors = []
+ ctx.input_shape = hidden_states.shape
+ hidden_states = hidden_states.detach()
+ hidden_states.requires_grad = True
+ ctx.is_only_recompute_activation = only_recompute_activation(moe_layer.layer_number)
+ ctx.layer_number = moe_layer.layer_number
+ if not moe_hierarchical_alltoallv and args.n_shared_experts:
+ if get_tensor_model_parallel_world_size() > 1:
+ _, shared_experts_input, shared_experts_allgather_handle = async_all_gather(
+ hidden_states, get_tensor_model_parallel_group(), is_use_get_global_memory_buffer=True
+ )
+ AG_SHARED_EXPERTS_INPUTS.append((shared_experts_input, shared_experts_allgather_handle))
+
+ # router
+ with torch.enable_grad():
+ scores, indices = moe_layer.router(hidden_states)
+
+ save_tensors.append(scores)
+ scores = scores.detach()
+ scores.requires_grad = True
+ save_tensors.append(scores)
+ moe_zero_memory = args.moe_zero_memory
+ n_shared_experts = args.n_shared_experts
+ ctx.n_shared_experts = n_shared_experts
+ ctx.moe_zero_memory = moe_zero_memory
+ shared_expert_gate = hasattr(args, 'shared_expert_gate') and args.shared_expert_gate
+ group_limited_greedy = hasattr(args, 'moe_router_load_balancing_type') and args.moe_router_load_balancing_type == "group_limited_greedy"
+ ctx.shared_expert_gate = shared_expert_gate
+
+ if moe_zero_memory == "level1" and not ctx.is_only_recompute_activation:
+ ctx.activation_func = moe_layer.experts.activation_func
+ ctx.hidden_size = moe_layer.experts.config.hidden_size
+ ctx.num_local_experts = moe_layer.experts.num_local_experts
+ ctx.weight1 = moe_layer.experts.weight1
+ ctx.moe_grouped_gemm = moe_layer.token_dispatcher.config.moe_grouped_gemm
+ ctx.num_local_experts = moe_layer.token_dispatcher.num_local_experts
+
+ save_tensors.append(indices)
+
+ if n_shared_experts:
+ ctx.shared_experts = moe_layer.shared_experts
+ else:
+ ctx.shared_experts = None
+
+ if shared_expert_gate:
+ shared_expert_gate = moe_layer.shared_expert_gate
+ else:
+ shared_expert_gate = None
+
+ (share_experts_output, dispatched_input, tokens_per_expert) = moe_layer.token_dispatcher.token_permutation(
+ hidden_states, scores, indices, ctx.shared_experts, save_tensors, shared_expert_gate, ctx
+ )
+ if moe_experts_pipeline_degree:
+ save_tensors.append(None)
+ save_tensors.append(None)
+ expert_output, mlp_bias = moe_experts_pipeline_forward_func(tokens_per_expert, moe_layer, dispatched_input, ctx, save_tensors)
+ output, mlp_bias = moe_layer.token_dispatcher.token_unpermutation(expert_output, mlp_bias, save_tensors)
+
+
+ if isinstance(share_experts_output, tuple):
+ share_experts_output, rs_share_experts_output, rs_shared_experts_handle = share_experts_output
+ else:
+ rs_share_experts_output = share_experts_output
+ rs_shared_experts_handle = None
+
+ expert_output.untyped_storage().resize_(0)
+ else:
+
+ if isinstance(share_experts_output, tuple):
+ share_experts_output, rs_share_experts_output, rs_shared_experts_handle = share_experts_output
+ else:
+ rs_share_experts_output = share_experts_output
+ rs_shared_experts_handle = None
+
+ (expert_output, mlp_bias), *_ = forward_func(moe_layer.experts, (dispatched_input, tokens_per_expert, ctx))
+ save_tensors.append(expert_output)
+
+ output, mlp_bias = moe_layer.token_dispatcher.token_unpermutation(expert_output, mlp_bias, save_tensors)
+
+ if group_limited_greedy:
+ save_tensors.append(moe_layer.router.l_aux)
+ moe_layer.router.l_aux = moe_layer.router.l_aux.detach()
+ moe_layer.router.l_aux.requires_grad = True
+ save_tensors.append(moe_layer.router.l_aux)
+ with torch.enable_grad():
+ save_to_aux_losses_tracker(
+ "load_balancing_loss",
+ moe_layer.router.l_aux,
+ moe_layer.layer_number,
+ moe_layer.config.num_layers,
+ )
+ save_to_aux_losses_tracker(
+ "load_balancing_expert_level_loss",
+ moe_layer.router.l_expert_aux / args.moe_aux_loss_coeff,
+ moe_layer.layer_number,
+ moe_layer.config.num_layers,
+ )
+ if hasattr(moe_layer.router, 'l_device_aux'):
+ save_to_aux_losses_tracker(
+ "load_balancing_device_level_loss",
+ moe_layer.router.l_device_aux / args.moe_device_level_aux_loss_coeff,
+ moe_layer.layer_number,
+ moe_layer.config.num_layers,
+ )
+ if hasattr(moe_layer.router, 'l_comm_aux'):
+ save_to_aux_losses_tracker(
+ "load_balancing_comm_level_loss",
+ moe_layer.router.l_comm_aux / args.moe_comm_aux_loss_coeff,
+ moe_layer.layer_number,
+ moe_layer.config.num_layers,
+ )
+ output = MoEAuxLossAutoScaler.apply(output, moe_layer.router.l_aux)
+ else:
+ save_tensors.append(None)
+ save_tensors.append(None)
+
+ save_tensors.append(hidden_states)
+
+ if moe_zero_memory == "level1" and not ctx.is_only_recompute_activation:
+ ctx.tokens_per_expert = tokens_per_expert
+
+ ctx.output_splits = moe_layer.token_dispatcher.output_splits
+ ctx.input_splits = moe_layer.token_dispatcher.input_splits
+ ctx.router_topk = moe_layer.token_dispatcher.router_topk
+ ctx.input_splits_tp_ep = getattr(moe_layer.token_dispatcher, 'input_splits_tp_ep', None)
+ if n_shared_experts:
+ if rs_shared_experts_handle is not None:
+ rs_shared_experts_handle.wait()
+ output_sum = output + rs_share_experts_output
+ output.untyped_storage().resize_(0)
+ share_experts_output.untyped_storage().resize_(0)
+ else:
+ output_sum = output.detach()
+
+ save_tensors.append(share_experts_output)
+ if hasattr(moe_layer.token_dispatcher, 'global_input_tokens_local_experts_indices'):
+ save_tensors.append(moe_layer.token_dispatcher.global_input_tokens_local_experts_indices)
+ else:
+ save_tensors.append(None)
+ ctx.save_for_backward(*save_tensors)
+ return output_sum, mlp_bias
+
+ @staticmethod
+ def backward(ctx, *args):
+ global_args = get_args()
+
+ output_splits = ctx.output_splits
+ input_splits = ctx.input_splits
+ router_topk = ctx.router_topk
+ n_shared_experts = ctx.n_shared_experts
+ moe_zero_memory = ctx.moe_zero_memory
+ moe_experts_pipeline_degree = ctx.moe_experts_pipeline_degree
+ moe_tp_extend_ep = global_args.moe_tp_extend_ep
+ moe_hierarchical_alltoallv = global_args.moe_hierarchical_alltoallv
+ shared_expert_gate = ctx.shared_expert_gate
+ input_splits_tp_ep = ctx.input_splits_tp_ep
+
+ (route_graph, detach_scores,
+ indices, indices_ep,
+ hidden_states_ep, scores_ep,
+ permute1_graph,
+ permute2_input_detach, permute2_graph,
+ experts_graph,
+ unpermute1_input_detach, unpermute1_graph,
+ unpermute2_input_detach, unpermute2_graph, l_aux_graph, l_aux_detach,
+ detach_input, share_experts_graph,
+ global_input_tokens_local_experts_indices,
+ ) = ctx.saved_tensors
+ if moe_hierarchical_alltoallv:
+ set_gemm_backward_need_tensors(
+ ((hidden_states_ep, indices_ep, scores_ep, router_topk, global_input_tokens_local_experts_indices),
+ permute2_input_detach, permute2_graph,
+ output_splits, input_splits, input_splits_tp_ep))
+ elif moe_experts_pipeline_degree:
+ input_list = ctx.input_list
+ else:
+ set_gemm_backward_need_tensors(
+ ((detach_input, indices, scores_ep, router_topk, global_input_tokens_local_experts_indices),
+ permute2_input_detach, permute2_graph,
+ output_splits, input_splits, input_splits_tp_ep))
+
+ if n_shared_experts:
+ if get_tensor_model_parallel_world_size() > 1 and not shared_expert_gate:
+ _, backward_ag_shared, backward_ag_shared_handle = async_all_gather(
+ args[0], get_tensor_model_parallel_group()
+ )
+ else:
+ backward_ag_shared = args[0]
+ backward_ag_shared_handle = None
+
+ if moe_hierarchical_alltoallv:
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ unpermute2_graph_backward_input = args[0].view(-1, args[0].shape[-1])
+ _, unpermute2_graph_backward_input, output_backward_handle = \
+ async_all_gather(unpermute2_graph_backward_input, group=ep_group)
+ if moe_zero_memory == "level0":
+ def alltoall_token_permutation1(hidden_states, indices, router_topk):
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ permutated_local_input_tokens, _, _ = permute_with_ep(
+ hidden_states, indices, probs=scores_ep, topk=router_topk, gb_inputs_splits=input_splits_tp_ep
+ )
+ return permutated_local_input_tokens
+
+ permutated_local_input_tokens = alltoall_token_permutation1(hidden_states_ep, indices_ep, router_topk)
+ set_permute_with_ep_local_input_tokens(permutated_local_input_tokens)
+
+ if moe_zero_memory == "level1" and not ctx.is_only_recompute_activation:
+ with torch.no_grad():
+ if get_tensor_model_parallel_world_size() > 1 and n_shared_experts:
+ _, shared_experts_input, shared_experts_allgather_handle = async_all_gather(
+ detach_input, get_tensor_model_parallel_group(), is_use_get_global_memory_buffer=True
+ )
+ AG_SHARED_EXPERTS_INPUTS.append((shared_experts_input, shared_experts_allgather_handle))
+
+ # Recompute token rearrange in permutation1
+ if moe_hierarchical_alltoallv:
+ permutated_local_input_tokens, _, _ = permute_with_ep(
+ hidden_states_ep.view(-1, hidden_states_ep.shape[-1]), indices_ep, probs=scores_ep, topk=ctx.router_topk,
+ gb_inputs_splits=ctx.input_splits_tp_ep
+ )
+ else:
+ permutated_local_input_tokens, _ = permute(
+ detach_input.view(-1, detach_input.shape[-1]), indices
+ )
+
+ # Recompute expert parallel AlltoAll communication
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ if moe_tp_extend_ep:
+ ep_group = parallel_state.get_tensor_and_expert_parallel_group()
+ if moe_hierarchical_alltoallv:
+ tp_group = parallel_state.get_tensor_model_parallel_group()
+ _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
+ permutated_local_input_tokens,
+ ctx.output_splits,
+ ctx.input_splits,
+ tp_group,
+ )
+ else:
+ _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
+ permutated_local_input_tokens,
+ ctx.output_splits,
+ ctx.input_splits,
+ ep_group,
+ )
+ if moe_hierarchical_alltoallv:
+ output_backward_handle.wait()
+ unpermute2_graph.backward(unpermute2_graph_backward_input)
+ else:
+ unpermute2_graph.backward(args[0])
+ unpermute2_graph = None
+ if moe_zero_memory == "level1" and not ctx.is_only_recompute_activation:
+ if n_shared_experts:
+ with torch.no_grad():
+ # Recompute mm1 and act of shared experts
+ shared_fc1_out, bias_parallel = ctx.shared_experts.linear_fc1(detach_input)
+ shared_act_out = ctx.shared_experts.activation_function(shared_fc1_out, bias_parallel)
+ shared_act_out_size = shared_act_out.untyped_storage().size()
+ ctx.shared_act_out.untyped_storage().resize_(shared_act_out_size)
+ ctx.shared_act_out.untyped_storage().copy_(shared_act_out.untyped_storage())
+ shared_act_out.untyped_storage().resize_(0)
+ shared_fc1_out_size = shared_fc1_out.untyped_storage().size()
+ ctx.shared_fc1_out.untyped_storage().resize_(shared_fc1_out_size)
+ ctx.shared_fc1_out.untyped_storage().copy_(shared_fc1_out.untyped_storage())
+ shared_fc1_out.untyped_storage().resize_(0)
+ if backward_ag_shared_handle is not None:
+ backward_ag_shared_handle.wait()
+ share_experts_graph.backward(backward_ag_shared)
+ share_experts_graph = None
+ if backward_ag_shared_handle is not None:
+ backward_ag_shared.untyped_storage().resize_(0)
+ ctx.shared_act_out.untyped_storage().resize_(0)
+ ctx.shared_fc1_out.untyped_storage().resize_(0)
+
+ permute1_ep_all_to_all_handle.wait()
+ permutated_local_input_tokens.untyped_storage().resize_(0)
+
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ if moe_tp_extend_ep:
+ ep_group = parallel_state.get_tensor_and_expert_parallel_group()
+ if moe_hierarchical_alltoallv:
+ tp_group = parallel_state.get_tensor_model_parallel_group()
+ _, unpermute1_backward_input, handle = async_all_to_all(
+ unpermute2_input_detach.grad,
+ output_splits,
+ input_splits,
+ tp_group,
+ )
+ else:
+ _, unpermute1_backward_input, handle = async_all_to_all(
+ unpermute2_input_detach.grad,
+ output_splits,
+ input_splits,
+ ep_group,
+ )
+
+ if moe_zero_memory == "level1" and not ctx.is_only_recompute_activation:
+ with torch.no_grad():
+ if ctx.num_local_experts > 1:
+ # Recompute permutation2
+ global_input_tokens, _ = permute(
+ global_input_tokens, global_input_tokens_local_experts_indices
+ )
+ if not moe_tp_extend_ep and get_tensor_model_parallel_world_size() > 1 and ctx.moe_grouped_gemm:
+ global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(
+ global_input_tokens
+ )
+ # Recompute mm1 and act
+ input_, mm1_out, act_out = ctx.recompute_tensors
+ ctx.recompute_tensors = None
+ if global_input_tokens.nelement() != 0:
+ group_list = torch.cumsum(ctx.tokens_per_expert, dim=0)
+ w1 = ctx.weight1.view(ctx.num_local_experts, ctx.hidden_size, -1)
+ mm1_out_ = gmm_op(global_input_tokens, w1, [], group_list, 0)[0]
+ group_list.untyped_storage().resize_(0)
+ else:
+ w1 = ctx.weight1.view(ctx.hidden_size, -1)
+ mm1_out_ = torch.matmul(global_input_tokens, w1)
+
+ act_out_ = ctx.activation_func(mm1_out_)
+ act_out_size = act_out_.untyped_storage().size()
+ act_out.untyped_storage().resize_(act_out_size)
+ act_out.untyped_storage().copy_(act_out_.untyped_storage())
+ act_out = None
+ act_out_.untyped_storage().resize_(0)
+ mm1_out_size = mm1_out_.untyped_storage().size()
+ mm1_out.untyped_storage().resize_(mm1_out_size)
+ mm1_out.untyped_storage().copy_(mm1_out_.untyped_storage())
+ mm1_out = None
+ mm1_out_.untyped_storage().resize_(0)
+ input_size = global_input_tokens.untyped_storage().size()
+ input_.untyped_storage().resize_(input_size)
+ input_.untyped_storage().copy_(global_input_tokens.untyped_storage())
+ input_ = None
+ global_input_tokens.untyped_storage().resize_(0)
+ ctx.activation_func = None
+ ctx.hidden_size = None
+ ctx.num_local_experts = None
+ ctx.weight1 = None
+ ctx.moe_grouped_gemm = None
+ ctx.num_local_experts = None
+ ctx.input_splits = None
+ ctx.output_splits = None
+ if moe_hierarchical_alltoallv:
+ ctx.input_splits_tp_ep = None
+ elif share_experts_graph is not None:
+ if backward_ag_shared_handle is not None:
+ backward_ag_shared_handle.wait()
+ share_experts_graph.backward(backward_ag_shared)
+ share_experts_graph = None
+ if backward_ag_shared_handle is not None:
+ backward_ag_shared.untyped_storage().resize_(0)
+ if handle is not None:
+ handle.wait()
+ unpermute2_input_detach.grad.untyped_storage().resize_(0)
+
+ backward_func(unpermute1_graph, unpermute1_backward_input)
+
+ unpermute1_backward_input.untyped_storage().resize_(0)
+ if moe_hierarchical_alltoallv:
+ set_all2all_experts_output((permute1_graph, scores_ep, hidden_states_ep))
+ backward_func(experts_graph, unpermute1_input_detach.grad)
+ unpermute1_input_detach.grad.untyped_storage().resize_(0)
+ permute2_input_detach.grad.untyped_storage().resize_(0)
+ detach_scores_grad, detach_input_grad, detach_input_handle = get_all2all_experts_output()
+ elif moe_experts_pipeline_degree:
+ expert_grad_output = moe_experts_pipeline_backward_func(ctx, ctx.input_list)
+ for input_tensor in input_list:
+ input_tensor.untyped_storage().resize_(0)
+ permute2_graph.backward(expert_grad_output)
+ backward_func(permute1_graph, permute2_input_detach.grad)
+ permute2_input_detach.grad.untyped_storage().resize_(0)
+ else:
+ backward_func(experts_graph, unpermute1_input_detach.grad)
+ unpermute1_input_detach.grad.untyped_storage().resize_(0)
+ permute1_backward_input, bw_permute1_ep_all2all_handle = get_all2all_experts_output()
+ bw_permute1_ep_all2all_handle.wait()
+ permute2_input_detach.grad.untyped_storage().resize_(0)
+ backward_func(permute1_graph, permute1_backward_input)
+ permute1_backward_input.untyped_storage().resize_(0)
+ if l_aux_graph is not None:
+ l_aux_graph.backward(l_aux_detach.grad, retain_graph=True)
+ if moe_zero_memory != "disable":
+ if ctx.router_topk > 1:
+ stream, matmul_output_grad, unpermuted_tokens = get_prob_backward_need_tensors()
+ torch.npu.current_stream().wait_stream(stream)
+ probs_grad = (matmul_output_grad * unpermuted_tokens).sum(-1).squeeze(-1)
+ if moe_hierarchical_alltoallv:
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ _, probs_grad, probs_grad_handle = async_reduce_scatter(probs_grad, group=ep_group)
+ probs_grad_handle.wait()
+ route_graph.backward(probs_grad)
+ detach_input_handle.wait()
+ else:
+ route_graph.backward(probs_grad)
+ ctx.router_topk = None
+ else:
+ if moe_hierarchical_alltoallv:
+ route_graph.backward(detach_scores_grad)
+ detach_input_handle.wait()
+ else:
+ route_graph.backward(detach_scores.grad)
+ route_graph = None
+ if moe_hierarchical_alltoallv:
+ grad_output = detach_input.grad + detach_input_grad
+ else:
+ grad_output = detach_input.grad
+ return grad_output, None
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer_overlap_allgather.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer_overlap_allgather.py
new file mode 100644
index 0000000000000000000000000000000000000000..5317b0694842634746ab70e35add530310d73cc8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_layer_overlap_allgather.py
@@ -0,0 +1,229 @@
+import acl
+import torch
+import torch_npu
+
+from megatron.core.parallel_state import (get_expert_model_parallel_group, get_tensor_and_expert_parallel_group,
+ get_tensor_model_parallel_group, get_tensor_model_parallel_world_size)
+from megatron.core.transformer.moe.moe_layer import MoELayer
+from megatron.training import get_args
+from mindspeed.core.transformer.moe.token_dispatcher import cann_version_check
+from mindspeed.core.transformer.moe.moe_utils import AG_SHARED_EXPERTS_INPUTS
+from mindspeed.core.transformer.moe.comm_utils import async_all_gather, async_reduce_scatter
+from mindspeed.core.transformer.moe.moe_utils import (forward_func, backward_func, set_gemm_backward_need_tensors,
+ get_rs_global_hidden_states_grad_with_handle)
+
+
+class MoELayerOverlapAllGather(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, hidden_states, moe_layer: MoELayer):
+ args = get_args()
+ save_tensors = []
+ ctx.input_shape = hidden_states.shape
+ moe_layer.experts.layer_number = moe_layer.layer_number
+ # input detach graph, leaf node
+ hidden_states = hidden_states.detach()
+ hidden_states.requires_grad = True
+
+ # 共享专家 tp group allgather hidden_states
+ # 默认 tp 开启 sp
+ if args.n_shared_experts and get_tensor_model_parallel_world_size() > 1:
+ hidden_states, shared_experts_input, shared_experts_allgather_handle = async_all_gather(
+ hidden_states, get_tensor_model_parallel_group(), is_use_get_global_memory_buffer=True
+ )
+ AG_SHARED_EXPERTS_INPUTS.append(shared_experts_input)
+ else:
+ shared_experts_input = hidden_states
+ shared_experts_allgather_handle = None
+
+ # router
+ (scores, indices), _ = forward_func(moe_layer.router, hidden_states)
+
+ # after router, do 2 allgather
+ global_indices_tuple = None
+ global_probs_tuple = None
+ if moe_layer.config.sequence_parallel or (moe_layer.config.expert_model_parallel_size > 1):
+ if isinstance(indices, tuple):
+ global_indices, gi_handle = indices
+ else:
+ _, global_indices, gi_handle = async_all_gather(indices, get_tensor_and_expert_parallel_group())
+
+ global_indices_tuple = (global_indices, gi_handle)
+
+ _, global_probs, gp_handle = async_all_gather(
+ scores, get_tensor_and_expert_parallel_group()
+ )
+
+ global_probs_tuple = (global_probs, gp_handle)
+
+ # 专家 ep group allgather hidden_states
+ global_hidden_states_tuple = None
+ if moe_layer.config.sequence_parallel or moe_layer.config.expert_model_parallel_size > 1:
+ if '910B' in acl.get_soc_name():
+ _, global_hidden_states, ghs_handle = async_all_gather(
+ hidden_states,
+ get_tensor_and_expert_parallel_group()
+ )
+ else:
+ _, global_hidden_states, ghs_handle = async_all_gather(
+ shared_experts_input,
+ get_expert_model_parallel_group()
+ if shared_experts_allgather_handle
+ else get_tensor_and_expert_parallel_group(),
+ shared_experts_allgather_handle
+ )
+ global_hidden_states = global_hidden_states.view(-1, global_hidden_states.shape[-1])
+ global_hidden_states_tuple = (global_hidden_states, ghs_handle)
+
+ # shared experts
+ shared_experts_rs_handle = None
+ share_experts_output = None
+ rs_share_experts_output = None
+ share_experts_bias = None
+ if args.n_shared_experts:
+ if shared_experts_allgather_handle is not None:
+ shared_experts_allgather_handle.wait()
+ (share_experts_output, share_experts_bias), _ = forward_func(
+ moe_layer.shared_experts, hidden_states
+ )
+
+ if get_tensor_model_parallel_world_size() > 1:
+ # reduce scatter
+ _, rs_share_experts_output, shared_experts_rs_handle = async_reduce_scatter(
+ share_experts_output, get_tensor_model_parallel_group()
+ )
+ else:
+ rs_share_experts_output = share_experts_output
+ shared_experts_rs_handle = None
+
+ token_permutation_input = (
+ global_indices_tuple,
+ global_probs_tuple,
+ global_hidden_states_tuple
+ )
+
+ # dispatch input
+ save_tensors.append(scores)
+
+ moe_layer.token_dispatcher.hidden_shape = hidden_states.shape
+ (dispatched_input, tokens_per_expert, global_local_map, indices), *token_permutation_input = forward_func(
+ moe_layer.token_dispatcher.token_permutation, token_permutation_input
+ )
+
+ save_tensors.append(global_local_map)
+ save_tensors.append(indices)
+
+ # token_permutation_input : (global_indices, handle), (global_probs, handle), (global_hidden_states, handle)
+ global_probs_detach, global_hidden_states_detach = token_permutation_input[1][0], token_permutation_input[2][0]
+
+ global_hidden_states_detach.untyped_storage().resize_(0)
+ if cann_version_check:
+ global_probs_detach.untyped_storage().resize_(0)
+ save_tensors.append(global_probs_detach)
+ save_tensors.append(global_hidden_states_detach)
+
+ expert_input = (dispatched_input, tokens_per_expert)
+
+ def func(dispatched_input, tokens_per_expert):
+ expert_output, mlp_bias = moe_layer.experts(dispatched_input, tokens_per_expert)
+ output, mlp_bias = moe_layer.token_dispatcher.token_unpermutation(
+ expert_output, mlp_bias
+ )
+ return output, mlp_bias
+
+ (output, mlp_bias), *_ = forward_func(func, expert_input)
+
+ save_tensors.append(dispatched_input)
+
+ _, output_rs, token_unpermutation_rs_handle = async_reduce_scatter(
+ output, get_tensor_and_expert_parallel_group()
+ )
+
+ ctx.token_unpermutation_output_shape = output.shape
+
+ token_unpermutation_rs_handle.wait()
+ output.untyped_storage().resize_(0)
+ output_rs = output_rs.view(moe_layer.token_dispatcher.hidden_shape)
+
+ save_tensors.append(hidden_states)
+ save_tensors.append(output)
+ save_tensors.append(share_experts_output)
+ ctx.save_for_backward(*save_tensors)
+
+ if args.n_shared_experts:
+ if shared_experts_rs_handle is not None:
+ shared_experts_rs_handle.wait()
+
+ output_rs = output_rs + rs_share_experts_output
+ if moe_layer.token_dispatcher.add_bias:
+ mlp_bias = mlp_bias + share_experts_bias
+ share_experts_output.untyped_storage().resize_(0)
+ return output_rs, mlp_bias
+
+ return output_rs.detach(), mlp_bias
+
+ @staticmethod
+ def backward(ctx, *args):
+ (scores, global_local_map, indices,
+ global_probs_detach, global_hidden_states_detach, dispatched_input,
+ input_, output, share_experts_graph) = ctx.saved_tensors
+
+ token_unpermutation_output_shape = ctx.token_unpermutation_output_shape
+ # tp group ag grad_out
+ if share_experts_graph is not None and get_tensor_model_parallel_world_size() > 1:
+ _, ag_share_experts_grad_input, ag_share_experts_handle = async_all_gather(
+ args[0], get_tensor_model_parallel_group()
+ )
+ else:
+ ag_share_experts_grad_input = args[0]
+ ag_share_experts_handle = None
+
+ if '910B' not in acl.get_soc_name() and share_experts_graph:
+ _, ag_experts_grad_input, ag_experts_handle = async_all_gather(
+ ag_share_experts_grad_input,
+ get_expert_model_parallel_group(),
+ ag_share_experts_handle
+ )
+ else:
+ _, ag_experts_grad_input, ag_experts_handle = async_all_gather(
+ args[0],
+ get_tensor_and_expert_parallel_group(),
+ )
+
+ args = None
+ if ag_share_experts_handle is not None:
+ ag_share_experts_handle.wait()
+
+ if share_experts_graph is not None:
+ # 反向 —— 共享专家
+ share_experts_graph.backward(ag_share_experts_grad_input)
+ if '910B' in acl.get_soc_name() or share_experts_graph is None:
+ from mindspeed.core.transformer.moe.moe_utils import set_ag_tp_hidden_status
+ set_ag_tp_hidden_status(input_)
+
+ ag_experts_handle.wait()
+ ag_share_experts_grad_input = None
+ ag_experts_grad_input = ag_experts_grad_input.view(token_unpermutation_output_shape)
+
+ # token 重排反向 function set
+ set_gemm_backward_need_tensors((dispatched_input, global_hidden_states_detach, indices, global_local_map))
+
+ # 反向 —— token 反重排 expert
+ output.backward(ag_experts_grad_input)
+
+ global_probs_grad = global_probs_detach.grad
+
+ _, rs_global_probs_grad, rs_global_probs_grad_handle = async_reduce_scatter(
+ global_probs_grad, get_tensor_and_expert_parallel_group()
+ )
+ rs_global_probs_grad_handle.wait()
+ global_probs_grad.untyped_storage().resize_(0)
+
+ # 反向 —— router
+ backward_func(scores, rs_global_probs_grad)
+
+ rs_global_hidden_states_grad, rs_handle = get_rs_global_hidden_states_grad_with_handle()
+ rs_handle.wait()
+ rs_global_hidden_states_grad = rs_global_hidden_states_grad.view(ctx.input_shape)
+ # expert grad + shared expert grad
+ rs_global_hidden_states_grad += input_.grad
+ return rs_global_hidden_states_grad, None
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_utils.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..461b8739cb93f6a013dd0558162e10013f67d8c7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/moe_utils.py
@@ -0,0 +1,376 @@
+# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+import torch_npu
+from megatron.core.transformer.moe.moe_utils import permute_with_padded_tokens, unpermute_with_padded_tokens
+from megatron.training import get_args
+from megatron.core import mpu, parallel_state
+from megatron.core.transformer.moe.moe_utils import (reduce_aux_losses_tracker_across_ranks,
+ clear_aux_losses_tracker)
+
+
+AG_TP_HIDDEN_STATUS = None
+AG_SHARED_EXPERTS_INPUTS = []
+GEMM_BACKWARD_NEED_TENSORS = None
+RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = None
+SWAP_STREAM = None
+SWAP_STREAM2 = None
+SWAP_TENSOR = None
+MATMUL_OUTPUT_GRAD = None
+UNPERMUTED_TOKENS = None
+PERMUTE_WITH_EP_LOCAL_INPUT_TOKENS = None
+
+
+def get_swap_stream():
+ global SWAP_STREAM2
+ if SWAP_STREAM2 is None:
+ _ = torch_npu.npu.Stream(device=torch.npu.current_device())
+ SWAP_STREAM2 = torch_npu.npu.Stream(device=torch.npu.current_device())
+ stream = SWAP_STREAM2
+ return stream
+
+
+def set_swap_status(tensor):
+ global SWAP_TENSOR
+ SWAP_TENSOR = tensor
+
+
+def get_swap_status():
+ global SWAP_STREAM
+ if SWAP_STREAM is None:
+ SWAP_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
+ global SWAP_TENSOR
+ stream = SWAP_STREAM
+ tensor = SWAP_TENSOR
+ SWAP_TENSOR = None
+ return stream, tensor
+
+
+def set_prob_backward_need_tensors(matmul_output_grad, unpermuted_tokens):
+ global MATMUL_OUTPUT_GRAD
+ MATMUL_OUTPUT_GRAD = matmul_output_grad
+ global UNPERMUTED_TOKENS
+ UNPERMUTED_TOKENS = unpermuted_tokens
+
+
+def get_prob_backward_need_tensors():
+ global SWAP_STREAM2
+ if SWAP_STREAM2 is None:
+ _ = torch_npu.npu.Stream(device=torch.npu.current_device())
+ SWAP_STREAM2 = torch_npu.npu.Stream(device=torch.npu.current_device())
+ global MATMUL_OUTPUT_GRAD
+ global UNPERMUTED_TOKENS
+ stream = SWAP_STREAM2
+ matmul_output_grad = MATMUL_OUTPUT_GRAD
+ unpermuted_tokens = UNPERMUTED_TOKENS
+ MATMUL_OUTPUT_GRAD = None
+ UNPERMUTED_TOKENS = None
+ return stream, matmul_output_grad, unpermuted_tokens
+
+
+def set_ag_tp_hidden_status(_inputs):
+ global AG_TP_HIDDEN_STATUS
+ AG_TP_HIDDEN_STATUS = _inputs
+
+
+def get_ag_tp_hidden_status():
+ global AG_TP_HIDDEN_STATUS
+ result = AG_TP_HIDDEN_STATUS
+ AG_TP_HIDDEN_STATUS = None
+ return result
+
+
+def set_gemm_backward_need_tensors(_inputs):
+ global GEMM_BACKWARD_NEED_TENSORS
+ GEMM_BACKWARD_NEED_TENSORS = _inputs
+
+
+def get_gemm_backward_need_tensors():
+ global GEMM_BACKWARD_NEED_TENSORS
+ result = GEMM_BACKWARD_NEED_TENSORS
+ GEMM_BACKWARD_NEED_TENSORS = None
+ return result
+
+
+def set_permute_with_ep_local_input_tokens(_inputs):
+ global PERMUTE_WITH_EP_LOCAL_INPUT_TOKENS
+ PERMUTE_WITH_EP_LOCAL_INPUT_TOKENS = _inputs
+
+
+def get_permute_with_ep_local_input_tokens():
+ global PERMUTE_WITH_EP_LOCAL_INPUT_TOKENS
+ result = PERMUTE_WITH_EP_LOCAL_INPUT_TOKENS
+ PERMUTE_WITH_EP_LOCAL_INPUT_TOKENS = None
+ return result
+
+
+def set_rs_global_hidden_states_grad_with_handle(_inputs):
+ global RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
+ RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = _inputs
+
+
+def get_rs_global_hidden_states_grad_with_handle():
+ global RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
+ result = RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
+ RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = None
+ return result
+
+
+ALL2ALL_EXPERTS_OUTPUT = None
+
+
+def set_all2all_experts_output(_input):
+ global ALL2ALL_EXPERTS_OUTPUT
+ ALL2ALL_EXPERTS_OUTPUT = _input
+
+
+def get_all2all_experts_output():
+ global ALL2ALL_EXPERTS_OUTPUT
+ result = ALL2ALL_EXPERTS_OUTPUT
+ ALL2ALL_EXPERTS_OUTPUT = None
+ return result
+
+
+def only_recompute_activation(layer_number):
+ args = get_args()
+ vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
+ vpp_size = args.virtual_pipeline_model_parallel_size
+ pp_size = args.transformer_pipeline_model_parallel_size
+
+ if vpp_size is not None:
+ layer_per_chunk = args.num_layers_per_virtual_pipeline_stage
+ elif pp_size is not None:
+ layer_per_chunk = args.num_layers // pp_size
+ else:
+ layer_per_chunk = args.num_layers
+
+ if vpp_rank is None:
+ vpp_rank = 0
+ if vpp_size is None:
+ vpp_size = 1
+ recompute_priority = ((layer_number - 1) % layer_per_chunk) * vpp_size + vpp_rank
+ moe_zero_memory_num_layers = args.moe_zero_memory_num_layers
+
+ if moe_zero_memory_num_layers:
+ if recompute_priority < moe_zero_memory_num_layers:
+ return False
+ else:
+ return True
+ else:
+ return False
+
+
+def forward_func(func, inputs):
+ def detach_tensor(input_):
+ if input_.requires_grad and input_.grad_fn is None:
+ return input_
+ else:
+ new_input = input_.detach()
+ new_input.requires_grad = True
+ return new_input
+
+ detach_inputs = []
+ if isinstance(inputs, tuple):
+ for input_ in inputs:
+ if isinstance(input_, tuple):
+ detach_input = []
+ for i in input_:
+ if isinstance(i, torch.Tensor) and torch.is_floating_point(i):
+ detach_input.append(detach_tensor(i))
+ else:
+ detach_input.append(i)
+ detach_inputs.append(tuple(detach_input))
+ else:
+ if isinstance(input_, torch.Tensor) and torch.is_floating_point(input_):
+ detach_input = detach_tensor(input_)
+ else:
+ detach_input = input_
+ detach_inputs.append(detach_input)
+ elif isinstance(inputs, torch.Tensor):
+ detach_inputs.append(detach_tensor(inputs))
+
+ with torch.enable_grad():
+ output = func(*detach_inputs)
+
+ return output, *detach_inputs
+
+
+def backward_func(func_tensor, gradinputs):
+ if gradinputs is None or func_tensor.grad_fn is None:
+ return
+ if isinstance(gradinputs, torch.Tensor):
+ func_tensor.backward(gradinputs)
+ elif isinstance(gradinputs, tuple):
+ func_tensor.backward(*gradinputs)
+
+
+def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False):
+ if padded_mode:
+ return permute_with_padded_tokens(tokens, indices)
+
+ if indices.dim() == 1:
+ topk = 1
+ else:
+ topk = indices.size(1)
+ flatten_indices = indices.view(-1)
+ # previous use argsort, argsort int64 will be run on host cpu
+ sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1]
+ if num_out_tokens is not None:
+ sorted_indices = sorted_indices[:num_out_tokens]
+ permuted_tokens = tokens.index_select(0, sorted_indices // topk)
+ return permuted_tokens, sorted_indices
+
+
+def permute_with_ep(tokens: torch.Tensor,
+ indices: torch.Tensor,
+ probs: torch.Tensor,
+ topk: int = 1,
+ gb_inputs_splits=None):
+ if topk > 1:
+ if indices.size(1) != topk:
+ raise RuntimeError("indices.size(1) should be equal to topk")
+ flatten_indices = indices.view(-1)
+ sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1]
+ ep_rank = mpu.get_expert_model_parallel_rank()
+ import numpy as np
+ gb_inputs_splits_sum = np.cumsum(gb_inputs_splits)
+ start = 0
+ if ep_rank > 0:
+ start = gb_inputs_splits_sum[ep_rank - 1]
+ end = gb_inputs_splits_sum[ep_rank]
+ result_indices = sorted_indices[start : end]
+ permuted_tokens = tokens.index_select(0, result_indices // topk)
+ flatten_probs = probs.view(-1)
+ permuted_probs = flatten_probs.index_select(0, result_indices)
+ return permuted_tokens, permuted_probs, result_indices
+
+
+def unpermute_with_ep(
+ unpermute_with_ep_input_tensors_list,
+ probs: torch.Tensor = None,
+ padded_mode: bool = False,
+ restore_shape: torch.Size = None,
+ topk: int = 1,
+):
+ permuted_tokens, sorted_indices, permuted_probs = unpermute_with_ep_input_tensors_list
+ if padded_mode:
+ return unpermute_with_padded_tokens(
+ permuted_tokens, sorted_indices, probs, restore_shape=restore_shape
+ )
+
+ assert sorted_indices.numel() == permuted_tokens.size(0)
+ if permuted_probs is not None:
+ permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
+ unpermuted_tokens = torch.zeros(restore_shape[0], permuted_tokens.size(-1),
+ dtype=permuted_tokens.dtype, device=permuted_tokens.device)
+ sorted_indices = sorted_indices // topk
+ unpermuted_tokens = unpermuted_tokens.scatter_add_(0,
+ sorted_indices.unsqueeze(1).expand(-1, permuted_tokens.shape[1]),
+ permuted_tokens)
+ return unpermuted_tokens
+
+
+def unpermute(
+ permuted_tokens: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ probs: torch.Tensor = None,
+ padded_mode: bool = False,
+ restore_shape: torch.Size = None,
+):
+ if padded_mode:
+ return unpermute_with_padded_tokens(
+ permuted_tokens, sorted_indices, probs, restore_shape=restore_shape
+ )
+
+ assert sorted_indices.numel() == permuted_tokens.size(0)
+ if probs is not None:
+ # Unpermute and merge the tokens with their probabilities
+ num_unpermuted_tokens = probs.numel()
+ topk = probs.size(1)
+ else:
+ # Unpermute the tokens without merge
+ num_unpermuted_tokens = permuted_tokens.size(0)
+ topk = 1
+
+ unpermuted_tokens = torch.zeros(
+ [num_unpermuted_tokens, permuted_tokens.shape[-1]],
+ dtype=permuted_tokens.dtype,
+ device=permuted_tokens.device,
+ )
+ unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
+ unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
+ if probs is not None:
+ unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
+ unpermuted_tokens = unpermuted_tokens.sum(dim=1)
+
+ return unpermuted_tokens
+
+
+def get_mean(tensor):
+ """
+ Calculate the mean of a tensor, excluding specified 'noop_layers'.
+
+ Parameters:
+ tensor (torch.Tensor): A one-dimensional tensor.
+
+ Returns:
+ float: The mean of the tensor, excluding the 'noop_layers' if specified.
+
+ Notes:
+ - If `args.noop_layers` is a set and is not empty, the mean is calculated by excluding these layers.
+ - If `args.noop_layers` is empty or None, the mean is calculated directly from the tensor.
+ - `args.num_layers` represents the total number of layers, used to adjust the mean calculation when
+ excluding 'noop_layers'.
+ """
+ args = get_args()
+ if hasattr(args, 'noop_layers') and isinstance(args.noop_layers, set) and len(args.noop_layers) > 0:
+ return tensor.sum() / (args.num_layers - len(args.noop_layers))
+ return tensor.mean()
+
+
+def track_moe_metrics(
+ loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False
+):
+ # Aux loss logging
+
+ reduce_aux_losses_tracker_across_ranks()
+ tracker = parallel_state.get_moe_layer_wise_logging_tracker()
+ if writer is not None:
+ aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()}
+ for name, loss_list in aux_losses.items():
+ # adaptation for
+ loss_list_mean = get_mean(loss_list)
+ if total_loss_dict is not None:
+ if name not in total_loss_dict:
+ # adaptation for loss_list.mean()
+ total_loss_dict[name] = loss_list_mean
+ else:
+ # adaptation for loss_list.mean()
+ total_loss_dict[name] += loss_list_mean
+
+ # currently when using add_scalars,
+ # torch.utils.add_scalars makes each timer its own run, which
+ # polutes the runs list, so we just add each as a scalar
+ # adaptation for loss_list.mean()
+ writer.add_scalar(name, loss_list_mean, iteration)
+ if per_layer_logging:
+ for i, loss in enumerate(loss_list.tolist()):
+ writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)
+
+ # W&B logging lacks support for logging multiple scalars simultaneously.
+ # As a workaround, we log each scalar individually first, then we can create
+ # a custom panel to manually group them to a single plot.
+ if wandb_writer:
+ # adaptation for loss_list.mean()
+ wandb_writer.log({f"{name}": loss_list_mean}, iteration)
+ if per_layer_logging:
+ wandb_writer.log(
+ {
+ f"moe/{name}_layer_{i}": loss
+ for i, loss in enumerate(loss_list.tolist())
+ },
+ iteration,
+ )
+
+ clear_aux_losses_tracker()
+
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/router.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e640133d0dd47defabd36123784bbeb4b5e714
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/router.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2022; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+from megatron.training import get_args
+from megatron.core.parallel_state import get_tensor_and_expert_parallel_group
+from megatron.core.tensor_parallel.mappings import _reduce_scatter_along_first_dim_moe
+from megatron.core.transformer.moe.moe_utils import topk_softmax_with_capacity
+
+
+def _gather_along_first_dim_moe_async(input_, async_op):
+ """Gather tensors and concatenate along the first dimension."""
+ group = get_tensor_and_expert_parallel_group()
+ world_size = torch.distributed.get_world_size(group=group)
+ # Bypass the function if we are using only 1 GPU.
+ if world_size == 1:
+ return input_
+
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ handle = torch.distributed._all_gather_base(output, input_.contiguous(), group=group, async_op=async_op)
+
+ return output, handle
+
+
+class _GatherFromSequenceParallelRegionToMOEAsync(torch.autograd.Function):
+ @staticmethod
+ def symbolic(graph, input_):
+ return _gather_along_first_dim_moe_async(input_, async_op=True)
+
+ @staticmethod
+ def forward(ctx, input_):
+ return _gather_along_first_dim_moe_async(input_, async_op=True)
+
+ @staticmethod
+ def backward(ctx, grad_output, grad_handle):
+ return _reduce_scatter_along_first_dim_moe(grad_output)
+
+
+def gather_from_sequence_parallel_region_to_moe_async(input_):
+ return _GatherFromSequenceParallelRegionToMOEAsync.apply(input_)
+
+
+def aux_loss_load_balancing(self, logits: torch.Tensor):
+ probs, indices, tokens_per_expert = topk_softmax_with_capacity(
+ logits,
+ self.topk,
+ capacity_factor=self.config.moe_expert_capacity_factor,
+ pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
+ drop_policy=self.config.moe_token_drop_policy,
+ use_pre_softmax=self.config.moe_router_pre_softmax,
+ )
+ global_indices = indices
+ if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1 and not get_args().tp_2d):
+ with torch.no_grad():
+ global_indices = gather_from_sequence_parallel_region_to_moe_async(indices)
+
+ # Apply load balancing loss
+ if self.training:
+ scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
+ probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs)
+ return probs, global_indices
+
+
+def routing_tp_extend_ep(self, logits: torch.Tensor):
+ """Top-k routing function
+
+ Args:
+ logits (torch.Tensor): Logits tensor after gating.
+
+ Returns:
+ probs (torch.Tensor): the probabilities tensor after load balancing.
+ indices (torch.Tensor): the indices tensor after top-k selection.
+ """
+ logits = logits.view(-1, self.config.num_moe_experts)
+
+ # Apply Z-Loss
+ logits = self.apply_z_loss(logits)
+
+ if self.routing_type == "sinkhorn":
+ scores, indices = self.sinkhorn_load_balancing(logits)
+ elif self.routing_type == "aux_loss":
+ scores, indices = self.aux_loss_load_balancing(logits)
+ elif self.routing_type == "none":
+ # A naive top-k routing without load balancing
+ scores, indices, _ = topk_softmax_with_capacity(
+ logits,
+ self.topk,
+ capacity_factor=self.config.moe_expert_capacity_factor,
+ pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
+ drop_policy=self.config.moe_token_drop_policy,
+ )
+ else:
+ raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
+
+ return scores, indices
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/token_dispatcher.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/token_dispatcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..3645d5cd6f782204eb48ef6e6221fe327d2e9145
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/token_dispatcher.py
@@ -0,0 +1,1262 @@
+# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+from torch_npu.utils.collect_env import get_cann_version
+from megatron.training import get_args
+from megatron.core import parallel_state, tensor_parallel, mpu
+from megatron.core.transformer.moe.moe_utils import moe_gather, moe_scatter, permute, unpermute
+from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel, reduce_scatter_to_sequence_parallel_region
+from mindspeed.core.transformer.moe.router import gather_from_sequence_parallel_region_to_moe_async
+from mindspeed.core.transformer.moe.comm_utils import (async_all_to_all, async_reduce_scatter,
+ async_all_gather, reduce_scatter, all_gather)
+from mindspeed.core.transformer.moe.moe_layer_overlap_all2all import forward_func
+from mindspeed.core.transformer.moe.unpermute_without_activation import UnpermuteWithoutActivation
+from mindspeed.core.transformer.moe.moe_utils import (AG_SHARED_EXPERTS_INPUTS,
+ permute_with_ep, unpermute_with_ep)
+
+
+def is_less_or_equal_rc2_cann_version():
+ cann_starts_with = ('8.0.RC1', '8.0.RC2')
+ cann_all = ('not known', '8.0.T1', '8.0.T2', '8.0.T3', '8.0.T37', '8.0.T5', '8.0.T6', '8.0.T7',
+ '8.0.T8', '8.0.T10', '8.0.T13', '8.0.T16', '8.0.T50', '8.0.T51', '8.0.T52')
+ cann_version = get_cann_version()
+ return cann_version in cann_all or cann_version.startswith(cann_starts_with)
+
+
+cann_version_check = is_less_or_equal_rc2_cann_version()
+
+
+def allgather_token_permutation(self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind):
+ args = get_args()
+ self.hidden_shape = hidden_states.shape
+ # [S/TP, B, H] -> [S*B/TP, H]
+ hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
+
+ # Permute the tokens across the expert parallel devices.
+ if (self.config.tensor_model_parallel_size > 1) or (
+ self.config.expert_model_parallel_size > 1
+ ):
+ # [S*B/TP, H] -> [S*B, H]
+ with torch.no_grad():
+ global_indices, gi_handle = max_ind if isinstance(max_ind,
+ tuple) else gather_from_sequence_parallel_region_to_moe_async(
+ max_ind)
+ global_probs, gp_handle = gather_from_sequence_parallel_region_to_moe_async(max_prob)
+ global_hidden_states, ghs_handle = gather_from_sequence_parallel_region_to_moe_async(hidden_states)
+
+ with torch.no_grad():
+ gi_handle.wait()
+ global_local_mask = (global_indices >= self.local_expert_indices[0]) & \
+ (global_indices <= self.local_expert_indices[-1])
+ local_indices = global_indices.masked_select(global_local_mask)
+ self.indices = torch.argsort(local_indices.float(), dim=0)
+ num_global_experts = self.num_local_experts * parallel_state.get_expert_model_parallel_world_size()
+ if args.moe_tp_extend_ep:
+ num_global_experts *= parallel_state.get_tensor_model_parallel_world_size()
+ all_tokens_per_expert = torch.histc(
+ global_indices,
+ bins=num_global_experts,
+ min=0,
+ max=num_global_experts - 1,
+ )
+ self.all_tokens_per_expert = all_tokens_per_expert.to(torch.long)
+ tokens_per_expert = self.all_tokens_per_expert[self.local_expert_indices[0]: self.local_expert_indices[-1] + 1]
+ self.global_local_map = global_local_mask.nonzero()[:, 0]
+
+ if self.router_topk > 1: # k > 1
+ gp_handle.wait()
+ self.local_probs = global_probs.masked_select(global_local_mask)
+ else:
+ self.local_probs = max_prob
+
+ ghs_handle.wait()
+ if cann_version_check:
+ local_hidden_states = global_hidden_states[self.global_local_map, :]
+ else:
+ self.global_local_map = self.global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
+ local_hidden_states = moe_gather.apply(global_hidden_states, self.global_local_map)
+ else:
+ if self.router_topk > 1:
+ global_local_mask = torch.ones_like(max_ind).bool()
+ local_indices = max_ind.masked_select(global_local_mask)
+ self.local_probs = max_prob.masked_select(global_local_mask)
+ self.global_local_map = global_local_mask.nonzero()[:, 0]
+ if cann_version_check:
+ local_hidden_states = hidden_states[self.global_local_map, :]
+ else:
+ self.global_local_map = self.global_local_map.view(-1, 1).expand(
+ -1, hidden_states.shape[-1]
+ )
+ local_hidden_states = torch.gather(hidden_states, 0, self.global_local_map)
+ else:
+ local_indices = max_ind
+ self.local_probs = max_prob
+ local_hidden_states = hidden_states
+ self.global_local_map = None
+
+ with torch.no_grad():
+ # The indices of local_indices that give its sorted order along dim 0.
+ self.indices = torch.argsort(local_indices, dim=0)
+ # use 0.7.0 implement for better performance
+ tokens_per_expert = torch.histc(
+ local_indices,
+ bins=self.num_local_experts,
+ min=self.local_expert_indices[0],
+ max=self.local_expert_indices[-1],
+ )
+ tokens_per_expert = tokens_per_expert.to(torch.long)
+ self.all_tokens_per_expert = tokens_per_expert
+
+ if self.num_local_experts > 1:
+ if cann_version_check:
+ permuted_local_hidden_states = local_hidden_states[self.indices, :]
+ else:
+ self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
+ permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices)
+ else:
+ permuted_local_hidden_states = local_hidden_states
+ return (
+ permuted_local_hidden_states,
+ tokens_per_expert,
+ )
+
+
+class NewIndePut(torch.autograd.Function):
+ @staticmethod
+ def forward(self, tensor, map_, value_):
+ self.map_ = map_
+ ori_dtype = None
+ if value_.dtype != torch.float32:
+ ori_dtype = value_.dtype
+ value_ = value_.float()
+ output = tensor.index_put_(map_, value_, accumulate=True)
+ if ori_dtype:
+ return output.to(ori_dtype)
+ return output
+
+ def backward(self, grad_input):
+ map_ = self.map_
+ grad_output = grad_input.index_select(0, map_[0])
+ return None, None, grad_output
+
+
+def allgather_token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None, ):
+ # Stage1: unpermute the tokens and bias locally respectively.w
+ scores = self.local_probs.to(dtype=hidden_states.dtype)
+ if self.num_local_experts > 1:
+ if cann_version_check:
+ unpermuted_local_hidden = torch.zeros_like(hidden_states)
+ unpermuted_local_hidden.index_put_((self.indices,), hidden_states[:self.indices.shape[0], :],
+ accumulate=False)
+ else:
+ assert self.indices.shape == hidden_states.shape
+ unpermuted_local_hidden = moe_scatter.apply(hidden_states, self.indices)
+ else:
+ unpermuted_local_hidden = hidden_states
+
+ # Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
+ if self.router_topk > 1:
+ unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
+
+ unpermuted_local_bias = None
+ if self.add_bias:
+ assert bias is not None
+ unpermuted_local_bias = torch.zeros_like(hidden_states)
+ if cann_version_check:
+ unpermuted_local_bias.index_put_((self.indices,), bias[:self.indices.shape[0], :], accumulate=False)
+ else:
+ assert self.indices.shape == bias.shape
+ unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias)
+ if self.router_topk > 1:
+ unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
+
+ output_total = unpermuted_local_hidden
+ output_bias_total = unpermuted_local_bias
+
+ # Unpermute the tokens across expert parallel devices.
+ if (self.config.tensor_model_parallel_size > 1) or (
+ self.config.expert_model_parallel_size > 1
+ ):
+ assert (
+ self.global_local_map is not None
+ ), "global_local_map is necessary for `AllGather`."
+ ep_group_size = parallel_state.get_tensor_and_expert_parallel_world_size()
+ # hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP)
+ global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size
+ global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
+ if cann_version_check:
+ unpermuted_global_hidden = torch.zeros(global_hidden_shape, dtype=torch.float,
+ device=torch.cuda.current_device())
+ unpermuted_global_hidden = NewIndePut.apply(unpermuted_global_hidden, (self.global_local_map,),
+ unpermuted_local_hidden[:self.global_local_map.shape[0], :])
+ else:
+ assert self.global_local_map.shape == unpermuted_local_hidden.shape
+ unpermuted_global_hidden = moe_scatter.apply(
+ unpermuted_local_hidden, self.global_local_map, global_hidden_shape
+ )
+
+ output_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(unpermuted_global_hidden)
+ if self.add_bias:
+ # Unpermute the bias across expert parallel devices.
+ unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
+ if cann_version_check:
+ unpermuted_global_bias.index_put_((self.global_local_map,),
+ unpermuted_local_bias[:self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ unpermuted_global_bias = unpermuted_global_bias.scatter_add(
+ 0, self.global_local_map, unpermuted_local_bias
+ )
+
+ output_bias_total = (
+ tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
+ unpermuted_global_bias
+ )
+ )
+ # bias is duplicated across tensor parallelism ranks;
+ # reduce scatter reduces bias across tensor parallel_ranks
+ output_bias_total = (output_bias_total / parallel_state.get_tensor_model_parallel_world_size())
+ else:
+ if self.router_topk > 1:
+ global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
+ global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
+ unpermuted_global_hidden = torch.zeros(
+ global_hidden_shape,
+ dtype=hidden_states.dtype,
+ device=torch.cuda.current_device(),
+ )
+ if cann_version_check:
+ output_total = unpermuted_global_hidden.index_put((self.global_local_map,),
+ unpermuted_local_hidden[
+ :self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ output_total = unpermuted_global_hidden.scatter_add(
+ 0, self.global_local_map, unpermuted_local_hidden
+ )
+ if self.add_bias:
+ unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
+ if cann_version_check:
+ output_bias_total = unpermuted_global_bias.index_put((self.global_local_map,),
+ unpermuted_local_bias[
+ :self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ output_bias_total = unpermuted_global_bias.scatter_add(
+ 0, self.global_local_map, unpermuted_local_bias
+ )
+
+ if self.router_topk == 1:
+ output_total = output_total * scores
+ output_total = output_total.view(self.hidden_shape)
+ if self.add_bias:
+ assert output_bias_total is not None
+ if self.router_topk == 1:
+ output_bias_total = output_bias_total * scores
+ output_bias_total = output_bias_total.view(self.hidden_shape)
+ else:
+ output_bias_total = None
+
+ return output_total, output_bias_total
+
+
+def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
+ # use 0.7.0 implement for better performance
+ num_local_tokens_per_expert = torch.histc(
+ indices, bins=self.num_experts, min=0, max=self.num_experts
+ )
+ # num_local_tokens_per_expert: [num_experts]
+
+ ep_size = self.config.expert_model_parallel_size
+ if self.drop_and_pad:
+ # probs: [num_experts, capacity]
+ self.capacity = self.probs.size(1)
+ num_tokens_per_local_expert = torch.full(
+ (self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long,
+ device=torch.cuda.current_device()
+ )
+ return num_tokens_per_local_expert
+ elif self.config.moe_expert_capacity_factor is not None:
+ # Token drop but no pad. A synchronization is needed before the first
+ # permutation to get the `num_out_tokens` CPU value.
+ self.num_out_tokens = num_local_tokens_per_expert.sum().to(
+ torch.device("cpu"), non_blocking=True
+ )
+ self.cuda_sync_point = "before_permutation_1"
+ elif ep_size > 1:
+ # Token dropless and enable ep. A synchronization is needed before expert parallel
+ # AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
+ self.cuda_sync_point = "before_ep_alltoall"
+ else:
+ # Token dropless and no ep. A synchronization is needed before the token_permutation()
+ # function returns to get the `tokens_per_expert` CPU value.
+ self.cuda_sync_point = "before_finish"
+
+ if ep_size > 1:
+ # ===================================================
+ # Calculate input_splits, output_splits for alltoall-v.
+ # ===================================================
+ self.input_splits = (
+ num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
+ .sum(axis=1)
+ .to(torch.device("cpu"), non_blocking=True)
+ .numpy()
+ )
+ num_global_tokens_per_expert = _gather_along_first_dim_expert_parallel(
+ num_local_tokens_per_expert
+ ).reshape(ep_size, self.num_experts)
+ self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
+ :, self.local_expert_indices[0]: self.local_expert_indices[-1] + 1
+ ]
+ self.output_splits = (
+ self.num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu")).numpy()
+ )
+ num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0)
+ # ===================================================
+ # num_global_tokens_per_expert: [ep_size, num_experts]
+ # num_global_tokens_per_local_expert: [ep_size, num_local_experts]
+ # num_tokens_per_local_expert: [num_local_experts]
+ # ===================================================
+ else:
+ self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
+ -1, self.num_experts
+ )
+ num_tokens_per_local_expert = num_local_tokens_per_expert
+
+ if self.num_local_experts > 1:
+ if not hasattr(self, 'comm_stream'):
+ self.comm_stream = torch.cuda.Stream()
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.comm_stream):
+ # No further synchronization is needed because torch.repeat_interleave() calls stream
+ # synchronization internally when the `output_size` parameter is not provided.
+ self.cuda_sync_point = "no_sync"
+ self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
+ self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
+ )
+
+ return num_tokens_per_local_expert
+
+
+def alltoall_token_permutation(
+ self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
+):
+ self.hidden_shape = hidden_states.shape
+ self.probs = probs
+ assert probs.dim() == 2, "Expected 2D tensor for probs"
+ assert indices.dim() == 2, "Expected 2D tensor for indices"
+ tokens_per_expert = self.preprocess(indices)
+
+ # Flatten the input tensor
+ # hidden_states: [S/TP, B, H] -> [S*B/TP, H]
+ hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
+
+ # Perform tensor parallel AlltoAll communication
+ # hidden_states: [S*B/TP, H] -> [S*B, H/TP]
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
+ hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states)
+
+ # Permutation 1: input to AlltoAll input
+ self.hiddden_shape_before_permute = hidden_states.shape
+ if self.cuda_sync_point == "before_permutation_1":
+ torch.cuda.current_stream().synchronize()
+ permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
+ hidden_states,
+ indices,
+ num_out_tokens=self.num_out_tokens,
+ padded_mode=self.drop_and_pad,
+ )
+
+ if get_args().moe_bmm_mc2:
+ return permutated_local_input_tokens, tokens_per_expert
+
+ # Perform expert parallel AlltoAll communication
+ if self.cuda_sync_point == "before_ep_alltoall":
+ torch.cuda.current_stream().synchronize()
+ global_input_tokens = tensor_parallel.all_to_all(
+ parallel_state.get_expert_model_parallel_group(),
+ permutated_local_input_tokens,
+ self.output_splits,
+ self.input_splits,
+ )
+
+ # Permutation 2: AlltoAll output to expert input if num_local_experts > 1
+ if self.num_local_experts > 1:
+ if not self.drop_and_pad:
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
+ global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
+ global_input_tokens, self.global_input_tokens_local_experts_indices
+ )
+ else:
+ global_input_tokens = global_input_tokens.reshape(
+ self.ep_size, self.num_local_experts, self.capacity, -1
+ )
+ global_input_tokens = (
+ global_input_tokens.transpose(0, 1)
+ .reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
+ .contiguous()
+ )
+
+ # Perform tensor parallel All-Gather on the hidden dimension to obtain the input tokens.
+ # global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
+ if parallel_state.get_tensor_model_parallel_world_size() > 1 and self.config.moe_grouped_gemm:
+ global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(
+ global_input_tokens
+ )
+ if self.cuda_sync_point == "before_finish":
+ torch.cuda.current_stream().synchronize()
+
+ return global_input_tokens, tokens_per_expert
+
+
+def alltoall_token_unpermutation_with_bmm(
+ self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
+):
+ # if use op bmm_reducescatter_alltoall to skip reducescatter and alltoall
+ output = unpermute(
+ hidden_states,
+ self.reversed_local_input_permutation_mapping,
+ probs=self.probs,
+ padded_mode=self.drop_and_pad,
+ restore_shape=self.hiddden_shape_before_permute,
+ )
+
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
+ output = tensor_parallel.all_to_all_hp2sp(output)
+
+ output = output.view(self.hidden_shape)
+ return output, None
+
+
+def alltoall_token_permutation_with_bmm(
+ self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
+):
+ # if use op alltoall_allgather_bmm to skip alltoall and allgather
+ self.hidden_states = hidden_states.shape
+ self.probs = probs
+ assert probs.dim() == 2, "Experted 2D tensor for probs"
+ assert indices.dim() == 2, "Experted 2D tensor for indices"
+ hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
+ tokens_per_expert = self.preprocess(indices)
+
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
+ hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states)
+
+ self.hidden_shape_before_permute = hidden_states.shape
+ permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
+ hidden_states,
+ indices,
+ num_out_tokens=self.num_out_tokens,
+ padded_mode=self.drop_and_pad,
+ )
+ return permutated_local_input_tokens, tokens_per_expert
+
+
+def preprocess_tp_extend_ep(self, indices: torch.Tensor, *args) -> torch.Tensor:
+ moe_hierarchical_alltoallv = get_args().moe_hierarchical_alltoallv
+ num_local_tokens_per_expert = torch.histc(
+ indices, bins=self.num_experts, min=0, max=self.num_experts
+ )
+ # num_local_tokens_per_expert: [num_experts]
+
+ ep_size = self.config.expert_model_parallel_size
+ if self.drop_and_pad:
+ # probs: [num_experts, capacity]
+ self.capacity = self.probs.size(1)
+ num_tokens_per_local_expert = torch.full(
+ (self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long,
+ device=torch.cuda.current_device()
+ )
+ return num_tokens_per_local_expert
+ elif self.config.moe_expert_capacity_factor is not None:
+ self.num_out_tokens = num_local_tokens_per_expert.sum().cpu()
+ tp_size = parallel_state.get_tensor_model_parallel_world_size()
+ tp_extended_ep_size = ep_size * tp_size
+ if tp_extended_ep_size > 1:
+ # ===================================================
+ # Calculate input_splits, output_splits for alltoall-v.
+ # ===================================================
+ if moe_hierarchical_alltoallv:
+ tp_group = parallel_state.get_tensor_model_parallel_group()
+ self.input_splits_tp_ep = (
+ num_local_tokens_per_expert.reshape(tp_extended_ep_size, self.num_local_experts)
+ .sum(axis=1)
+ .to(torch.device("cpu"))
+ .numpy()
+ )
+ expert_parallel_rank = mpu.get_expert_model_parallel_rank()
+ tp_size = parallel_state.get_tensor_model_parallel_world_size()
+ offset = expert_parallel_rank * tp_size
+ self.input_splits = [self.input_splits_tp_ep[i + offset] for i in range(tp_size)]
+ self.input_splits_tp_ep = self.input_splits_tp_ep.reshape(ep_size, tp_size).sum(axis=1)
+ num_global_tokens_per_expert = \
+ all_gather(num_local_tokens_per_expert, group=tp_group).reshape(tp_size, self.num_experts)
+ # shared_experts allgather with tp
+ if get_args().n_shared_experts and parallel_state.get_tensor_model_parallel_world_size() > 1:
+ _, shared_experts_input, shared_experts_allgather_handle = async_all_gather(
+ args[0], parallel_state.get_tensor_model_parallel_group(), is_use_get_global_memory_buffer=True
+ )
+ AG_SHARED_EXPERTS_INPUTS.append((shared_experts_input, shared_experts_allgather_handle))
+ else:
+ self.input_splits_tp_ep = None
+ self.input_splits = (
+ num_local_tokens_per_expert.reshape(tp_extended_ep_size, self.num_local_experts)
+ .sum(axis=1)
+ .to(torch.device("cpu"))
+ .numpy()
+ )
+ num_global_tokens_per_expert = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
+ num_local_tokens_per_expert
+ ).reshape(tp_extended_ep_size, self.num_experts)
+ self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
+ :, self.local_expert_indices
+ ]
+ self.output_splits = (
+ self.num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu")).numpy()
+ )
+ num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0)
+ # ===================================================
+ # num_global_tokens_per_expert: [ep_size, num_experts]
+ # num_global_tokens_per_local_expert: [ep_size, num_local_experts]
+ # num_tokens_per_local_expert: [num_local_experts]
+ # ===================================================
+ else:
+ self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
+ -1, self.num_experts
+ )
+ num_tokens_per_local_expert = num_local_tokens_per_expert
+
+ if self.num_local_experts > 1:
+ if not hasattr(self, 'comm_stream'):
+ self.comm_stream = torch.cuda.Stream()
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.comm_stream):
+ if moe_hierarchical_alltoallv:
+ expert_ids_per_ep_rank = torch.tensor(
+ [i % self.num_local_experts for i in range(self.config.num_moe_experts // ep_size)],
+ dtype=torch.int32,
+ device=torch.cuda.current_device(),
+ )
+ else:
+ expert_ids_per_ep_rank = torch.tensor(
+ [i % self.num_local_experts for i in range(self.config.num_moe_experts)],
+ dtype=torch.int32,
+ device=torch.cuda.current_device(),
+ )
+ self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
+ expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
+ )
+
+ return num_tokens_per_local_expert
+
+
+def alltoall_token_permutation_tp_extend_ep(
+ self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
+):
+ self.hidden_shape = hidden_states.shape
+ self.probs = probs
+ assert probs.dim() == 2, "Expected 2D tensor for probs"
+ assert indices.dim() == 2, "Expected 2D tensor for indices"
+ tokens_per_expert = self.preprocess(indices)
+
+ # Flatten the input tensor
+ # hidden_states: [S/TP, B, H] -> [S*B/TP, H]
+ hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
+
+ # Permutation 1: input to AlltoAll input
+ self.hiddden_shape_before_permute = hidden_states.shape
+ permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
+ hidden_states,
+ indices,
+ num_out_tokens=self.num_out_tokens,
+ padded_mode=self.drop_and_pad,
+ )
+
+ # Perform expert parallel AlltoAll communication
+ global_input_tokens = tensor_parallel.all_to_all(
+ parallel_state.get_tensor_and_expert_parallel_group(),
+ permutated_local_input_tokens,
+ self.output_splits,
+ self.input_splits,
+ )
+
+ # Permutation 2: AlltoAll output to expert input if num_local_experts > 1
+ if self.num_local_experts > 1:
+ if not self.drop_and_pad:
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
+ global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
+ global_input_tokens, self.global_input_tokens_local_experts_indices
+ )
+ else:
+ global_input_tokens = global_input_tokens.reshape(
+ self.ep_size, self.num_local_experts, self.capacity, -1
+ )
+ global_input_tokens = (
+ global_input_tokens.transpose(0, 1)
+ .reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
+ .contiguous()
+ )
+
+ return global_input_tokens, tokens_per_expert
+
+
+def alltoall_token_unpermutation_tp_extend_ep(
+ self, hidden_states: torch.Tensor, bias: torch.Tensor = None,
+):
+ """
+ Reverse the token permutation to restore the original order.
+
+ Args:
+ hidden_states (torch.Tensor): Output from local experts.
+ bias (torch.Tensor, optional): Bias tensor (not supported).
+
+ Returns:
+ Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ - Unpermuted token embeddings in the original order.
+ - None (bias is not supported).
+ """
+ assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
+
+ # Unpermutation 2: expert output to AlltoAll input
+ # hidden_states: [SEQL, H] -> [SEQL, H/TP]
+ if self.num_local_experts > 1:
+ if not self.drop_and_pad:
+ hidden_states = unpermute(
+ hidden_states, self.reversed_global_input_permutation_mapping,
+ )
+ else:
+ hidden_states = hidden_states.reshape(
+ self.num_local_experts, self.ep_size, self.capacity, -1
+ )
+ hidden_states = (
+ hidden_states.transpose(0, 1)
+ .reshape(self.ep_size * self.num_local_experts * self.capacity, -1)
+ .contiguous()
+ )
+
+ # Perform expert parallel AlltoAll communication
+ permutated_local_input_tokens = tensor_parallel.all_to_all(
+ parallel_state.get_tensor_and_expert_parallel_group(),
+ hidden_states,
+ self.input_splits,
+ self.output_splits,
+ )
+
+ # Unpermutation 1: AlltoAll output to output
+ output = unpermute(
+ permutated_local_input_tokens,
+ self.reversed_local_input_permutation_mapping,
+ probs=self.probs,
+ padded_mode=self.drop_and_pad,
+ restore_shape=self.hiddden_shape_before_permute,
+ )
+
+ # Reshape the output tensor
+ output = output.view(self.hidden_shape)
+ return output, None
+
+
+def allgather_token_permutation_new(self, global_indices_2_tuple, global_probs_2_tuple, global_hidden_states_2_tuple):
+ global_indices, gi_handle = global_indices_2_tuple
+ global_probs, gp_handle = global_probs_2_tuple
+ global_hidden_states, ghs_handle = global_hidden_states_2_tuple
+
+ local_hidden_states = None
+ tokens_per_expert = None
+
+ if (self.config.tensor_model_parallel_size > 1) or (
+ self.config.expert_model_parallel_size > 1
+ ):
+ with (torch.no_grad()):
+ gi_handle.wait()
+ global_local_mask = (global_indices >= self.local_expert_indices[0]) & \
+ (global_indices <= self.local_expert_indices[-1])
+
+ # masked_select -> reshape
+ local_indices = global_indices.masked_select(global_local_mask)
+ self.indices = torch.argsort(local_indices.float(), dim=0)
+ num_global_experts = self.num_local_experts * parallel_state.get_expert_model_parallel_world_size()
+ if get_args().moe_tp_extend_ep:
+ num_global_experts *= parallel_state.get_tensor_model_parallel_world_size()
+ all_tokens_per_expert = torch.histc(
+ global_indices,
+ bins=num_global_experts,
+ min=0,
+ max=num_global_experts
+ )
+ self.all_tokens_per_expert = all_tokens_per_expert.to(torch.long)
+ tokens_per_expert = self.all_tokens_per_expert[self.local_expert_indices[0]: self.local_expert_indices[-1] + 1]
+ self.global_local_map = global_local_mask.nonzero()[:, 0]
+
+ if self.router_topk > 1: # k > 1
+ gp_handle.wait()
+ # masked_select -> reshape
+ self.local_probs = global_probs.masked_select(global_local_mask)
+
+ ghs_handle.wait()
+ if cann_version_check:
+ local_hidden_states = global_hidden_states[self.global_local_map, :]
+ else:
+ self.global_local_map = self.global_local_map.view(-1, 1).expand(-1, self.hidden_shape[-1])
+ local_hidden_states = moe_gather.apply(global_hidden_states, self.global_local_map)
+ if self.num_local_experts > 1:
+ if cann_version_check:
+ permuted_local_hidden_states = local_hidden_states[self.indices, :]
+ else:
+ self.indices = self.indices.view(-1, 1).expand(-1, self.hidden_shape[-1])
+ permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices)
+ else:
+ permuted_local_hidden_states = local_hidden_states
+ return (
+ permuted_local_hidden_states,
+ tokens_per_expert,
+ self.global_local_map,
+ self.indices
+ )
+
+
+def allgather_token_unpermutation_new(self, hidden_states: torch.Tensor, bias: torch.Tensor = None):
+ # Stage1: unpermute the tokens and bias locally respectively.w
+ scores = self.local_probs.to(dtype=hidden_states.dtype)
+ if self.num_local_experts > 1:
+ if cann_version_check:
+ unpermuted_local_hidden = torch.zeros_like(hidden_states)
+ unpermuted_local_hidden.index_put_((self.indices,), hidden_states[:self.indices.shape[0], :],
+ accumulate=False)
+ else:
+ assert self.indices.shape == hidden_states.shape
+ unpermuted_local_hidden = moe_scatter.apply(hidden_states, self.indices)
+ else:
+ unpermuted_local_hidden = hidden_states
+
+ # Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
+ if self.router_topk > 1:
+ unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
+
+ unpermuted_local_bias = None
+ if self.add_bias:
+ assert bias is not None
+ unpermuted_local_bias = torch.zeros_like(hidden_states)
+ if cann_version_check:
+ unpermuted_local_bias.index_put_((self.indices,), bias[:self.indices.shape[0], :], accumulate=False)
+ else:
+ assert self.indices.shape == bias.shape
+ unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias)
+
+ if self.router_topk > 1:
+ unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
+
+ output_total = unpermuted_local_hidden
+ output_bias_total = unpermuted_local_bias
+
+ # Unpermute the tokens across expert parallel devices.
+ if (self.config.tensor_model_parallel_size > 1) or (
+ self.config.expert_model_parallel_size > 1
+ ):
+ assert (
+ self.global_local_map is not None
+ ), "global_local_map is necessary for 'AllGather'."
+ ep_group_size = parallel_state.get_tensor_and_expert_parallel_world_size()
+ # hidden_shape: [SeqLen/TP, MBS, HiddenSize], global_num_tokens = SeqLen/TP*MBS*(TP*EP)
+ global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size
+ global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
+
+ if cann_version_check:
+ unpermuted_global_hidden = torch.zeros(global_hidden_shape, dtype=torch.float,
+ device=torch.cuda.current_device())
+ unpermuted_global_hidden = NewIndePut.apply(unpermuted_global_hidden, (self.global_local_map,),
+ unpermuted_local_hidden[:self.global_local_map.shape[0], :])
+ else:
+ unpermuted_global_hidden = torch.zeros(
+ global_hidden_shape, dtype=hidden_states.dtype, device=torch.cuda.current_device()
+ )
+ # Reshape global_local_map to be compatible with Tensor.scatter
+ assert self.global_local_map.shape == unpermuted_local_hidden.shape
+ unpermuted_global_hidden = unpermuted_global_hidden.scatter_add(
+ 0, self.global_local_map, unpermuted_local_hidden)
+
+ output_total = unpermuted_global_hidden
+ if self.add_bias:
+ # Unpermute the bias across expert parallel devices.
+ unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
+ if cann_version_check:
+ unpermuted_global_bias.index_put_((self.global_local_map,),
+ unpermuted_local_bias[:self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ unpermuted_global_bias = unpermuted_global_bias.scatter_add(
+ 0, self.global_local_map, unpermuted_local_bias
+ )
+
+ output_bias_total = tensor_parallel.reduce_scatter_to_sequence_parallel_region_from_moe(
+ unpermuted_global_bias
+ )
+ # bias is duplicated across tensor parallelism ranks;
+ # reduce scatter reduces bias across tensor parallel_ranks
+ output_bias_total = (output_bias_total / parallel_state.get_tensor_model_parallel_world_size())
+ else:
+ if self.router_topk > 1:
+ global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
+ global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
+ unpermuted_global_hidden = torch.zeros(
+ global_hidden_shape,
+ dtype=hidden_states.dtype,
+ device=torch.cuda.current_device()
+ )
+ if cann_version_check:
+ output_total = unpermuted_global_hidden.index_put((self.global_local_map,),
+ unpermuted_local_hidden[
+ :self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ output_total = unpermuted_global_hidden.scatter_add(
+ 0, self.global_local_map, unpermuted_local_hidden
+ )
+
+ if self.add_bias:
+ unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
+ if cann_version_check:
+ output_bias_total = unpermuted_global_bias.index_put((self.global_local_map,),
+ unpermuted_local_bias[
+ :self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ output_bias_total = unpermuted_global_bias.scatter_add(
+ 0, self.global_local_map, unpermuted_local_bias
+ )
+
+ if self.router_topk == 1:
+ output_total = output_total * scores
+ if self.add_bias:
+ assert output_bias_total is not None
+ if self.router_topk == 1:
+ output_bias_total = output_bias_total * scores
+ output_bias_total = output_bias_total.view(self.hidden_shape)
+ else:
+ output_bias_total = None
+
+ return output_total, output_bias_total
+
+
+def alltoall_token_permutation_new(
+ self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor, shared_experts, save_tensors, shared_expert_gate, moe_ctx=None
+):
+ moe_hierarchical_alltoallv = get_args().moe_hierarchical_alltoallv
+ self.hidden_shape = hidden_states.shape
+ self.probs = probs
+ assert probs.dim() == 2, "Expected 2D tensor for probs"
+ assert indices.dim() == 2, "Expected 2D tensor for indices"
+ if moe_hierarchical_alltoallv:
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ _, indices, indices_handle = async_all_gather(indices, group=ep_group)
+ indices_handle.wait()
+ save_tensors.append(indices)
+ _, hidden_states_ep, hidden_states_ep_handle = async_all_gather(hidden_states, group=ep_group)
+ else:
+ indices_ep, hidden_states_ep, hidden_states_ep_handle = None, None, None
+ save_tensors.append(indices_ep)
+
+ def alltoall_token_permutation1(hidden_states, indices, *args):
+ if moe_hierarchical_alltoallv:
+ _, self.probs, probs_handle = async_all_gather(self.probs, group=ep_group)
+ tokens_per_expert = self.preprocess(indices, hidden_states)
+ args[1].wait() # hidden_states_ep_handle
+ save_tensors.append(args[0]) # hidden_states_ep
+ # hidden_states: [S/TP, B, H] -> [S*B/TP, H]
+ hidden_states = args[0].view(-1, self.hidden_shape[-1])
+ self.hidden_shape_before_permute = hidden_states.shape
+ # Permutation 1: input to AlltoAll input
+ if self.cuda_sync_point == "before_permutation_1":
+ torch.cuda.current_stream().synchronize()
+ probs_handle.wait()
+ self.probs = self.probs.detach()
+ self.probs.requires_grad = True
+ save_tensors.append(self.probs)
+ permutated_local_input_tokens, permuted_probs, self.reversed_local_input_permutation_mapping = permute_with_ep(
+ hidden_states, indices, probs=self.probs, topk=self.router_topk,
+ gb_inputs_splits=self.input_splits_tp_ep,
+ )
+ self.permuted_probs = permuted_probs
+ else:
+ tokens_per_expert = self.preprocess(indices)
+ save_tensors.append(args[0])
+ if get_args().moe_experts_pipeline_degree:
+ tokens_per_expert = tokens_per_expert.cpu()
+
+ # Flatten the input tensor
+ # hidden_states: [S/TP, B, H] -> [S*B/TP, H]
+ hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
+
+ # Perform tensor parallel AlltoAll communication
+ # hidden_states: [S*B/TP, H] -> [S*B, H/TP]
+ if not get_args().moe_tp_extend_ep and parallel_state.get_tensor_model_parallel_world_size() > 1:
+ hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states)
+
+ # Permutation 1: input to AlltoAll input
+ self.hiddden_shape_before_permute = hidden_states.shape
+ if self.cuda_sync_point == "before_permutation_1":
+ torch.cuda.current_stream().synchronize()
+ scores_ep = None
+ save_tensors.append(scores_ep)
+ permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
+ hidden_states,
+ indices,
+ num_out_tokens=self.num_out_tokens,
+ padded_mode=self.drop_and_pad,
+ )
+ return tokens_per_expert, permutated_local_input_tokens
+
+ (tokens_per_expert, permutated_local_input_tokens), *_ = forward_func(alltoall_token_permutation1,
+ (hidden_states, indices,
+ hidden_states_ep, hidden_states_ep_handle))
+
+ # permute 1
+ save_tensors.append(permutated_local_input_tokens)
+
+ # Perform expert parallel AlltoAll communication
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ if get_args().moe_tp_extend_ep:
+ ep_group = parallel_state.get_tensor_and_expert_parallel_group()
+
+ # Perform expert parallel AlltoAll communication
+ if self.cuda_sync_point == "before_ep_alltoall":
+ torch.cuda.current_stream().synchronize()
+ if moe_hierarchical_alltoallv:
+ tp_group = parallel_state.get_tensor_model_parallel_group()
+ _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
+ permutated_local_input_tokens,
+ self.output_splits,
+ self.input_splits,
+ tp_group,
+ )
+ else:
+ _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
+ permutated_local_input_tokens,
+ self.output_splits,
+ self.input_splits,
+ ep_group,
+ )
+
+ # shared experts
+ if shared_experts is not None:
+ (share_experts_output, _), *_ = forward_func(shared_experts, (hidden_states, moe_ctx))
+ if parallel_state.get_tensor_model_parallel_world_size() > 1 and shared_expert_gate is None:
+ share_experts_graph, share_experts_output, rs_shared_experts_handle = async_reduce_scatter(share_experts_output, parallel_state.get_tensor_model_parallel_group(),
+ event=permute1_ep_all_to_all_handle, stream=torch.npu.default_stream())
+ share_experts_output = (share_experts_graph, share_experts_output, rs_shared_experts_handle)
+ if shared_expert_gate is not None:
+ with torch.enable_grad():
+ # tp not support shared expert gate for now
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
+ share_experts_output = reduce_scatter_to_sequence_parallel_region(share_experts_output)
+ share_experts_output = torch.nn.functional.sigmoid(shared_expert_gate(hidden_states)) * share_experts_output
+ else:
+ share_experts_output = None
+
+ if permute1_ep_all_to_all_handle is not None:
+ permute1_ep_all_to_all_handle.wait()
+ permutated_local_input_tokens.untyped_storage().resize_(0)
+
+ def alltoall_token_permutation2(global_input_tokens):
+ # Permutation 2: AlltoAll output to expert input if num_local_experts > 1
+ if self.num_local_experts > 1:
+ if not self.drop_and_pad:
+ if self.comm_stream is not None:
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
+ global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
+ global_input_tokens, self.global_input_tokens_local_experts_indices
+ )
+ else:
+ global_input_tokens = global_input_tokens.reshape(
+ self.ep_size, self.num_local_experts, self.capacity, -1
+ )
+ global_input_tokens = (
+ global_input_tokens.transpose(0, 1)
+ .reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
+ .contiguous()
+ )
+ # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
+ # global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
+ need_tp_comm = (not get_args().moe_tp_extend_ep and
+ parallel_state.get_tensor_model_parallel_world_size() > 1 and
+ self.config.moe_grouped_gemm) and get_args().moe_experts_pipeline_degree == 0
+ if need_tp_comm:
+ global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(
+ global_input_tokens
+ )
+ if self.cuda_sync_point == "before_finish":
+ torch.cuda.current_stream().synchronize()
+
+ return global_input_tokens
+
+ # token 重排2 input
+ (global_input_tokens), global_input_tokens_detach = forward_func(alltoall_token_permutation2,
+ global_input_tokens)
+ save_tensors.append(global_input_tokens_detach)
+ save_tensors.append(global_input_tokens)
+ global_input_tokens_detach.untyped_storage().resize_(0)
+
+ return share_experts_output, global_input_tokens, tokens_per_expert
+
+
+def alltoall_token_unpermutation_new(
+ self, hidden_states, bias, save_tensors
+):
+ moe_hierarchical_alltoallv = get_args().moe_hierarchical_alltoallv
+
+ def alltoall_token_unpermutation1(hidden_states):
+ assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
+
+ # Perform tensor parallel Reduce-Scatter
+ # hidden_states: [SEQL, H] -> [SEQL, H/TP]
+ if not get_args().moe_tp_extend_ep and parallel_state.get_tensor_model_parallel_world_size() > 1 and get_args().moe_experts_pipeline_degree == 0:
+ hidden_states = tensor_parallel.reduce_scatter_last_dim_to_tensor_parallel_region(hidden_states)
+
+ # Unpermutation 2: expert output to AlltoAll input
+ if self.num_local_experts > 1:
+ if not self.drop_and_pad:
+ hidden_states = unpermute(
+ hidden_states, self.reversed_global_input_permutation_mapping,
+ )
+ else:
+ hidden_states = hidden_states.reshape(
+ self.num_local_experts, self.ep_size, self.capacity, -1
+ )
+ hidden_states = (
+ hidden_states.transpose(0, 1)
+ .reshape(self.ep_size * self.num_local_experts * self.capacity, -1)
+ .contiguous()
+ )
+ return hidden_states
+ if get_args().moe_experts_pipeline_degree:
+ with torch.enable_grad():
+ hidden_states = alltoall_token_unpermutation1(hidden_states)
+ save_tensors.append(hidden_states)
+ else:
+ hidden_states, unpermute1_input_detach = forward_func(alltoall_token_unpermutation1, hidden_states)
+ save_tensors.append(unpermute1_input_detach)
+ save_tensors.append(hidden_states)
+ unpermute1_input_detach.untyped_storage().resize_(0)
+
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ if get_args().moe_tp_extend_ep:
+ ep_group = parallel_state.get_tensor_and_expert_parallel_group()
+ # Perform expert parallel AlltoAll communication
+ # hidden_states: [SEQL, H] -> [SEQL, H/TP]
+ if moe_hierarchical_alltoallv:
+ tp_group = parallel_state.get_tensor_model_parallel_group()
+ _, permutated_local_input_tokens, handle = async_all_to_all(
+ hidden_states,
+ self.input_splits,
+ self.output_splits,
+ tp_group
+ )
+ else:
+ _, permutated_local_input_tokens, handle = async_all_to_all(
+ hidden_states,
+ self.input_splits,
+ self.output_splits,
+ ep_group
+ )
+ if handle is not None:
+ handle.wait()
+ hidden_states.untyped_storage().resize_(0)
+
+ def alltoall_token_unpermutation2(permutated_local_input_tokens):
+ # Unpermutation 1: AlltoAll output to output
+ if get_args().moe_zero_memory != "disable":
+ output = UnpermuteWithoutActivation.apply(
+ permutated_local_input_tokens,
+ self.reversed_local_input_permutation_mapping,
+ self.probs
+ )
+ else:
+ if moe_hierarchical_alltoallv:
+ unpermute_with_ep_input_tensors_list = [permutated_local_input_tokens,
+ self.reversed_local_input_permutation_mapping,
+ self.permuted_probs]
+ output = unpermute_with_ep(
+ unpermute_with_ep_input_tensors_list,
+ restore_shape=self.hidden_shape_before_permute,
+ probs=self.probs,
+ topk=self.router_topk
+ )
+ else:
+ output = unpermute(
+ permutated_local_input_tokens,
+ self.reversed_local_input_permutation_mapping,
+ probs=self.probs,
+ padded_mode=self.drop_and_pad,
+ restore_shape=self.hiddden_shape_before_permute,
+ )
+ if moe_hierarchical_alltoallv:
+ return output
+ # Perform tensor parallel AlltoAll communication
+ # output: [S*B, H/TP] -> [S*B/TP, H]
+ if not get_args().moe_tp_extend_ep and parallel_state.get_tensor_model_parallel_world_size() > 1:
+ output = tensor_parallel.all_to_all_hp2sp(output)
+
+ # Reshape the output tensor
+ output = output.view(self.hidden_shape)
+ return output
+
+ output, unpermute2_input_detach = forward_func(alltoall_token_unpermutation2, permutated_local_input_tokens)
+ save_tensors.append(unpermute2_input_detach)
+ should_resize = not self.drop_and_pad and not moe_hierarchical_alltoallv and \
+ not get_args().use_fused_moe_token_permute_and_unpermute or get_args().moe_zero_memory != "disable"
+ if should_resize:
+ unpermute2_input_detach.untyped_storage().resize_(0)
+ save_tensors.append(output)
+
+ if moe_hierarchical_alltoallv:
+ ep_group = parallel_state.get_expert_model_parallel_group()
+ _, output, output_handle = async_reduce_scatter(output, group=ep_group)
+ output_handle.wait()
+ output = output.view(self.hidden_shape)
+ return output, None
+
+
+def allgather_token_permutation_npu(self, hidden_states: torch.Tensor, max_prob: torch.Tensor, max_ind: torch.Tensor):
+ self.hidden_shape = hidden_states.shape
+ # [S/TP, B, H] -> [S*B/TP, H]
+ hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
+
+ # Permute the tokens across the expert parallel devices.
+ if (self.config.tensor_model_parallel_size > 1) or (
+ self.config.expert_model_parallel_size > 1
+ ):
+ with torch.no_grad():
+ global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
+ max_ind
+ )
+ # Create a mask of mapping between global and local tokens where each
+ # element is True if it's between the local_expert_indices
+ global_local_mask = (global_indices >= self.local_expert_indices[0]) & (
+ global_indices <= self.local_expert_indices[-1]
+ )
+ local_indices = global_indices.masked_select(global_local_mask)
+
+ if self.router_topk > 1: # k > 1
+ global_probs = tensor_parallel.gather_from_sequence_parallel_region_to_moe(max_prob)
+ self.local_probs = global_probs.masked_select(global_local_mask)
+ else:
+ self.local_probs = max_prob
+
+ # [S*B/TP, H] -> [S*B, H]
+ global_hidden_states = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
+ hidden_states, use_global_buffer=True
+ )
+ # Reshape global_local_mask to be compatible with Tensor.gather
+ global_local_map = global_local_mask.nonzero()[:, 0]
+ self.global_local_map = global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
+ local_hidden_states = moe_gather.apply(global_hidden_states, self.global_local_map)
+ else:
+ if self.router_topk > 1:
+ global_local_mask = torch.ones_like(max_ind).bool()
+ local_indices = max_ind.masked_select(global_local_mask)
+ self.local_probs = max_prob.masked_select(global_local_mask)
+ global_local_map = global_local_mask.nonzero()[:, 0]
+ self.global_local_map = global_local_map.view(-1, 1).expand(
+ -1, hidden_states.shape[-1]
+ )
+ local_hidden_states = torch.gather(hidden_states, 0, self.global_local_map)
+ else:
+ local_indices = max_ind
+ self.local_probs = max_prob
+ local_hidden_states = hidden_states
+ self.global_local_map = None
+
+ with torch.no_grad():
+ # The indices of local_indices that give its sorted order along dim 0.
+ self.indices = torch.argsort(local_indices, dim=0)
+ # use 0.7.0 implement for better performance
+ tokens_per_expert = torch.histc(
+ local_indices,
+ bins=self.num_local_experts,
+ min=self.local_expert_indices[0],
+ max=self.local_expert_indices[-1],
+ )
+ tokens_per_expert = tokens_per_expert.to(torch.long)
+
+ # Stage2: permute the tokens locally so that they are grouped by their expert assignment
+ # Reshape indices to be compatible with Tensor.gather
+ self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
+ if self.num_local_experts > 1:
+ permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices)
+ else:
+ permuted_local_hidden_states = local_hidden_states
+ return (
+ permuted_local_hidden_states,
+ tokens_per_expert,
+ )
+
+
+def alltoall_preprocess_npu(self, indices: torch.Tensor):
+ # use 0.7.0 implement for better performance
+ num_local_tokens_per_expert = torch.histc(
+ indices, bins=self.num_experts, min=0, max=self.num_experts
+ )
+ # num_local_tokens_per_expert: [num_experts]
+
+ ep_size = self.config.expert_model_parallel_size
+ if self.drop_and_pad:
+ # probs: [num_experts, capacity]
+ self.capacity = self.probs.size(1)
+ num_tokens_per_local_expert = torch.full(
+ (self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long,
+ device=torch.cuda.current_device()
+ )
+ return num_tokens_per_local_expert
+ elif self.config.moe_expert_capacity_factor is not None:
+ # Token drop but no pad.
+ self.num_out_tokens = num_local_tokens_per_expert.sum().to(
+ torch.device("cpu"), non_blocking=True
+ )
+ self.cuda_sync_point = "before_permutation_1"
+ elif ep_size > 1:
+ # Token dropless and enable ep.
+ self.cuda_sync_point = "before_ep_alltoall"
+ else:
+ # Token dropless and no ep.
+ self.cuda_sync_point = "before_finish"
+
+ if ep_size > 1:
+ # ===================================================
+ # Calculate input_splits, output_splits for alltoall-v.
+ # ===================================================
+ self.input_splits = (
+ num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
+ .sum(axis=1)
+ .to(torch.device("cpu"), non_blocking=True)
+ .numpy()
+ )
+ num_global_tokens_per_expert = _gather_along_first_dim_expert_parallel(
+ num_local_tokens_per_expert
+ ).reshape(ep_size, self.num_experts)
+ self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
+ :, self.local_expert_indices[0]: self.local_expert_indices[-1] + 1
+ ]
+ self.output_splits = (
+ self.num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu")).numpy()
+ )
+ num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0)
+ # ===================================================
+ # num_global_tokens_per_expert: [ep_size, num_experts]
+ # num_global_tokens_per_local_expert: [ep_size, num_local_experts]
+ # num_tokens_per_local_expert: [num_local_experts]
+ # ===================================================
+ else:
+ self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
+ -1, self.num_experts
+ )
+ num_tokens_per_local_expert = num_local_tokens_per_expert
+
+ if self.num_local_experts > 1:
+ # No further synchronization is needed because torch.repeat_interleave() calls stream
+ # synchronization internally when the `output_size` parameter is not provided.
+ self.cuda_sync_point = "no_sync"
+ self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
+ self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
+ )
+
+ return num_tokens_per_local_expert
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/__init__.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e5a4690a85d90545f92d6f9ccdada665435bfd
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/grouped_mlp_2d.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/grouped_mlp_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c704524a81396db8190194befa3cac27114c21ea
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/grouped_mlp_2d.py
@@ -0,0 +1,595 @@
+# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+from copy import deepcopy
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+from torch.cuda.amp import custom_bwd
+from torch.cuda.amp import custom_fwd
+
+from megatron.core import parallel_state
+from megatron.core.dist_checkpointing import ShardedTensor
+from megatron.core.dist_checkpointing.mapping import ReplicaId
+from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory
+from megatron.core.jit import jit_fuser
+from megatron.core.tensor_parallel.layers import _initialize_affine_weight_gpu
+from megatron.core.transformer.module import MegatronModule
+from megatron.core.transformer.moe import grouped_gemm_util as gg
+from megatron.core.transformer.transformer_config import TransformerConfig
+from megatron.core.transformer.utils import make_sharded_object_for_checkpoint
+from megatron.core.utils import divide
+from megatron.training import get_args
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+from mindspeed.core.tensor_parallel.layers import _initialize_affine_weight_cpu_2d
+from mindspeed.core.tensor_parallel.comm_utils import _split_along_last_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_reduce_scatter_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_gather_along_first_dim
+from mindspeed.core.tensor_parallel.comm_utils import sync_gather_along_last_dim
+from mindspeed.core.fusions.fused_bias_swiglu import fused_swiglu
+from mindspeed.ops.gmm import GMMFunction
+
+
+G_FORWARD_PADDING_SIZE = 0
+G_BACKWARD_PADDING_SIZE = 0
+
+
+class GroupedMLP2D(MegatronModule):
+ """An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
+
+ This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
+ """
+
+ def __init__(self, num_local_experts: int, config: TransformerConfig):
+ super().__init__(config=config)
+ self.config: TransformerConfig = config
+ self.num_local_experts = num_local_experts
+ gg.assert_grouped_gemm_is_available()
+ assert (
+ config.add_bias_linear == False
+ ), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
+
+ self.expert_parallel = config.expert_model_parallel_size > 1
+ if self.config.gated_linear_unit:
+ if self.config.activation_func not in (F.silu, F.gelu):
+ raise ValueError("Activation function must be silu or gelu when using GroupedMLP.")
+
+ self.activation_func = fused_swiglu
+ else:
+ self.activation_func = self.config.activation_func
+
+ # How many feature each rank holds for fc1 and fc2, respectively.
+ self.moe_extended_tp = config.moe_extended_tp
+
+ self.config = config
+ self.num_local_experts = num_local_experts
+ gg.assert_grouped_gemm_is_available()
+ assert config.add_bias_linear is False, (
+ "bias in the expert layer is not supported in Grouped GEMM yet, "
+ "please set '--disable-bias-linear' instead."
+ )
+
+ self.init_paras()
+
+ def remove_extra_states_check(self, incompatible_keys):
+ """
+ Remove _extra_state from unexpected keys.
+ These keys are for dist ckpt compatibility with SequentialMLP.
+ """
+ keys = deepcopy(incompatible_keys.unexpected_keys)
+ for key in keys:
+ if "_extra_state" in key:
+ incompatible_keys.unexpected_keys.remove(key)
+
+ self.register_load_state_dict_post_hook(remove_extra_states_check)
+
+ def init_paras(self):
+ config = self.config
+ # How many feature each rank holds for fc1.
+ all_local_expert_fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts
+ expert_fc1_output_size = self.config.ffn_hidden_size
+ if config.gated_linear_unit:
+ # Project to 4h. If using swiglu double the output width,
+ # see https://arxiv.org/pdf/2002.05202.pdf
+ all_local_expert_fc1_output_size *= 2
+ expert_fc1_output_size *= 2
+
+ tpx_comm_world_sz = TPXCollectiveComm.get_comm_group_world_size()
+ tpy_comm_world_sz = TPYCollectiveComm.get_comm_group_world_size()
+ assert self.config.hidden_size % tpy_comm_world_sz == 0, (
+ "fc1 input size should be " "divisible by tp-y"
+ )
+ assert (
+ all_local_expert_fc1_output_size % tpx_comm_world_sz == 0
+ ), "fc1 output size should be divisible by tp-x"
+ # h/y
+ # 2e*dff_h/x
+ all_local_experts_fc1_output_size_per_partition = divide(
+ all_local_expert_fc1_output_size, tpx_comm_world_sz
+ )
+ # How many feature each rank holds for fc2.
+ all_local_experts_fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts
+ assert (
+ all_local_experts_fc2_input_size % tpx_comm_world_sz == 0
+ ), "all local expert fc2 output size should be divisible by tp-y"
+ assert self.config.hidden_size % tpy_comm_world_sz == 0, (
+ "fc2 input size should be " "divisible by tp-x"
+ )
+ # e*dff_h/x
+ all_local_experts_fc2_input_size_per_partition = divide(
+ all_local_experts_fc2_input_size, tpx_comm_world_sz
+ )
+ # h/y
+ # Note: The current kernel implementations of grouped_gemm
+ # does not support transposition with CUTLASS grouped GEMM
+ # (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
+ # and as a result we avoid allocate the transpose of weights.
+ # Initialize weight.
+ if config.use_cpu_initialization:
+ w1s = [] # e1: splited_w1, e2: splited_w1 ..
+ w2s = [] # e1: splited_w2, e2: splited_w2 ..
+ master_w1s = []
+ master_w2s = []
+ for idx in range(self.num_local_experts):
+ # [h/y, 2*dff_h/x]
+ w1 = Parameter(
+ torch.empty(
+ self.config.hidden_size // tpy_comm_world_sz,
+ expert_fc1_output_size // tpx_comm_world_sz,
+ dtype=config.params_dtype,
+ )
+ )
+
+ master_w1 = _initialize_affine_weight_cpu_2d(w1, 1, return_master_weight=True, config=self.config)
+ w1s.append(w1)
+ master_w1s.append(master_w1)
+ # [dff_h/x, h/y]
+ w2 = Parameter(
+ torch.empty(
+ self.config.ffn_hidden_size // tpx_comm_world_sz,
+ self.config.hidden_size // tpy_comm_world_sz,
+ dtype=config.params_dtype,
+ )
+ )
+ master_w2 = _initialize_affine_weight_cpu_2d(w2, 0, return_master_weight=True, config=self.config)
+ w2s.append(w2)
+ master_w2s.append(master_w2)
+
+ self.master_weight1 = Parameter(torch.cat(master_w1s, dim=-1).contiguous().npu())
+ self.master_weight2 = Parameter(torch.cat(master_w2s, dim=0).contiguous().npu())
+ # [h/y, e*2*dff_h/x]
+ self.weight1 = Parameter(torch.cat(w1s, dim=-1).contiguous().npu())
+ # [e*dff_h/x, h/y]
+ self.weight2 = Parameter(torch.cat(w2s, dim=0).contiguous().npu())
+ else:
+ # [h/y, 2e*dff_h/x]
+ self.weight1 = Parameter(
+ torch.empty(
+ divide(self.config.hidden_size, tpy_comm_world_sz),
+ all_local_experts_fc1_output_size_per_partition,
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+ # [e*dff_h/x, h/y]
+ self.weight2 = Parameter(
+ torch.empty(
+ all_local_experts_fc2_input_size_per_partition,
+ divide(self.config.hidden_size, tpy_comm_world_sz),
+ device=torch.cuda.current_device(),
+ dtype=config.params_dtype,
+ )
+ )
+ if config.perform_initialization:
+ _initialize_affine_weight_gpu(
+ self.weight1,
+ config.init_method,
+ partition_dim=1,
+ expert_parallel=self.expert_parallel,
+ )
+ _initialize_affine_weight_gpu(
+ self.weight2,
+ config.output_layer_init_method,
+ partition_dim=0,
+ expert_parallel=self.expert_parallel,
+ )
+
+ setattr(self.weight1, "allreduce", not self.expert_parallel)
+ setattr(self.weight2, "allreduce", not self.expert_parallel)
+
+ def forward(self, permuted_local_hidden_states, tokens_per_expert):
+ grouped_mlp_paras = dict()
+ grouped_mlp_paras['tokens_per_expert'] = tokens_per_expert
+ grouped_mlp_paras['hidden_size'] = self.config.hidden_size
+ grouped_mlp_paras['num_local_experts'] = self.num_local_experts
+ grouped_mlp_paras['gemm_fusion'] = get_args().gemm_gradient_accumulation_fusion
+ grouped_mlp_paras['tp_y'] = get_args().tp_y
+
+ # [n, h] -> [n1/y, 2e*dff_h/x]
+ fc1_output = CustomGMM2DFC1.apply(permuted_local_hidden_states, self.weight1, grouped_mlp_paras)
+
+ # [n1/y, 2e*dff_h/x] -> [n1/y, e*dff_h/x]
+ intermediate_parallel = self.activation_func(fc1_output)
+
+ # [n1/y, e*dff_h/x] -> [n, h] partial-x
+ fc2_output = CustomGMM2DFC2.apply(intermediate_parallel, self.weight2, grouped_mlp_paras)
+
+ return fc2_output, None
+
+ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
+ """Maps local expert to global experts."""
+ if self.moe_extended_tp:
+ raise NotImplementedError(
+ "Currently distributed checkpointing is not supported for moe_extended_tp"
+ )
+
+ sharded_state_dict = {}
+ num_global_experts = (
+ parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
+ )
+ local_expert_indices_offset = (
+ parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
+ )
+ tp_size = TPXCollectiveComm.get_comm_group_world_size()
+ tp_rank = TPXCollectiveComm.get_comm_rank()
+
+ prepend_axis_num = len(sharded_offsets)
+ replica_id = (
+ 0,
+ 0,
+ parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True),
+ )
+
+ @torch.no_grad()
+ def sh_ten_build_fn(
+ key: str,
+ t: torch.Tensor,
+ replica_id: ReplicaId,
+ flattened_range: Optional[slice],
+ tp_axis: int,
+ with_glu: bool,
+ ):
+ if tp_axis == 0:
+ real_shape = (self.num_local_experts, self.config.hidden_size // get_args().tp_y, -1)
+ elif tp_axis == 1:
+ real_shape = (self.num_local_experts, -1, self.config.hidden_size // get_args().tp_y)
+ assert with_glu == False
+ else:
+ raise ValueError("tp_axis should be 0 or 1.")
+ if flattened_range is None:
+ t = t.view(real_shape).transpose(-1, -2)
+ if with_glu:
+ local_tensors = torch.chunk(t, 2, -2)
+ sub_states = [
+ ShardedTensor.from_rank_offsets(
+ key,
+ local_tensors[0].contiguous(),
+ *sharded_offsets,
+ (
+ prepend_axis_num,
+ parallel_state.get_expert_model_parallel_rank(),
+ parallel_state.get_expert_model_parallel_world_size(),
+ ),
+ (prepend_axis_num + 1, tp_rank, tp_size * 2),
+ replica_id=replica_id,
+ prepend_axis_num=prepend_axis_num,
+ ),
+ ShardedTensor.from_rank_offsets(
+ key,
+ local_tensors[1].contiguous(),
+ *sharded_offsets,
+ (
+ prepend_axis_num,
+ parallel_state.get_expert_model_parallel_rank(),
+ parallel_state.get_expert_model_parallel_world_size(),
+ ),
+ (prepend_axis_num + 1, tp_size + tp_rank, tp_size * 2),
+ replica_id=replica_id,
+ prepend_axis_num=prepend_axis_num,
+ ),
+ ]
+ else:
+ sub_states = ShardedTensor.from_rank_offsets(
+ key,
+ t.contiguous(),
+ *sharded_offsets,
+ (
+ prepend_axis_num,
+ parallel_state.get_expert_model_parallel_rank(),
+ parallel_state.get_expert_model_parallel_world_size(),
+ ),
+ (prepend_axis_num + 1 + tp_axis, tp_rank, tp_size),
+ replica_id=replica_id,
+ prepend_axis_num=prepend_axis_num,
+ )
+ else:
+ raise NotImplementedError(
+ "Currently GroupedMLP does not support distributed checkpointing "
+ "with the distributed optimizer."
+ )
+ return sub_states
+
+ @torch.no_grad()
+ def sh_ten_merge_fn(sub_state_dict, tp_axis: int, with_glu: bool):
+ if tp_axis == 0:
+ weight_shape = (self.config.hidden_size, -1)
+ elif tp_axis == 1:
+ weight_shape = (-1, self.config.hidden_size)
+ assert with_glu == False
+ else:
+ raise ValueError("tp_axis should be 0 or 1.")
+ if with_glu:
+ sub_state_dict = torch.cat(sub_state_dict, -2)
+ return sub_state_dict.transpose(-1, -2).reshape(weight_shape)
+
+ state_dict = self.state_dict(prefix="", keep_vars=True)
+ # To align with SequentialMLP, the weight tensors are transposed,
+ # and the tp_axis is also for the transposed tensors
+ for name, tensor in state_dict.items():
+ if name == "weight1":
+ tp_axis = 0
+ with_glu = self.config.gated_linear_unit
+ wkey = f"{prefix}experts.linear_fc1.weight"
+ else:
+ tp_axis = 1
+ with_glu = False
+ wkey = f"{prefix}experts.linear_fc2.weight"
+ sharded_state_dict[f"{prefix}{name}"] = ShardedTensorFactory(
+ wkey,
+ tensor,
+ partial(sh_ten_build_fn, tp_axis=tp_axis, with_glu=with_glu),
+ partial(sh_ten_merge_fn, tp_axis=tp_axis, with_glu=with_glu),
+ replica_id,
+ )
+
+ replica_id = (
+ 0,
+ parallel_state.get_tensor_model_parallel_rank(),
+ parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True),
+ )
+ # Add fake _extra_state to be compatible with SequentialMLP
+ for expert_local_idx in range(self.num_local_experts):
+ expert_global_idx = local_expert_indices_offset + expert_local_idx
+ expert_sharded_offsets = (
+ *sharded_offsets,
+ (len(sharded_offsets), expert_global_idx, num_global_experts),
+ )
+ for mod in ["linear_fc1", "linear_fc2"]:
+ sharded_state_dict[
+ f"{prefix}expert{expert_global_idx}.{mod}._extra_state"
+ ] = make_sharded_object_for_checkpoint(
+ None, f"{prefix}experts.{mod}._extra_state", expert_sharded_offsets, replica_id,
+ )
+
+ return sharded_state_dict
+
+
+class CustomGMM2DFC1(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, activation_input, weight, grouped_mlp_paras):
+ # activation_input: [n, h], weight: [h/y, 2e*dff_h/x]
+
+ ctx.grouped_mlp_paras = grouped_mlp_paras
+ ctx.weight = weight
+
+ num_local_experts = grouped_mlp_paras.get('num_local_experts')
+ hidden_size = grouped_mlp_paras.get('hidden_size')
+ tokens_per_expert = grouped_mlp_paras.get('tokens_per_expert')
+ gemm_fusion = grouped_mlp_paras.get('gemm_fusion')
+ tp_y = grouped_mlp_paras.get('tp_y')
+
+ # [n, h] -> [n, h/y]
+ activation_input = _split_along_last_dim(activation_input, TPYCollectiveComm)
+ ctx.save_for_backward(activation_input)
+
+ # [h/y, 2e*dff_h/x]-> [2e*dff_h/x, h/y]
+ w1 = weight.transpose(0, -1).contiguous()
+ # [2e*dff_h/x, h/y] -> [e, 2*dff_h/x, h/y]
+ w1 = w1.view(num_local_experts, -1, hidden_size // tp_y)
+ # [e, 2*dff_h/x, h/y] -> [e, h/y, 2*dff_h/x]
+ w1 = w1.transpose(1, -1).contiguous()
+
+ # [n, h/y] @ [e, h/y, 2*dff_h/x] -> [n, 2e*dff_h/x] partial-y
+ fc1_output = gg.ops.gmm(
+ activation_input,
+ w1,
+ tokens_per_expert,
+ trans_b=False,
+ gemm_fusion=gemm_fusion,
+ original_weight=weight
+ )
+
+ # padding for reduce scatter, [n, 2e*dff_h/x] partial-y -> [n1, 2e*dff_h/x] partial-y
+ global G_FORWARD_PADDING_SIZE
+ n_tokens, h = fc1_output.shape
+ rs_size = TPYCollectiveComm.get_comm_group_world_size()
+ remaining = n_tokens - n_tokens // rs_size * rs_size
+ G_FORWARD_PADDING_SIZE = rs_size - remaining if remaining else 0
+ if G_FORWARD_PADDING_SIZE != 0:
+ padding_tensor = torch.zeros(
+ G_FORWARD_PADDING_SIZE, h, dtype=fc1_output.dtype, device=fc1_output.device
+ )
+ fc1_output = torch.cat((fc1_output, padding_tensor), dim=0)
+
+ # [n1, 2e*dff_h/x] partial-y -> [n1/y, 2e*dff_h/x]
+ fc1_output = sync_reduce_scatter_along_first_dim(fc1_output, TPYCollectiveComm)
+
+ return fc1_output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ # grad_output shape: [n1/y, 2e*dff_h/x]
+
+ # activation_input shape: [n, h/y]
+ activation_input, = ctx.saved_tensors
+ grouped_mlp_paras = ctx.grouped_mlp_paras
+
+ # weight shape: [h/y, 2e*dff_h/x]
+ weight = ctx.weight
+
+ num_local_experts = grouped_mlp_paras.get('num_local_experts')
+ tokens_per_expert = grouped_mlp_paras.get('tokens_per_expert')
+ hidden_size = grouped_mlp_paras.get('hidden_size')
+ gemm_fusion = grouped_mlp_paras.get('gemm_fusion')
+ tp_y = grouped_mlp_paras.get('tp_y')
+
+ # weight shape: [h/y, 2e*dff_h/x] -> [2e*dff_h/x, h/y]
+ w1 = weight.t().contiguous()
+ # [2e*dff_h/x, h/y] -> [e, 2*dff_h/x, h/y]
+ w1 = w1.view(num_local_experts, -1, hidden_size // tp_y)
+
+ # [n1/y, 2e*dff_h/x] -> [n1, 2e*dff_h/x]
+ total_grad_output = sync_gather_along_first_dim(grad_output, TPYCollectiveComm)
+
+ # unpadding, [n1, 2e*dff_h/x] -> [n, 2e*dff_h/x]
+ global G_BACKWARD_PADDING_SIZE
+ if G_BACKWARD_PADDING_SIZE != 0:
+ real_input_num = total_grad_output.shape[0] - G_BACKWARD_PADDING_SIZE
+ total_grad_output = total_grad_output[:real_input_num, :]
+
+ # [n, 2e*dff_h/x] @ [e, 2*dff_h/x, h/y] = [n, h/y] partial-x
+ grad_gmm_output = gg.ops.gmm(
+ total_grad_output,
+ w1,
+ tokens_per_expert,
+ trans_b=False,
+ gemm_fusion=gemm_fusion,
+ )
+
+ group_list = torch.cumsum(tokens_per_expert, dim=0)
+ # [h/y, n] @ [n, 2e*dff_h/x] = [e, h/y, 2*dff_h/x]
+ grad_weight_output = GMMFunction.builder.load().npu_gmm(
+ [activation_input.t()],
+ [total_grad_output],
+ [],
+ group_list,
+ 2,
+ 0)[0]
+
+ # [e, h/y, 2*dff_h/x] -> [e, 2*dff_h/x, h/y]
+ grad_weight_output = grad_weight_output.transpose(1, -1).contiguous()
+
+ # [e, 2*dff_h/x, h/y] -> [2e*dff_h/x, h/y]
+ grad_weight_output = grad_weight_output.view(-1, grad_weight_output.shape[-1])
+ # [2e*dff_h/x, h/y] -> [h/y, 2e*dff_h/x]
+ grad_weight_output = grad_weight_output.transpose(0, 1).contiguous()
+
+ # [n, h/y] partial-x -> [n, h] partial-x
+ grad_gmm_output = sync_gather_along_last_dim(grad_gmm_output, TPYCollectiveComm)
+
+ return grad_gmm_output, grad_weight_output, None
+
+
+class CustomGMM2DFC2(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, activation_input, weight, grouped_mlp_paras):
+ # activation_input shape: [n1/y, e*dff_h/x], weight shape: [e*dff_h/x, h/y]
+
+ ctx.grouped_mlp_paras = grouped_mlp_paras
+ ctx.weight = weight
+
+ num_local_experts = grouped_mlp_paras.get('num_local_experts')
+ hidden_size = grouped_mlp_paras.get('hidden_size')
+ tokens_per_expert = grouped_mlp_paras.get('tokens_per_expert')
+ gemm_fusion = grouped_mlp_paras.get('gemm_fusion')
+ tp_y = grouped_mlp_paras.get('tp_y')
+
+ # [e*dff_h/x, h/y] -> [e, dff_h/x, h/y]
+ w2 = weight.view(num_local_experts, -1, hidden_size // tp_y)
+
+ # [n1/y, e*dff_h/x] -> [n1, e*dff_h/x]
+ total_input = sync_gather_along_first_dim(activation_input, TPYCollectiveComm)
+
+ # unpadding, [n1, e*dff_h/x] -> [n, e*dff_h/x]
+ global G_FORWARD_PADDING_SIZE
+ if G_FORWARD_PADDING_SIZE != 0:
+ real_input_num = total_input.shape[0] - G_FORWARD_PADDING_SIZE
+ total_input = total_input[:real_input_num, :]
+
+ ctx.save_for_backward(total_input)
+
+ # [n, e*dff_h/x] @ [e, dff_h/x, h/y] -> [n, h/y] partial-x
+ fc2_output = gg.ops.gmm(
+ total_input,
+ w2,
+ tokens_per_expert,
+ trans_b=False,
+ gemm_fusion=gemm_fusion,
+ original_weight=weight
+ )
+
+ # [n, h/y] partial-x -> [n, h] partial-x
+ fc2_output = sync_gather_along_last_dim(fc2_output, TPYCollectiveComm)
+
+ return fc2_output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ # grad_output shape: [n, h]
+
+ # activation_input shape: [n, e*dff_h/x]
+ activation_input, = ctx.saved_tensors
+ grouped_mlp_paras = ctx.grouped_mlp_paras
+
+ # weight 2 shape: [e*dff_h/x, h/y]
+ weight = ctx.weight
+
+ num_local_experts = grouped_mlp_paras.get('num_local_experts')
+ tokens_per_expert = grouped_mlp_paras.get('tokens_per_expert')
+ hidden_size = grouped_mlp_paras.get('hidden_size')
+ gemm_fusion = grouped_mlp_paras.get('gemm_fusion')
+ tp_y = grouped_mlp_paras.get('tp_y')
+
+ # weight shape: [e*dff_h/x, h/y] -> [e, dff_h/x, h/y]
+ w2 = weight.view(num_local_experts, -1, hidden_size // tp_y)
+ # [e, dff_h/x, h/y] -> [e, h/y, dff_h/x]
+ w2 = w2.transpose(1, -1).contiguous()
+
+ # [n, h] -> [n, h/y]
+ grad_output = _split_along_last_dim(grad_output, TPYCollectiveComm)
+
+ # [n, h/y] @ [e, h/y, dff_h/x] = [n, e*dff_h/x] partial-y
+ partial_grad_gmm_output = gg.ops.gmm(
+ grad_output,
+ w2,
+ tokens_per_expert,
+ trans_b=False,
+ gemm_fusion=gemm_fusion,
+ )
+
+ # padding for reduce scatter, [n, e*dff_h/x] -> [n1, e*dff_h/x]
+ global G_BACKWARD_PADDING_SIZE
+ n_tokens, h = partial_grad_gmm_output.shape
+ rs_size = TPYCollectiveComm.get_comm_group_world_size()
+ remaining = n_tokens - n_tokens // rs_size * rs_size
+ G_BACKWARD_PADDING_SIZE = rs_size - remaining if remaining else 0
+ if G_BACKWARD_PADDING_SIZE != 0:
+ padding_tensor = torch.zeros(G_BACKWARD_PADDING_SIZE, h, dtype=partial_grad_gmm_output.dtype,
+ device=partial_grad_gmm_output.device)
+ partial_grad_gmm_output = torch.cat((partial_grad_gmm_output, padding_tensor), dim=0)
+
+ # [n1, e*dff_h/x] -> [n1/y, e*dff_h/x]
+ grad_gmm_output = sync_reduce_scatter_along_first_dim(partial_grad_gmm_output, TPYCollectiveComm)
+
+ group_list = torch.cumsum(tokens_per_expert, dim=0)
+ # [e*dff_h/x, n] @ [n, h/y] = [e, dff_h/x, h/y]
+ grad_weight_output = GMMFunction.builder.load().npu_gmm(
+ [activation_input.t()],
+ [grad_output],
+ [],
+ group_list,
+ 2,
+ 0)[0]
+ # [e, dff_h/x, h/y] -> [e*dff_h/x, h/y]
+ grad_weight_output = grad_weight_output.view(-1, grad_weight_output.shape[-1])
+
+ return grad_gmm_output, grad_weight_output, None
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/moe_allgather_token_dispatcher_2d.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/moe_allgather_token_dispatcher_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..564a63dc28ffdb2b0696247e7460180d40ad10c8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/moe_allgather_token_dispatcher_2d.py
@@ -0,0 +1,279 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+from typing import List
+
+import torch
+
+from megatron.core import parallel_state
+from megatron.core.transformer.moe.moe_utils import moe_gather
+from megatron.core.transformer.moe.moe_utils import moe_scatter
+from megatron.core.transformer.moe.token_dispatcher import MoETokenDispatcher
+from megatron.core.transformer.transformer_config import TransformerConfig
+from mindspeed.core.tensor_parallel.comm_autograd_function import \
+ auto_grad_reduce_scatter_along_first_dim
+from mindspeed.core.tensor_parallel.comm_autograd_function import \
+ auto_grad_sync_gather_along_first_dim_rs
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPXEPCollectiveComm
+from mindspeed.core.transformer.moe.token_dispatcher import NewIndePut
+from mindspeed.core.transformer.moe.token_dispatcher import cann_version_check
+
+
+class MoEAllGatherTokenDispatcher2D(MoETokenDispatcher):
+ """
+ AllGather Based Token dispatcher.
+ """
+
+ def __init__(
+ self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig,
+ ) -> None:
+ """
+ Initialize the zero token dropping router.
+ """
+ super().__init__(config=config)
+ self.num_local_experts = num_local_experts
+ self.num_experts = config.num_moe_experts
+ assert self.num_local_experts > 0, "Expected at least one expert"
+ self.local_expert_indices = local_expert_indices
+ assert len(self.local_expert_indices) > 0, "Expected at least one local expert index"
+ self.router_topk = config.moe_router_topk
+ self.add_bias = config.add_bias_linear
+
+ # self.local_probs: probs of global token assignment to local experts.
+ self.local_probs = None
+
+ # self.indices: The indices of `local_indices`
+ self.indices = None
+
+ # self.global_local_map: 2D tensor
+ self.global_local_map = None
+
+ def token_permutation(
+ self, hidden_states: torch.Tensor, topk_probs: torch.Tensor, topk_indices: torch.Tensor
+ ):
+ """Dispatch tokens to local experts. It's composed of two stages:
+ (1) Permute the tokens across the expert parallel devices. After this stage,
+ each device receives all the tokens assigned to its local set of experts
+ in its local HBM.
+ (2) Permute the tokens locally so that they are grouped by their expert
+ assignment.
+ After the stage (1), the tokens are grouped by which device
+ they came from. We re-order them locally for subsequent efficient computation.
+
+ Args:
+ hidden_states: input tokens of shape [s/(cp*x), b, h]
+ topk_probs: probs of local token assignment to global experts
+ with shape: [sb/(cp*x), topK]
+ topk_indices: token assignment to local experts with shape: [sb/(cp*x), topK]
+
+ Returns:
+ permuted_local_hidden_states: Permutation of tokens to local experts group.
+ tokens_per_expert: the number of tokens each local expert to process.
+ """
+
+ self.hidden_shape = hidden_states.shape
+ # [S/TP, B, H] -> [S*B/(cp*x), H]
+ hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
+
+ # Permute the tokens across the expert parallel devices.
+ if TPXCollectiveComm.get_comm_group_world_size() > 1 or self.config.expert_model_parallel_size > 1:
+ # [S*B/(cp*x), H] -> [S*B, H]
+ with torch.no_grad():
+ # [sb/x, topk] -> [sb*ep, topK]
+ global_indices = auto_grad_sync_gather_along_first_dim_rs(topk_indices, TPXEPCollectiveComm)
+
+ # [sb/x, topk] -> [sb*ep, topK]
+ global_probs = auto_grad_sync_gather_along_first_dim_rs(topk_probs, TPXEPCollectiveComm)
+ # [S/x, b, h] -> [sb*ep, h]
+ global_hidden_states = auto_grad_sync_gather_along_first_dim_rs(hidden_states, TPXEPCollectiveComm)
+
+ with torch.no_grad():
+ global_local_mask = (global_indices >= self.local_expert_indices[0]) & (
+ global_indices <= self.local_expert_indices[-1])
+ local_indices = global_indices.masked_select(global_local_mask)
+ self.indices = torch.argsort(local_indices.float(), dim=0)
+ num_global_experts = self.num_local_experts * parallel_state.get_expert_model_parallel_world_size()
+
+ all_tokens_per_expert = torch.histc(global_indices, bins=num_global_experts, min=0,
+ max=num_global_experts - 1, )
+ self.all_tokens_per_expert = all_tokens_per_expert.to(torch.long)
+ tokens_per_expert = self.all_tokens_per_expert[
+ self.local_expert_indices[0]: self.local_expert_indices[-1] + 1]
+ self.global_local_map = global_local_mask.nonzero()[:, 0]
+
+ if self.router_topk > 1:
+ self.local_probs = global_probs.masked_select(global_local_mask)
+ else:
+ self.local_probs = topk_probs
+
+ if cann_version_check:
+ local_hidden_states = global_hidden_states[self.global_local_map, :]
+ else:
+ self.global_local_map = (self.global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1]))
+ local_hidden_states = moe_gather.apply(global_hidden_states, self.global_local_map)
+ else:
+ if self.router_topk > 1:
+ global_local_mask = torch.ones_like(topk_indices).bool()
+ local_indices = topk_indices.masked_select(global_local_mask)
+ self.local_probs = topk_probs.masked_select(global_local_mask)
+ self.global_local_map = global_local_mask.nonzero()[:, 0]
+ if cann_version_check:
+ local_hidden_states = hidden_states[self.global_local_map, :]
+ else:
+ self.global_local_map = self.global_local_map.view(-1, 1).expand(-1, hidden_states.shape[-1])
+ local_hidden_states = torch.gather(hidden_states, 0, self.global_local_map)
+ else:
+ local_indices = topk_indices
+ self.local_probs = topk_probs
+ local_hidden_states = hidden_states
+ self.global_local_map = None
+
+ with torch.no_grad():
+ # The indices of local_indices that give its sorted order along dim 0.
+ self.indices = torch.argsort(local_indices, dim=0)
+ # use 0.7.0 implement for better performance
+ tokens_per_expert = torch.histc(local_indices, bins=self.num_local_experts,
+ min=self.local_expert_indices[0], max=self.local_expert_indices[-1], )
+ tokens_per_expert = tokens_per_expert.to(torch.long)
+ self.all_tokens_per_expert = tokens_per_expert
+
+ if self.num_local_experts > 1:
+ if cann_version_check:
+ permuted_local_hidden_states = local_hidden_states[self.indices, :]
+ else:
+ self.indices = self.indices.view(-1, 1).expand(-1, hidden_states.shape[-1])
+ permuted_local_hidden_states = moe_gather.apply(local_hidden_states, self.indices)
+ else:
+ permuted_local_hidden_states = local_hidden_states
+
+ return permuted_local_hidden_states, tokens_per_expert
+
+
+ def token_unpermutation(
+ self,
+ hidden_states: torch.Tensor,
+ bias: torch.Tensor = None,
+ ):
+ """
+ Reverse process of `dispatch()` which permutes the output of local
+ experts locally and across expert parallel rank into the original order to
+ produce the final output.
+
+ Args:
+ hidden_states: 2D tensor of shape [sum_tokens_of_all_local_experts, HiddenSize],
+ output of local experts.
+ bias (optional): The bias tensor.
+
+ Returns:
+ output_total: un-permuted updated hidden states output from all local experts
+ with shape of [SeqLen/TP, MBS, HiddenSize]
+ """
+ # Stage1: unpermute the tokens and bias locally respectively.
+ scores = self.local_probs.to(dtype=hidden_states.dtype)
+ if self.num_local_experts > 1:
+ if cann_version_check:
+ unpermuted_local_hidden = torch.zeros_like(hidden_states)
+ unpermuted_local_hidden.index_put_((self.indices,),
+ hidden_states[:self.indices.shape[0], :],
+ accumulate=False)
+ else:
+ assert self.indices.shape == hidden_states.shape
+ unpermuted_local_hidden = moe_scatter.apply(hidden_states, self.indices)
+ else:
+ unpermuted_local_hidden = hidden_states
+
+ # Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1.
+ if self.router_topk > 1:
+ unpermuted_local_hidden = unpermuted_local_hidden * scores.view(-1, 1)
+
+ unpermuted_local_bias = None
+ if self.add_bias:
+ assert bias is not None
+ unpermuted_local_bias = torch.zeros_like(hidden_states)
+ if cann_version_check:
+ unpermuted_local_bias.index_put_((self.indices,), bias[:self.indices.shape[0], :],
+ accumulate=False)
+ else:
+ assert self.indices.shape == bias.shape
+ unpermuted_local_bias = unpermuted_local_bias.scatter(0, self.indices, bias)
+ if self.router_topk > 1:
+ unpermuted_local_bias = unpermuted_local_bias * scores.view(-1, 1)
+
+ output_total = unpermuted_local_hidden
+ output_bias_total = unpermuted_local_bias
+
+ # Unpermute the tokens across expert parallel devices.
+ if TPXCollectiveComm.get_comm_group_world_size() > 1 or self.config.expert_model_parallel_size > 1:
+ assert (self.global_local_map is not None), \
+ "global_local_map is necessary for `AllGather`."
+ ep_group_size = TPXEPCollectiveComm.get_comm_group_world_size()
+ # hidden_shape: [SeqLen/TP, MBS, HiddenSize], glboal_num_tokens = SeqLen/TP*MBS*(TP*EP)
+ global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1] * ep_group_size
+ global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
+ if cann_version_check:
+ unpermuted_global_hidden = torch.zeros(global_hidden_shape, dtype=torch.float,
+ device=torch.cuda.current_device())
+ unpermuted_global_hidden = NewIndePut.apply(unpermuted_global_hidden,
+ (self.global_local_map,),
+ unpermuted_local_hidden[
+ :self.global_local_map.shape[0], :])
+ else:
+ assert self.global_local_map.shape == unpermuted_local_hidden.shape
+ unpermuted_global_hidden = moe_scatter.apply(unpermuted_local_hidden,
+ self.global_local_map, global_hidden_shape)
+
+ output_total = auto_grad_reduce_scatter_along_first_dim(unpermuted_global_hidden, TPXEPCollectiveComm)
+ if self.add_bias:
+ # Unpermute the bias across expert parallel devices.
+ unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
+ if cann_version_check:
+ unpermuted_global_bias.index_put_((self.global_local_map,),
+ unpermuted_local_bias[
+ :self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ unpermuted_global_bias = unpermuted_global_bias.scatter_add(0,
+ self.global_local_map, unpermuted_local_bias)
+
+ output_bias_total = auto_grad_reduce_scatter_along_first_dim(unpermuted_global_bias,
+ TPXEPCollectiveComm)
+ # bias is duplicated across tensor parallelism ranks;
+ # reduce scatter reduces bias across tensor parallel_ranks
+ output_bias_total = (output_bias_total /
+ TPXCollectiveComm.get_comm_group_world_size())
+ else:
+ if self.router_topk > 1:
+ global_num_tokens = self.hidden_shape[0] * self.hidden_shape[1]
+ global_hidden_shape = [global_num_tokens, hidden_states.shape[-1]]
+ unpermuted_global_hidden = torch.zeros(global_hidden_shape,
+ dtype=hidden_states.dtype, device=torch.cuda.current_device(), )
+ if cann_version_check:
+ output_total = unpermuted_global_hidden.index_put((self.global_local_map,),
+ unpermuted_local_hidden[
+ :self.global_local_map.shape[
+ 0], :], accumulate=True)
+ else:
+ output_total = unpermuted_global_hidden.scatter_add(0, self.global_local_map,
+ unpermuted_local_hidden)
+ if self.add_bias:
+ unpermuted_global_bias = torch.zeros_like(unpermuted_global_hidden)
+ if cann_version_check:
+ output_bias_total = unpermuted_global_bias.index_put(
+ (self.global_local_map,),
+ unpermuted_local_bias[:self.global_local_map.shape[0], :],
+ accumulate=True)
+ else:
+ output_bias_total = unpermuted_global_bias.scatter_add(0,
+ self.global_local_map, unpermuted_local_bias)
+
+ if self.router_topk == 1:
+ output_total = output_total * scores
+ output_total = output_total.view(self.hidden_shape)
+ if self.add_bias:
+ assert output_bias_total is not None
+ if self.router_topk == 1:
+ output_bias_total = output_bias_total * scores
+ output_bias_total = output_bias_total.view(self.hidden_shape)
+ else:
+ output_bias_total = None
+
+ return output_total, output_bias_total
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/moe_layer_2d.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/moe_layer_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..68bed32233574f4fa1f5bb7d0c26ff7ea5080806
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/moe_layer_2d.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2023; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+
+from megatron.core.transformer.mlp import MLPSubmodules
+from megatron.core.transformer.moe.moe_layer import BaseMoELayer
+from megatron.core.transformer.transformer_config import TransformerConfig
+from mindspeed.core.tensor_parallel.comm_autograd_function import auto_grad_scatter_along_last_dim
+from mindspeed.core.tensor_parallel.comm_autograd_function import \
+ auto_grad_sync_gather_along_last_dim
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+from mindspeed.core.transformer.moe.tp_2d.grouped_mlp_2d import GroupedMLP2D
+from mindspeed.core.transformer.moe.tp_2d.moe_allgather_token_dispatcher_2d import \
+ MoEAllGatherTokenDispatcher2D
+from mindspeed.core.transformer.moe.tp_2d.sequential_mlp_2d import SequentialMLP2D
+from mindspeed.core.transformer.moe.tp_2d.topk_router_2d import TopKRouter2D
+
+
+class MoELayer2D(BaseMoELayer):
+ """2D Mixture of experts Layer **currently only supports allgather gmm**.
+
+ """
+
+ def __init__(
+ self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
+ ):
+ super(MoELayer2D, self).__init__(config=config, layer_number=layer_number)
+ self.submodules = submodules
+ self.router = TopKRouter2D(config=self.config)
+ if self.config.moe_grouped_gemm:
+ self.experts = GroupedMLP2D(self.num_local_experts, self.config)
+ else:
+ assert isinstance(self.submodules, MLPSubmodules)
+ self.experts = SequentialMLP2D(self.num_local_experts, self.config, self.submodules)
+ if config.moe_token_dispatcher_type == "allgather":
+ self.token_dispatcher = MoEAllGatherTokenDispatcher2D(
+ self.num_local_experts, self.local_expert_indices, config=self.config
+ )
+ else:
+ raise ValueError(
+ f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
+ )
+ self.moe_layer_recompute = config.moe_layer_recompute
+
+ def forward(self, hidden_states: torch.Tensor):
+ # [s/x, b, h/y] -> [s/x, b, h]
+ hidden_states = auto_grad_sync_gather_along_last_dim(hidden_states, TPYCollectiveComm)
+
+ # [sb/x, h] => [sb/x, topK], [sb/x, topK]
+ topk_probs, topk_indices = self.router(hidden_states)
+
+ (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(hidden_states, topk_probs,
+ topk_indices)
+ expert_output, bias = self.experts(dispatched_input, tokens_per_expert)
+ output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, bias)
+
+ # [s/x, b, h] -> [s/x, b, h/y]
+ output = auto_grad_scatter_along_last_dim(output, TPYCollectiveComm)
+ if mlp_bias:
+ mlp_bias = auto_grad_scatter_along_last_dim(mlp_bias, TPYCollectiveComm)
+
+ return output, mlp_bias
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/sequential_mlp_2d.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/sequential_mlp_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..b76d1fa64d0e488f4c9945051de8066f25557062
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/sequential_mlp_2d.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+from megatron.core import parallel_state
+from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
+from megatron.core.transformer.module import MegatronModule
+from megatron.core.transformer.mlp import MLP, MLPSubmodules
+from megatron.core.transformer.transformer_config import TransformerConfig
+
+
+class SequentialMLP2D(MegatronModule):
+ """An implementation of the Experts layer using a sequence of MLP layers.
+ This class executes each expert sequentially.
+ """
+
+ def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
+ super().__init__(config=config)
+ self.add_bias = config.add_bias_linear
+ self.moe_extended_tp = config.moe_extended_tp
+ self.num_local_experts = num_local_experts
+ self.local_experts = torch.nn.ModuleList()
+ for _ in range(self.num_local_experts):
+ expert = MLP(self.config, submodules, is_expert=True)
+ self.local_experts.append(expert)
+
+ def forward(self, permuted_local_hidden_states, tokens_per_expert):
+
+ output_local = torch.zeros_like(permuted_local_hidden_states)
+ output_bias_local = None
+ if self.add_bias:
+ output_bias_local = torch.zeros_like(permuted_local_hidden_states)
+
+ cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
+ # Insert zero at the begining for offset index's convenience
+ zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
+ cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
+ for expert_num, expert in enumerate(self.local_experts):
+ start = cumsum_num_tokens[expert_num]
+ end = cumsum_num_tokens[expert_num + 1]
+ hidden = permuted_local_hidden_states[start:end]
+ output, output_bias = expert(hidden)
+
+ output_local[start:end] = output
+ if self.add_bias:
+ output_bias = output_bias.expand_as(output)
+ output_bias_local[start:end, :] = output_bias
+ return output_local, output_bias_local
+
+ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
+ """Maps local expert to global experts."""
+ if self.moe_extended_tp:
+ raise NotImplementedError(
+ 'Currently distributed checkpointing is not supported for moe_extended_tp'
+ )
+
+ sharded_state_dict = {}
+ num_global_experts = (
+ parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
+ )
+ local_expert_indices_offset = (
+ parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
+ )
+
+ expert_sharded_prefix = f'{prefix}experts.'
+ for expert_local_idx, expert in enumerate(self.local_experts):
+ expert_global_idx = local_expert_indices_offset + expert_local_idx
+ expert_state_dict_prefix = f'{prefix}local_experts.{expert_local_idx}.'
+ expert_sharded_offsets = (
+ *sharded_offsets,
+ (len(sharded_offsets), expert_global_idx, num_global_experts),
+ )
+
+ expert_state_dict = expert.sharded_state_dict(
+ expert_state_dict_prefix, expert_sharded_offsets, metadata
+ )
+ # Remove expert layers indexing from sharded keys
+ replace_prefix_for_sharding(
+ expert_state_dict, expert_state_dict_prefix, expert_sharded_prefix
+ )
+ # Adjust replica ids - replication along DP modulo EP
+ for k, sh_ten in expert_state_dict.items():
+ replica_id = sh_ten.replica_id
+ assert (
+ len(replica_id) == 3
+ ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
+ sh_ten.replica_id = (
+ *replica_id[:2],
+ parallel_state.get_data_modulo_expert_parallel_rank(with_context_parallel=True),
+ )
+
+ sharded_state_dict.update(expert_state_dict)
+ return sharded_state_dict
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/topk_router_2d.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/topk_router_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dea9493326c2458cbf3da699b2fddeb59c34c7f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/tp_2d/topk_router_2d.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2023; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+
+from megatron.core.transformer import TransformerConfig
+from megatron.core.transformer.moe.moe_utils import save_to_aux_losses_tracker
+from megatron.core.transformer.moe.moe_utils import switch_load_balancing_loss_func
+from megatron.core.transformer.moe.moe_utils import z_loss_func
+from megatron.core.transformer.moe.router import TopKRouter
+from mindspeed.core.tensor_parallel_x_union_cp import TensorParallelXUnionCP
+from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
+from mindspeed.moe.utils import MoEAuxLossAutoScaler
+
+
+class TopKRouter2D(TopKRouter):
+ """Route each token to the top-k experts for 2-D tensor parallel."""
+
+ def __init__(self, config: TransformerConfig) -> None:
+ super().__init__(config)
+ setattr(self.weight, 'sequence_parallel', False)
+ setattr(self.weight, "2d_tp", True)
+
+ def apply_load_balancing_loss(self, probs: torch.Tensor,
+ num_local_tokens_per_expert: torch.Tensor,
+ activation: torch.Tensor):
+ """Applies auxiliary loss to the MoE layer.
+
+ Args:
+ probs (torch.Tensor): The probs output by the router for each token, with shape:
+ [sb / (x * cp), E]
+ num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert
+ with shape: [E]
+ activation, ie top_k_probs (torch.Tensor): The activation tensor to attach the gradient function to.
+ with shape: [sb/(x*cp), topK]
+
+ Returns:
+ torch.Tensor: The activation tensor with the attached gradient function with
+ shape: [sb/(x*cp), topK]
+ """
+ moe_aux_loss_coeff = self.config.moe_aux_loss_coeff
+ sequence_partition_group = None
+ if self.config.moe_token_dispatcher_type == "allgather":
+ tp_x_cp_group = TensorParallelXUnionCP()
+ sequence_partition_group = tp_x_cp_group.group
+ elif self.config.moe_token_dispatcher_type == "alltoall":
+ tp_y_cp_group = TensorParallelYUnionCP()
+ sequence_partition_group = tp_y_cp_group.group
+ moe_aux_loss_coeff /= tp_y_cp_group.get_parallel_group_world_size()
+
+ aux_loss = switch_load_balancing_loss_func(probs, num_local_tokens_per_expert, self.topk,
+ moe_aux_loss_coeff, sequence_partition_group=sequence_partition_group)
+ save_to_aux_losses_tracker("load_balancing_loss", aux_loss / moe_aux_loss_coeff,
+ self.layer_number, self.config.num_layers, reduce_group=sequence_partition_group)
+
+ activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
+ return activation
+
+ def apply_z_loss(self, logits):
+ """Encourages the router's logits to remain small to enhance stability.
+ Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
+
+ Args:
+ logits (torch.Tensor): The logits of the router.
+
+ Returns:
+ torch.Tensor: The logits after applying the z-loss.
+ """
+ if self.config.moe_z_loss_coeff is not None and self.training:
+ moe_z_loss_coeff = (
+ self.config.moe_z_loss_coeff
+ / TensorParallelXUnionCP().get_parallel_group_world_size()
+ )
+ z_loss = z_loss_func(logits, moe_z_loss_coeff)
+ logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
+ save_to_aux_losses_tracker(
+ "z_loss",
+ z_loss / moe_z_loss_coeff,
+ self.layer_number,
+ self.config.num_layers,
+ )
+ return logits
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/moe/unpermute_without_activation.py b/model/train/yoco_moe/mindspeed/core/transformer/moe/unpermute_without_activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9826ae279bb898fd7bad680920bfb2b79932a1af
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/moe/unpermute_without_activation.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+
+import torch
+import torch.distributed
+import torch.distributed as dist
+import torch_npu
+from megatron.training import get_args
+from mindspeed.core.transformer.moe.moe_utils import (set_swap_status, get_swap_status,
+ set_prob_backward_need_tensors, get_swap_stream)
+
+
+class UnpermuteWithoutActivation(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx,
+ permuted_tokens: torch.Tensor,
+ sorted_indices: torch.Tensor,
+ probs: torch.Tensor = None,
+ padded_mode: bool = False,
+ restore_shape: torch.Size = None,
+ ):
+ """Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
+
+ Args:
+ permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
+ sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
+ probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
+ padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
+ restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None.
+
+ Returns:
+ torch.Tensor: The unpermuted tokens, optionally merged with probabilities.
+ """
+ moe_hierarchical_alltoallv = get_args().moe_hierarchical_alltoallv
+ if padded_mode:
+ raise ValueError("moe-zero-memory temporally does not support padded mode")
+
+ if sorted_indices.numel() != permuted_tokens.size(0):
+ raise AssertionError("")
+ saved_tensors = [sorted_indices]
+
+ with torch.no_grad():
+ if probs is not None:
+ # Unpermute and merge the tokens with their probabilities
+ num_unpermuted_tokens = probs.numel()
+ saved_tensors.append(probs)
+ ctx.topk = probs.size(1)
+ ctx.probs_shape = probs.shape
+ ctx.probs_dtype = probs.dtype
+ else:
+ # Unpermute the tokens without merge
+ num_unpermuted_tokens = permuted_tokens.size(0)
+ ctx.topk = 1
+ ctx.save_for_backward(*saved_tensors)
+ if moe_hierarchical_alltoallv:
+ unpermuted_tokens = torch.zeros(
+ [ctx.topk * probs.shape[0], permuted_tokens.shape[-1]],
+ dtype=permuted_tokens.dtype,
+ device=permuted_tokens.device,
+ )
+ unpermuted_tokens = \
+ unpermuted_tokens.scatter(0, sorted_indices.unsqueeze(1).expand(-1, permuted_tokens.shape[1]),
+ permuted_tokens)
+ elif not get_args().use_fused_moe_token_permute_and_unpermute:
+ unpermuted_tokens = torch.zeros(
+ [num_unpermuted_tokens, permuted_tokens.shape[-1]],
+ dtype=permuted_tokens.dtype,
+ device=permuted_tokens.device,
+ )
+ unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
+ else:
+ unpermuted_tokens = permuted_tokens.index_select(0, sorted_indices)
+
+ ctx.permuted_tokens_shape = permuted_tokens.shape
+ ctx.unpermuted_tokens_shape = unpermuted_tokens.shape
+ unpermuted_tokens = unpermuted_tokens.reshape(-1, ctx.topk, permuted_tokens.size(-1))
+ permuted_tokens.untyped_storage().resize_(0)
+
+ if probs is not None:
+ tensor_to_swap = unpermuted_tokens
+ unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
+ swap_stream, last_tensor = get_swap_status()
+ if last_tensor is not None:
+ torch.npu.current_stream().wait_stream(swap_stream)
+ last_tensor.untyped_storage().resize_(0)
+ forward_event = torch.npu.Event()
+ forward_event.record()
+ set_swap_status(tensor_to_swap)
+ ctx.tensor_cpu = torch.empty(tensor_to_swap.shape, dtype=tensor_to_swap.dtype, pin_memory=True, device='cpu')
+ with torch_npu.npu.stream(swap_stream):
+ swap_stream.wait_event(forward_event)
+ ctx.tensor_cpu.untyped_storage().copy_(tensor_to_swap.untyped_storage(), non_blocking=True)
+ ctx.swap_event = torch.npu.Event()
+ ctx.swap_event.record()
+
+ ctx.matmul_output_shape = unpermuted_tokens.shape
+ unpermuted_tokens = unpermuted_tokens.sum(dim=1)
+
+ return unpermuted_tokens
+
+ @staticmethod
+ def backward(ctx, *args):
+ moe_hierarchical_alltoallv = get_args().moe_hierarchical_alltoallv
+ if ctx.topk > 1:
+ (indices, probs) = ctx.saved_tensors
+ else:
+ (indices,) = ctx.saved_tensors
+ ctx.save_for_backward()
+
+ if ctx.topk > 1:
+ matmul_output_grad = args[0].unsqueeze(dim=1).expand(ctx.matmul_output_shape)
+ backward_event1 = torch.npu.Event()
+ backward_event1.record()
+ swap_stream = get_swap_stream()
+ unpermuted_tokens = torch.empty(ctx.tensor_cpu.shape, dtype=ctx.tensor_cpu.dtype, device=torch.npu.current_device())
+ with torch_npu.npu.stream(swap_stream):
+ swap_stream.wait_event(backward_event1)
+ swap_stream.wait_event(ctx.swap_event)
+ unpermuted_tokens.untyped_storage().copy_(ctx.tensor_cpu.untyped_storage(), non_blocking=True)
+ ctx.tensor_cpu = None
+ unpermuted_tokens_grad = matmul_output_grad * probs.unsqueeze(-1)
+ set_prob_backward_need_tensors(matmul_output_grad, unpermuted_tokens)
+ else:
+ unpermuted_tokens_grad = args[0]
+
+ unpermuted_tokens_grad = unpermuted_tokens_grad.reshape(ctx.unpermuted_tokens_shape)
+
+ if not get_args().use_fused_moe_token_permute_and_unpermute or moe_hierarchical_alltoallv:
+ permuted_tokens_grad = unpermuted_tokens_grad.index_select(0, indices)
+ else:
+ permuted_tokens_grad = torch.zeros(size=ctx.permuted_tokens_shape, dtype=unpermuted_tokens_grad.dtype,
+ device=torch.npu.current_device())
+ permuted_tokens_grad.index_add_(0, indices, unpermuted_tokens_grad)
+
+ return permuted_tokens_grad, None, None, None, None
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/transformer.py b/model/train/yoco_moe/mindspeed/core/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c35573936e684268fa791e9a3895f90359a751b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/transformer.py
@@ -0,0 +1,289 @@
+import enum
+import os
+from functools import wraps
+
+from contextlib import nullcontext
+import torch
+import torch_npu
+import torch.nn.functional as F
+
+from megatron import core
+from megatron.training import get_args
+from megatron.core.num_microbatches_calculator import get_num_microbatches
+from megatron.core import tensor_parallel, parallel_state, mpu
+from megatron.core.utils import make_viewless_tensor
+from megatron.legacy.model.transformer import bias_dropout_add_fused_train, get_bias_dropout_add, bias_dropout_add_fused_inference
+from megatron.legacy.model.enums import AttnMaskType, LayerType, AttnType
+from mindspeed.model.transformer import should_recompute_activation
+from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
+from mindspeed.core.fusions.fused_bias_swiglu import fused_swiglu
+from mindspeed.core.transformer.moe.moe_utils import only_recompute_activation
+
+
+def parallel_transformer_layer_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ from megatron.core.transformer.moe.moe_layer import MoELayer
+ from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP
+ fn(self, *args, **kwargs)
+ if self.mlp.__class__ is MoELayer:
+ if self.mlp.experts.__class__ is GroupedMLP:
+ self.mlp.experts.layer_number = self.layer_number
+ if self.mlp.experts.__class__ is SequentialMLP:
+ for expert in self.mlp.experts.local_experts:
+ expert.layer_number = self.layer_number
+ global_args = get_args()
+ if global_args.n_shared_experts:
+ self.mlp.shared_experts.layer_number = self.layer_number
+ else:
+ self.mlp.layer_number = self.layer_number
+
+ return wrapper
+
+
+def parallel_transformer_checkpointed_forward_wrapper(forward_func):
+ @wraps(forward_func)
+ def row_parallel_forward(*args, **kwargs):
+ global_args = get_args()
+ if global_args.recompute_method != 'block' and not global_args.swap_attention:
+ output = forward_func(*args, **kwargs)
+ else:
+ output = parallel_transformer_checkpointed_forward(*args, **kwargs)
+ return output
+
+ return row_parallel_forward
+
+
+def parallel_transformer_checkpointed_forward(self, hidden_states, attention_mask,
+ encoder_output, enc_dec_attn_mask,
+ rotary_pos_emb, is_first_microbatch):
+ """Forward method with activation checkpointing."""
+
+ def custom(start, end):
+ def custom_forward(*args, **kwargs):
+ x_, *args = args
+ for index in range(start, end):
+ layer = self._get_layer(index)
+ x_ = layer(x_, *args, **kwargs)
+ return x_
+
+ return custom_forward
+
+ global_args = get_args()
+ num_layers_per_pipeline_rank = global_args.num_layers // global_args.pipeline_model_parallel_size
+ if self.recompute_method == 'uniform':
+ # Uniformly divide the total number of Transformer layers and
+ # checkpoint the input activation of each divided chunk.
+ # A method to further reduce memory usage reducing checkpoints.
+ if not global_args.swap_attention:
+ l = 0
+ while l < num_layers_per_pipeline_rank:
+ hidden_states = tensor_parallel.checkpoint(
+ custom(l, l + self.recompute_num_layers),
+ self.distribute_saved_activations,
+ hidden_states, attention_mask,
+ encoder_output, enc_dec_attn_mask,
+ None, None, None, None, rotary_pos_emb)
+
+ l += self.recompute_num_layers
+ else:
+ for l in range(num_layers_per_pipeline_rank):
+ hidden_states = custom(l, l + 1)(
+ hidden_states, attention_mask,
+ encoder_output, enc_dec_attn_mask,
+ None, None, None, None, rotary_pos_emb)
+ elif self.recompute_method == 'block':
+ # Checkpoint the input activation of only a set number of individual
+ # Transformer layers and skip the rest.
+ # A method fully use the device memory removing redundant re-computation.
+ vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
+ vpp_size = global_args.virtual_pipeline_model_parallel_size
+ if vpp_rank is None or not global_args.enable_recompute_layers_per_pp_rank:
+ vpp_rank = 0
+ if vpp_size is None or not global_args.enable_recompute_layers_per_pp_rank:
+ vpp_size = 1
+ for l in range(self.num_layers):
+ # The number of layers each pipeline rank recomputes is self.recompute_num_layers.
+ # If self.recompute_num_layers cannot divide exactly the number of layers in each pp rank,
+ # we try to balance the number of recomputed layers in each model chunk.
+ # e.g. with 8 layers, 2 stages, and 2 virtual stages, the assignment of
+ # layers to stages like (each list is a model chunk):
+ # Stage 0: [0, 1] [4, 5]
+ # Stage 1: [2, 3] [6, 7]
+ # With self.recompute_num_layers = 2, we will recompute layers 0,4 for stage 0, and 2,6 for stage 1.
+ # With self.recompute_num_layers = 3, we will recompute layers 0,1,4 for stage 0, and 2,3,6 for stage 1.
+ def should_recompute():
+ if global_args.reduce_recompute_for_last_chunk:
+ def is_last_layer():
+ return (l == self.num_layers - 1) and mpu.is_pipeline_last_stage()
+
+ return ((l * vpp_size + vpp_rank) < self.recompute_num_layers) and not is_last_layer()
+ else:
+ return (l * vpp_size + vpp_rank) < self.recompute_num_layers
+
+ if should_recompute() and not global_args.swap_attention:
+ hidden_states = tensor_parallel.checkpoint(
+ custom(l, l + 1),
+ self.distribute_saved_activations,
+ hidden_states, attention_mask,
+ encoder_output, enc_dec_attn_mask,
+ None, None, None, None, rotary_pos_emb)
+ else:
+ hidden_states = custom(l, l + 1)(
+ hidden_states, attention_mask,
+ encoder_output, enc_dec_attn_mask,
+ None, None, None, None, rotary_pos_emb)
+ else:
+ raise ValueError("Invalid activation recompute method.")
+
+ return hidden_states
+
+
+def core_mlp_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ self.layer_number = getattr(self, "layer_number", None)
+ is_recompute_activation = should_recompute_activation(self.layer_number)
+ if get_args().moe_alltoall_overlap_comm and not isinstance(args[-1], torch.Tensor):
+ moe_ctx = args[-1]
+ args = args[:-1]
+
+ def activation_function(*function_args):
+ intermediate, bias = function_args
+ if bias is not None:
+ intermediate = intermediate + bias
+ if self.config.gated_linear_unit:
+ assert (self.config.activation_func == F.silu), 'Activation function must be silu when using fused_swiglu'
+ if not hasattr(self, 'origin_activation_func'):
+ self.origin_activation_func = self.activation_func
+ self.activation_func = fused_swiglu
+ intermediate = self.activation_func(intermediate)
+ else:
+ intermediate = self.activation_func(intermediate)
+
+ return intermediate
+
+ moe_zero_memory = get_args().moe_zero_memory
+ if not (is_recompute_activation or moe_zero_memory != "disable"):
+ if hasattr(self, 'origin_activation_func'):
+ self.activation_func = self.origin_activation_func
+ output, output_bias = fn(self, *args, **kwargs)
+ elif moe_zero_memory == "level1" and not only_recompute_activation(self.layer_number):
+ if self.shared_expert:
+ self.activation_function = activation_function
+ hidden_states = args[0]
+ fc1_out_parallel, bias_parallel = self.linear_fc1(hidden_states)
+ act_out_parallel = activation_function(fc1_out_parallel, bias_parallel)
+ output, output_bias = self.linear_fc2(act_out_parallel)
+ fc1_out_parallel.untyped_storage().resize_(0)
+ act_out_parallel.untyped_storage().resize_(0)
+ moe_ctx.shared_fc1_out = fc1_out_parallel
+ moe_ctx.shared_act_out = act_out_parallel
+ else:
+ output, output_bias = fn(self, *args, **kwargs)
+ else:
+ hidden_states = args[0]
+ intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
+ self.activation_checkpoint_manager = CheckpointWithoutOutput()
+ intermediate_parallel = self.activation_checkpoint_manager.checkpoint(activation_function,
+ False,
+ intermediate_parallel,
+ bias_parallel)
+ # [s, b, h]
+ output, output_bias = self.linear_fc2(intermediate_parallel)
+
+ # discard the output of the activation function,
+ # which will be restored by recomputation during backward.
+ self.activation_checkpoint_manager.discard_output()
+
+ # when backward to output of dense_4h_to_h,
+ # recompute and restore the output of activation function.
+ if output.requires_grad:
+ output.register_hook(self.activation_checkpoint_manager.recompute)
+ return output, output_bias
+ return wrapper
+
+
+def norm_recompute_forward(
+ self,
+ hidden_states,
+ attention_mask,
+ context=None,
+ context_mask=None,
+ rotary_pos_emb=None,
+ inference_params=None,
+ packed_seq_params=None,
+):
+ # hidden_states: [s, b, h]
+
+ # Residual connection.
+ residual = hidden_states
+
+ # Optional Input Layer norm
+ self.norm_ckpt1 = CheckpointWithoutOutput()
+ input_layernorm_output = self.norm_ckpt1.checkpoint(self.input_layernorm, False, hidden_states)
+
+ # Self attention.
+ attention_output_with_bias = self.self_attention(
+ input_layernorm_output,
+ attention_mask=attention_mask,
+ inference_params=inference_params,
+ rotary_pos_emb=rotary_pos_emb,
+ packed_seq_params=packed_seq_params,
+ )
+
+ self.norm_ckpt1.discard_output()
+ if self.training:
+ attention_output_with_bias[0].register_hook(self.norm_ckpt1.recompute)
+
+ with self.bias_dropout_add_exec_handler():
+ hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
+ attention_output_with_bias, residual, self.hidden_dropout
+ )
+
+ # Residual connection.
+ residual = hidden_states
+
+ # Optional Layer norm after self-attention
+ pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)
+
+ # Cross attention.
+ attention_output_with_bias = self.cross_attention(
+ pre_cross_attn_layernorm_output,
+ attention_mask=context_mask,
+ key_value_states=context,
+ inference_params=inference_params,
+ )
+
+ if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
+ context = attention_output_with_bias["context"]
+
+ with self.bias_dropout_add_exec_handler():
+ hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
+ attention_output_with_bias, residual, self.hidden_dropout
+ )
+
+ # Residual connection.
+ residual = hidden_states
+
+ # Optional Layer norm post the cross-attention.
+ self.norm_ckpt2 = CheckpointWithoutOutput()
+ pre_mlp_layernorm_output = self.norm_ckpt2.checkpoint(self.pre_mlp_layernorm, False, hidden_states)
+
+ # MLP.
+ mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
+
+ if self.training:
+ self.norm_ckpt2.discard_output()
+ mlp_output_with_bias[0].register_hook(self.norm_ckpt2.recompute)
+
+ with self.bias_dropout_add_exec_handler():
+ hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
+ mlp_output_with_bias, residual, self.hidden_dropout
+ )
+
+ output = make_viewless_tensor(
+ inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
+ )
+
+ return output, context
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/transformer_block.py b/model/train/yoco_moe/mindspeed/core/transformer/transformer_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad0063d762af0de2127a88e90860b74e9326652d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/transformer_block.py
@@ -0,0 +1,223 @@
+from functools import wraps
+import torch
+from torch import Tensor
+
+from megatron.core import tensor_parallel, parallel_state, mpu
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.transformer.module import MegatronModule
+from megatron.core.transformer.spec_utils import build_module
+from megatron.training import get_args
+from megatron.core.transformer.custom_layers.transformer_engine import TENorm
+from mindspeed.core.tensor_parallel.comm_autograd_function import auto_grad_sync_gather_along_last_dim, \
+ auto_grad_sync_gather_along_first_dim
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm, TPYCollectiveComm
+
+
+def transformer_block_checkpointed_forward_wrapper(forward_func):
+ @wraps(forward_func)
+ def row_parallel_forward(*args, **kwargs):
+ global_args = get_args()
+ if global_args.recompute_method != 'block' and not global_args.swap_attention:
+ output = forward_func(*args, **kwargs)
+ else:
+ output = transformer_block_checkpointed_forward(*args, **kwargs)
+ return output
+
+ return row_parallel_forward
+
+
+def transformer_block_checkpointed_forward(
+ self,
+ hidden_states: Tensor,
+ attention_mask: Tensor,
+ context: Tensor,
+ context_mask: Tensor,
+ rotary_pos_emb: Tensor,
+ packed_seq_params: PackedSeqParams,
+):
+ """Forward method with activation checkpointing."""
+
+ def custom(start: int, end: int):
+ def custom_forward(
+ hidden_states,
+ attention_mask,
+ context,
+ context_mask,
+ rotary_pos_emb,
+ ):
+ for index in range(start, end):
+ layer = self._get_layer(index)
+ hidden_states, context = layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ context=context,
+ context_mask=context_mask,
+ rotary_pos_emb=rotary_pos_emb,
+ inference_params=None,
+ packed_seq_params=packed_seq_params,
+ )
+ return hidden_states, context
+
+ return custom_forward
+
+ def checkpoint_handler(forward_func):
+ if self.config.fp8:
+ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
+
+ return te_checkpoint(
+ forward_func,
+ self.config.distribute_saved_activations,
+ tensor_parallel.random.get_cuda_rng_tracker,
+ parallel_state.get_tensor_model_parallel_group(),
+ hidden_states,
+ attention_mask,
+ context,
+ context_mask,
+ rotary_pos_emb,
+ )
+ else:
+ return tensor_parallel.checkpoint(
+ forward_func,
+ self.config.distribute_saved_activations,
+ hidden_states,
+ attention_mask,
+ context,
+ context_mask,
+ rotary_pos_emb,
+ )
+
+ # Checkpoint the input activation of only a set number of individual
+ # Transformer layers and skip the rest.
+ # A method fully use the device memory removing redundant re-computation.
+ global_args = get_args()
+ if self.config.recompute_method == 'uniform':
+ # Uniformly divide the total number of Transformer layers and
+ # checkpoint the input activation of each divided chunk.
+ # A method to further reduce memory usage reducing checkpoints.
+ if not global_args.swap_attention:
+ l = 0
+ while l < self.num_layers_per_pipeline_rank:
+ hidden_states = checkpoint_handler(custom(l, l + 1))
+
+ l += self.config.recompute_num_layers
+ else:
+ for l in range(self.num_layers_per_pipeline_rank):
+ hidden_states, context = custom(l, l + 1)(
+ hidden_states,
+ attention_mask,
+ context,
+ context_mask,
+ rotary_pos_emb,
+ )
+ elif self.config.recompute_method == 'block':
+ vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
+ vpp_size = self.config.virtual_pipeline_model_parallel_size
+ if vpp_rank is None or not global_args.enable_recompute_layers_per_pp_rank:
+ vpp_rank = 0
+ if vpp_size is None or not global_args.enable_recompute_layers_per_pp_rank:
+ vpp_size = 1
+ for l in range(self.num_layers_per_pipeline_rank):
+ # The number of layers each pipeline rank recomputes is self.recompute_num_layers.
+ # If self.recompute_num_layers cannot divide exactly the number of layers in each pp rank,
+ # we try to balance the number of recomputed layers in each model chunk.
+ # e.g. with 8 layers, 2 stages, and 2 virtual stages, the assignment of
+ # layers to stages like (each list is a model chunk):
+ # Stage 0: [0, 1] [4, 5]
+ # Stage 1: [2, 3] [6, 7]
+ # With self.recompute_num_layers = 2, we will recompute layers 0,4 for stage 0, and 2,6 for stage 1.
+ # With self.recompute_num_layers = 3, we will recompute layers 0,1,4 for stage 0, and 2,3,6 for stage 1.
+ def should_recompute():
+ if global_args.reduce_recompute_for_last_chunk:
+ def is_last_layer():
+ return (l == self.num_layers_per_pipeline_rank - 1) and mpu.is_pipeline_last_stage()
+
+ return ((l * vpp_size + vpp_rank) < self.config.recompute_num_layers) and not is_last_layer()
+ else:
+ return (l * vpp_size + vpp_rank) < self.config.recompute_num_layers
+
+ if should_recompute() and not global_args.swap_attention:
+ hidden_states, context = checkpoint_handler(custom(l, l + 1))
+ else:
+ hidden_states, context = custom(l, l + 1)(
+ hidden_states,
+ attention_mask,
+ context,
+ context_mask,
+ rotary_pos_emb,
+ )
+
+ return hidden_states
+
+
+class NoopTransformerLayer(MegatronModule):
+ def __init__(self, layer_number):
+ super().__init__(None)
+ self.layer_number = layer_number
+
+ def forward(self, hidden_states, attention_mask, context, context_mask, rotary_pos_emb, inference_params, packed_seq_params):
+ return hidden_states.clone(), context
+
+
+def _get_layer_offset(args):
+ num_layers = args.num_layers
+ pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
+
+ num_layers_per_pipeline_rank = (
+ num_layers // parallel_state.get_pipeline_model_parallel_world_size()
+ )
+
+ if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
+ vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
+ vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
+
+ total_num_layers = num_layers
+ num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
+ total_virtual_chunks = total_num_layers // vp_size
+ offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
+
+ else:
+ # Each stage gets a contiguous set of layers.
+ if parallel_state.get_pipeline_model_parallel_world_size() > 1:
+ offset = pipeline_rank * num_layers_per_pipeline_rank
+ else:
+ offset = 0
+ return offset
+
+
+def _build_layers(self):
+ args = get_args()
+
+ def build_layer(layer_spec, layer_number):
+ global_layer_number = _get_layer_offset(args) + layer_number
+ if (hasattr(args, 'noop_layers') and isinstance(args.noop_layers, set)
+ and global_layer_number - 1 in args.noop_layers):
+ return NoopTransformerLayer(global_layer_number)
+ return build_module(layer_spec, config=self.config, layer_number=layer_number,)
+
+ self.layers = torch.nn.ModuleList(
+ [
+ build_layer(layer_spec, i + 1)
+ for i, layer_spec in enumerate(self.submodules.layer_specs)
+ ]
+ )
+
+ if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
+ self.final_layernorm = build_module(
+ self.submodules.layer_norm,
+ config=self.config,
+ hidden_size=self.config.hidden_size,
+ eps=self.config.layernorm_epsilon,
+ )
+ else:
+ self.final_layernorm = None # Either this or nn.Identity
+
+
+def transformer_block_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ hidden_states = fn(*args, **kwargs)
+ if get_args().tp_2d and parallel_state.is_pipeline_last_stage():
+ hidden_states = auto_grad_sync_gather_along_first_dim(hidden_states, TPXCollectiveComm)
+ hidden_states = auto_grad_sync_gather_along_last_dim(hidden_states, TPYCollectiveComm)
+ return hidden_states
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/core/transformer/transformer_config.py b/model/train/yoco_moe/mindspeed/core/transformer/transformer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf75ae0a7997abf39b7f288a23cb46c73544d58
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/transformer/transformer_config.py
@@ -0,0 +1,187 @@
+# Copyright (c) 2023; NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
+from dataclasses import make_dataclass, field
+from functools import wraps
+
+import torch.nn.functional as F
+
+from megatron.core.transformer import TransformerConfig
+from megatron.core.utils import init_method_normal, scaled_init_method_normal
+from megatron.training import get_args
+
+
+def transformer_config_post_init(self):
+ super(TransformerConfig, self).__post_init__()
+ if self.fp16 and self.bf16:
+ raise ValueError(
+ f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.'
+ )
+ args = get_args()
+ world_size = args.tp_x if args.tp_2d else self.tensor_model_parallel_size
+ if self.num_attention_heads % world_size != 0:
+ if not args.unaligned_linear:
+ raise ValueError(
+ f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
+ f"tensor_model_parallel_size ({world_size})."
+ )
+
+ if self.ffn_hidden_size is None:
+ self.ffn_hidden_size = 4 * self.hidden_size
+
+ if self.kv_channels is None:
+ self.kv_channels = self.hidden_size // self.num_attention_heads
+
+ if self.num_query_groups is None:
+ self.num_query_groups = self.num_attention_heads
+
+ if self.num_query_groups % world_size != 0:
+ if not args.unaligned_linear:
+ raise ValueError(
+ f"num_query_groups ({self.num_query_groups}) must be a multiple of "
+ f"tensor_model_parallel_size ({world_size})."
+ )
+
+ if self.apply_query_key_layer_scaling:
+ self.attention_softmax_in_fp32 = True
+
+ if self.expert_model_parallel_size > 1 and self.num_moe_experts is None:
+ raise ValueError(f'num_moe_experts must be non None to use expert-parallel.')
+
+ if self.num_moe_experts is not None and self.num_moe_experts <= 0:
+ raise ValueError(f'num_moe_experts must be non-negative.')
+
+ if self.moe_expert_capacity_factor is not None:
+ if self.moe_token_dispatcher_type != "alltoall":
+ raise ValueError(
+ f'moe_expert_capacity_factor only works with alltoall token dispatcher'
+ )
+ if self.moe_expert_capacity_factor < 0:
+ self.moe_expert_capacity_factor = None
+ if self.moe_router_load_balancing_type not in ["aux_loss", "none"]:
+ raise ValueError(
+ f'moe_expert_capacity_factor only works with aux_loss or none load balancing'
+ )
+
+ if self.moe_pad_expert_input_to_capacity:
+ if self.moe_expert_capacity_factor is None:
+ raise ValueError(
+ f'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity'
+ )
+
+ if self.cpu_offloading and (
+ self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers
+ ):
+ raise ValueError(
+ f'CPU offloading can be done only for layers less than {self.num_layers}'
+ )
+
+ if self.cpu_offloading and self.pipeline_model_parallel_size > 1:
+ raise ValueError(
+ f'Currently there is no support for Pipeline parallelism with CPU offloading'
+ )
+
+ if self.cpu_offloading and self.recompute_granularity is not None:
+ raise ValueError(
+ f'CPU offloading does not work when activation recomputation is enabled'
+ )
+
+ if self.recompute_granularity is not None:
+ if self.recompute_granularity not in ['full', 'selective']:
+ raise ValueError(
+ f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".'
+ )
+
+ if self.recompute_method is not None:
+ if self.recompute_method not in ['block', 'uniform']:
+ raise ValueError(
+ f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
+ )
+ elif self.recompute_granularity != 'selective':
+ raise ValueError(
+ f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"'
+ )
+
+ if self.recompute_granularity != 'selective' and self.recompute_num_layers is None:
+ raise ValueError(
+ f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between '
+ f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}'
+ )
+ elif (
+ self.recompute_granularity == 'selective' and self.recompute_num_layers is not None
+ ):
+ raise ValueError(
+ f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.'
+ )
+
+ if self.distribute_saved_activations and self.sequence_parallel:
+ raise ValueError(
+ f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}'
+ )
+
+ if self.virtual_pipeline_model_parallel_size is not None:
+ if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
+ raise ValueError(
+ f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}'
+ )
+
+ if self.apply_query_key_layer_scaling:
+ self.attention_softmax_in_fp32 = True
+
+ if self.bias_activation_fusion:
+ if self.activation_func not in [F.gelu, F.silu]:
+ raise ValueError(
+ "When bias_activation_fusion is True, activation function should be either gelu or swiglu"
+ )
+ if (
+ self.activation_func == F.gelu
+ and not self.gated_linear_unit
+ and not self.add_bias_linear
+ ):
+ raise ValueError(
+ "When bias_activation_fusion is True, gated_linear_unit is False, "
+ "and activation function is gelu, add_bias_linear must also be True."
+ )
+ if self.activation_func_fp8_input_store:
+ if self.activation_func != F.silu or not self.gated_linear_unit:
+ raise ValueError("Storing activation input in FP8 is supported only for SwiGLU.")
+ if self.apply_rope_fusion and self.rotary_interleaved:
+ raise ValueError(f'rotary_interleaved does not work with apply_rope_fusion.')
+
+ if self.init_method is None:
+ self.init_method = init_method_normal(self.init_method_std)
+
+ if self.output_layer_init_method is None:
+ self.output_layer_init_method = scaled_init_method_normal(
+ self.init_method_std, self.num_layers
+ )
+
+ if self.moe_extended_tp:
+ if self.moe_token_dispatcher_type != 'allgather':
+ raise ValueError(
+ "Moe extended TP parallelism only applies to allgather based token dispatcher."
+ )
+ extended_tp_size = self.tensor_model_parallel_size * self.expert_model_parallel_size
+ if self.ffn_hidden_size % extended_tp_size != 0:
+ raise ValueError(
+ f'ffn_hidden_size: {self.ffn_hidden_size} must be divisible by extended_tp_size {extended_tp_size}'
+ )
+
+
+def transformer_config_post_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self):
+ fn(self)
+ args = get_args()
+ fields = []
+ for key, value in vars(args).items():
+ field_name = str(key)
+ field_type = type(value)
+ if not hasattr(self, key):
+ field_def = (field_name, field_type, field(init=False))
+ fields.append(field_def)
+ self.__class__ = make_dataclass(self.__class__.__name__, fields=fields, bases=(self.__class__,))
+
+ for key, value in vars(args).items():
+ if not hasattr(self, key):
+ setattr(self, key, value)
+ return wrapper
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/core/weight_grad_store.py b/model/train/yoco_moe/mindspeed/core/weight_grad_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb50a0999ffda7f1224adf707143852ef86fb0f9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/core/weight_grad_store.py
@@ -0,0 +1,241 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import operator
+import queue
+from functools import reduce
+import torch
+import torch_npu
+
+from megatron.core.parallel_state import (
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_world_size
+)
+from megatron.training import get_args
+
+
+def gather(input_slice, stream):
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input_slice.size())
+ dim_size[0] = dim_size[0] * world_size
+
+ all_gather_buffer = torch.empty(
+ dim_size, dtype=input_slice.dtype, device=torch.cuda.current_device(), requires_grad=False
+ )
+ handle = None
+ forward_event = torch.npu.Event()
+ forward_event.record()
+ if get_args().use_nanopipe_swap:
+ swap_event = WeightGradStore.swap_event.pop(0)
+ with torch.no_grad():
+ with torch_npu.npu.stream(stream):
+ stream.wait_event(forward_event)
+ if get_args().use_nanopipe_swap:
+ stream.wait_event(swap_event)
+ handle = torch.distributed._all_gather_base(
+ all_gather_buffer, input_slice, group=get_tensor_model_parallel_group(), async_op=True
+ )
+
+ # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
+ # gather is scheduled before the input gradient computation
+ return all_gather_buffer, handle
+
+
+def swap_d2h(ori_tensor, stream):
+ storage_size = ori_tensor.storage().size()
+ tensor_cpu = torch.empty(ori_tensor.shape, dtype=ori_tensor.dtype, pin_memory=True, device='cpu')
+ forward_event = torch.npu.Event()
+ forward_event.record()
+ with torch.no_grad():
+ with torch_npu.npu.stream(stream):
+ stream.wait_event(forward_event)
+ tensor_cpu.storage().copy_(ori_tensor.storage(), non_blocking=True)
+ WeightGradStore.ori_storage.append(ori_tensor)
+
+ return storage_size, tensor_cpu
+
+
+def swap_h2d(ori_tensor, tensor_cpu, storage_size, stream):
+ with torch.no_grad():
+ with torch_npu.npu.stream(stream):
+ ori_tensor.storage().resize_(storage_size)
+ ori_tensor.storage().copy_(tensor_cpu.storage(), non_blocking=True)
+
+
+class WeightGradStore:
+ cache = []
+ weight_grad_queue = queue.Queue()
+ store_grad_cache = []
+ grad_store = []
+ swap_event = []
+ prefetch_stream = None
+ gather_stream = None
+ host_tensors_gradoutput = []
+ host_pipe_experts_grad = []
+ host_tensors_input = []
+ ori_storage = []
+ is_decoupleBlock = False
+ grad_overlap_count = 0
+ interval_per_layers_count = 0
+ interval_per_layers = []
+
+ @classmethod
+ def put(cls, total_input, grad_output, weight, sequence_parallel, in_row=False, pipe_experts=False):
+ if get_args().use_nanopipe_swap:
+ if cls.prefetch_stream is None:
+ cls.prefetch_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ if grad_output is not None:
+ cls.host_tensors_gradoutput.append(swap_d2h(grad_output, cls.prefetch_stream))
+ cls.host_tensors_input.append(swap_d2h(total_input, cls.prefetch_stream))
+ cls.interval_per_layers_count += 1
+ cls.cache.append((total_input, grad_output, weight, sequence_parallel, in_row, pipe_experts))
+
+ @classmethod
+ def flush(cls):
+ cls.interval_per_layers.append(cls.interval_per_layers_count)
+ cls.interval_per_layers_count = 0
+
+ @classmethod
+ def save_grad_output(cls, grad):
+ if get_args().use_nanopipe_swap:
+ if cls.prefetch_stream is None:
+ cls.prefetch_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ cls.host_pipe_experts_grad.append(swap_d2h(grad, cls.prefetch_stream))
+ cls.grad_store.append(grad)
+
+ @classmethod
+ def start_decouple(cls):
+ cls.is_decoupleBlock = True
+
+ @classmethod
+ def end_decouple(cls):
+ cls.is_decoupleBlock = False
+
+ @classmethod
+ def overlap_all_gather(cls):
+ # used for grad_output all gather in RowParallel and input all gather in ColumnParallel.
+ if len(cls.cache) > 0:
+ [input, grad_output_slice, weight, sequence_parallel, in_row, pipe_experts] = cls.cache.pop(0)
+ if not sequence_parallel:
+ return (input, grad_output_slice, weight, sequence_parallel, in_row, pipe_experts), None
+ if not in_row:
+ total_input, handle = gather(input, cls.gather_stream)
+ grad_output = grad_output_slice
+ else:
+ if pipe_experts and not get_args().use_nanopipe_swap:
+ grad_output_slice = cls.grad_store.pop(0)
+ grad_output, handle = gather(grad_output_slice, cls.gather_stream)
+ total_input = input
+ return [total_input, grad_output, weight, sequence_parallel, in_row, pipe_experts], handle
+ else:
+ raise Exception("All Gather empty queue.")
+
+ @classmethod
+ def swap_tensors(cls):
+ if get_args().use_nanopipe_swap:
+ if cls.prefetch_stream is None:
+ cls.prefetch_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ cls.prefetch_stream.wait_stream(torch.npu.current_stream())
+ for cache_id in range(len(cls.cache)):
+ cls.cache[cache_id] = list(cls.cache[cache_id])
+ if cls.cache[cache_id][-1] and cls.cache[cache_id][1] is None:
+ cls.cache[cache_id][1] = cls.grad_store.pop(0)
+ input, grad_output_slice, weight, sequence_parallel, in_row, pipe_experts = cls.cache[cache_id]
+ if pipe_experts:
+ storage_size_g, tensor_cpu_g = cls.host_pipe_experts_grad.pop(0)
+ else:
+ storage_size_g, tensor_cpu_g = cls.host_tensors_gradoutput.pop(0)
+ storage_size_i, tensor_cpu_i = cls.host_tensors_input.pop(0)
+ swap_h2d(grad_output_slice, tensor_cpu_g, storage_size_g, cls.prefetch_stream)
+ swap_h2d(input, tensor_cpu_i, storage_size_i, cls.prefetch_stream)
+ cls.swap_event.append((cls.prefetch_stream.record_event()))
+
+ @classmethod
+ def overlap_matmul(cls, grad_store_cache):
+ total_input, grad_output, weight, sequence_parallel, in_row, pipe_experts = grad_store_cache
+ grad_output = grad_output.contiguous()
+ sb = grad_output.shape[0] * grad_output.shape[1]
+ # Convert the tensor shapes to 2D for execution compatibility
+ grad_output = grad_output.view(
+ sb, grad_output.shape[2]
+ )
+ total_input = total_input.view(
+ sb, total_input.shape[2]
+ )
+ if get_args().gradient_accumulation_fusion:
+ import fused_weight_gradient_mlp_cuda
+ if weight.main_grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
+ total_input, grad_output, weight.main_grad
+ )
+ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
+ total_input, grad_output, weight.main_grad
+ )
+ else:
+ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
+ weight.main_grad.data.add_(grad_weight)
+ cls.grad_overlap_count += 1
+
+ @classmethod
+ def pop(cls, overlap_arg=None):
+ if len(cls.cache) == 0:
+ return
+ if cls.gather_stream is None:
+ cls.gather_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ if get_args().overlap_grad_reduce:
+ if overlap_arg is None:
+ raise RuntimeError("overlap_arg is invalid")
+ pipeline_parallel_size, nano_flag, synchronized_model_chunks, grad_sync_func, model = overlap_arg
+ model_chunk_id = len(nano_flag) - 1
+ input, grad_output_slice, weight, sequence_parallel, in_row, pipe_experts = cls.cache.pop(0)
+ if not sequence_parallel:
+ grad_output = grad_output_slice
+ handle = None
+ else:
+ if pipe_experts and not get_args().use_nanopipe_swap:
+ grad_output_slice = cls.grad_store.pop(0)
+ grad_output, handle = gather(grad_output_slice, cls.gather_stream)
+ layers_count = 0
+ cls.store_grad_cache = (input, grad_output, weight, sequence_parallel, in_row, pipe_experts)
+ while len(cls.cache) > 0:
+ if handle is not None:
+ handle.wait()
+ next_grad_cache, handle = cls.overlap_all_gather()
+ cls.overlap_matmul(cls.store_grad_cache)
+ if get_args().overlap_grad_reduce:
+ if cls.grad_overlap_count == cls.interval_per_layers[0]:
+ cls.interval_per_layers.pop(0)
+ layers_count += 1
+ if layers_count == pipeline_parallel_size:
+ if model_chunk_id not in synchronized_model_chunks and not nano_flag[model_chunk_id]:
+ grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
+ synchronized_model_chunks.add(model_chunk_id)
+ model_chunk_id -= 1
+ layers_count = 0
+ cls.grad_overlap_count = 0
+ cls.store_grad_cache = next_grad_cache
+ if handle is not None:
+ handle.wait()
+ cls.overlap_matmul(cls.store_grad_cache)
+ if get_args().overlap_grad_reduce:
+ if model_chunk_id not in synchronized_model_chunks and not nano_flag[model_chunk_id]:
+ grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
+ synchronized_model_chunks.add(model_chunk_id)
+ model_chunk_id -= 1
+ cls.grad_overlap_count = 0
+ cls.stored_grads = None
+ cls.store_grad_cache = None
+ cls.swap_event = []
+ cls.grad_store = []
+ cls.host_pipe_experts_grad = []
+ cls.interval_per_layers = []
+ cls.interval_per_layers_count = 0
+
+ @classmethod
+ def resize_ori_storage(cls, use_nano_swap):
+ if use_nano_swap and len(cls.ori_storage) > 0:
+ torch.npu.current_stream().wait_stream(cls.prefetch_stream)
+ for ori_storage_ in cls.ori_storage:
+ ori_storage_.storage().resize_(0)
+ cls.ori_storage = []
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/features_manager/__init__.py b/model/train/yoco_moe/mindspeed/features_manager/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b98d707f78b58ad0c7da83fe7133e89cfccccc1b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/features_manager/__init__.py
@@ -0,0 +1,9 @@
+from .tensor_parallel.unaligned_linear_feature import UnalignedLinearFeature
+from .functional.profiler_default_feature import ProfilerDefaultFeature
+
+FEATURES_LIST = [
+ # Functional features
+ ProfilerDefaultFeature(),
+ # Tensor parallel features
+ UnalignedLinearFeature()
+]
diff --git a/model/train/yoco_moe/mindspeed/features_manager/feature.py b/model/train/yoco_moe/mindspeed/features_manager/feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d014dcb76fdaaef61f0a6102407625a1c800455
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/features_manager/feature.py
@@ -0,0 +1,38 @@
+import argparse
+
+
+class MindSpeedFeature:
+ def __init__(self, feature_name: str, optimization_level: int = 2):
+ self.feature_name = feature_name.strip().replace('-', '_')
+ self.optimization_level = optimization_level
+ self.default_patches = self.optimization_level == 0
+
+ def register_args(self, parser):
+ pass
+
+ def pre_validate_args(self, args):
+ pass
+
+ def validate_args(self, args):
+ pass
+
+ def post_validate_args(self, args):
+ pass
+
+ def register_patches(self, patch_manager, args):
+ ...
+
+ def incompatible_check(self, global_args, check_args):
+ if getattr(global_args, self.feature_name, None) and getattr(global_args, check_args, None):
+ raise AssertionError('{} and {} are incompatible.'.format(self.feature_name, check_args))
+
+ def dependency_check(self, global_args, check_args):
+ if getattr(global_args, self.feature_name, None) and not getattr(global_args, check_args, None):
+ raise AssertionError('{} requires {}.'.format(self.feature_name, check_args))
+
+ @staticmethod
+ def add_parser_argument_choices_value(parser, argument_name, new_choice):
+ for action in parser._actions:
+ exist_arg = isinstance(action, argparse.Action) and argument_name in action.option_strings
+ if exist_arg and action.choices is not None and new_choice not in action.choices:
+ action.choices.append(new_choice)
diff --git a/model/train/yoco_moe/mindspeed/features_manager/functional/__init__.py b/model/train/yoco_moe/mindspeed/features_manager/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/features_manager/functional/profiler_default_feature.py b/model/train/yoco_moe/mindspeed/features_manager/functional/profiler_default_feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a737b3625cdc7f0f6e8dc7d1c9e5eab99a9e7a0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/features_manager/functional/profiler_default_feature.py
@@ -0,0 +1,10 @@
+from mindspeed.features_manager.feature import MindSpeedFeature
+
+
+class ProfilerDefaultFeature(MindSpeedFeature):
+ def __init__(self):
+ super(ProfilerDefaultFeature, self).__init__('profiler-default', 0)
+
+ def register_patches(self, patch_manager, args):
+ from mindspeed.functional.profiler.profiler_info import get_nccl_options_add_group_info_wrapper
+ patch_manager.register_patch('megatron.core.parallel_state.get_nccl_options', get_nccl_options_add_group_info_wrapper)
diff --git a/model/train/yoco_moe/mindspeed/features_manager/tensor_parallel/__init__.py b/model/train/yoco_moe/mindspeed/features_manager/tensor_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/features_manager/tensor_parallel/unaligned_linear_feature.py b/model/train/yoco_moe/mindspeed/features_manager/tensor_parallel/unaligned_linear_feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..14471c7b60c675f7a75d566f370a736cbd837e6c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/features_manager/tensor_parallel/unaligned_linear_feature.py
@@ -0,0 +1,58 @@
+from argparse import ArgumentParser
+
+from mindspeed.features_manager.feature import MindSpeedFeature
+
+
+class UnalignedLinearFeature(MindSpeedFeature):
+
+ def __init__(self):
+ super().__init__('unaligned-linear')
+
+ def register_args(self, parser: ArgumentParser):
+ group = parser.add_argument_group(title=self.feature_name)
+ group.add_argument('--unaligned-linear', action='store_true',
+ help='Replace ColumnParallelLinear/RowParallelLinear with '
+ 'UnalignedColumnParallelLinearAdaptor/UnalignedRowParallelLinearAdaptor.')
+
+ def validate_args(self, args):
+ self.incompatible_check(args, 'use_ascend_mc2')
+ self.incompatible_check(args, 'tp_2d')
+ if args.unaligned_linear and args.num_experts and args.num_experts > 1:
+ raise AssertionError("The unaligned linear feature does not support the moe model.")
+ # self.dependency_check(..)
+
+ def register_patches(self, patch_manager, args):
+ from mindspeed.core.tensor_parallel.unaligned_layers.adaptor import divide_adaptor, \
+ scatter_to_sequence_parallel_region_adaptor, get_rotary_seq_len, UnalignedColumnParallelLinearAdaptor, \
+ UnalignedRowParallelLinearAdaptor, reduce_scatter_to_sequence_parallel_region_adaptor, \
+ gather_from_sequence_parallel_region_adaptor
+ from mindspeed.core.transformer.transformer_config import transformer_config_post_init
+ from mindspeed.core.transformer.dot_product_attention import dot_product_attention_init_wrapper
+ from mindspeed.core.transformer.attention import attention_init_wrapper
+ if getattr(args, self.feature_name, None):
+ patch_manager.register_patch('megatron.core.tensor_parallel.layers.ColumnParallelLinear',
+ UnalignedColumnParallelLinearAdaptor)
+ patch_manager.register_patch('megatron.core.tensor_parallel.layers.RowParallelLinear',
+ UnalignedRowParallelLinearAdaptor)
+
+ # To adapt to the distribution of MHA attention heads
+ patch_manager.register_patch('megatron.core.utils.divide', divide_adaptor)
+ patch_manager.register_patch(
+ 'megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.get_rotary_seq_len',
+ get_rotary_seq_len)
+ patch_manager.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__',
+ transformer_config_post_init)
+
+ # To adapt to the distribution of GQA attention heads
+ patch_manager.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__',
+ dot_product_attention_init_wrapper)
+ patch_manager.register_patch('megatron.core.transformer.attention.Attention.__init__',
+ attention_init_wrapper)
+ patch_manager.register_patch('megatron.core.tensor_parallel.mappings.gather_from_sequence_parallel_region',
+ gather_from_sequence_parallel_region_adaptor)
+
+ # To adapt to the sequence parallel feature
+ patch_manager.register_patch('megatron.core.tensor_parallel.mappings.scatter_to_sequence_parallel_region',
+ scatter_to_sequence_parallel_region_adaptor)
+ patch_manager.register_patch('megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region',
+ reduce_scatter_to_sequence_parallel_region_adaptor)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/functional/__init__.py b/model/train/yoco_moe/mindspeed/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/functional/profiler/__init__.py b/model/train/yoco_moe/mindspeed/functional/profiler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/functional/profiler/profiler_info.py b/model/train/yoco_moe/mindspeed/functional/profiler/profiler_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..afd515c51eafc57d7fde6952550fe38513611743
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/functional/profiler/profiler_info.py
@@ -0,0 +1,19 @@
+from functools import wraps
+import torch_npu
+
+
+def get_nccl_options_add_group_info_wrapper(get_nccl_options):
+ @wraps(get_nccl_options)
+ def wrapper(pg_name, nccl_comm_cfgs):
+ options = get_nccl_options(pg_name, nccl_comm_cfgs)
+ if hasattr(torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options, 'hccl_config'):
+ options = options if options is not None else torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
+ try:
+ # torch_npu not support inplace update
+ hccl_config = options.hccl_config
+ hccl_config.update({'group_name': pg_name})
+ options.hccl_config = hccl_config
+ except TypeError as e:
+ pass # compatible with old torch_npu version
+ return options
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/initialize.py b/model/train/yoco_moe/mindspeed/initialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..168b17d389d573d2eca10a67b4c016e831d462d4
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/initialize.py
@@ -0,0 +1,65 @@
+import time
+from functools import wraps
+import torch
+import torch_npu
+from megatron.training import get_args
+from megatron.training.utils import print_rank_0
+from megatron.training.initialize import _warmup_jit_function
+from mindspeed.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args
+from .utils import extend_seed_all
+
+
+def _compile_dependencies():
+ if torch.distributed.get_rank() == 0:
+ start_time = time.time()
+ print('> compiling dataset index builder ...')
+ from megatron.core.datasets.utils import compile_helpers
+ compile_helpers()
+ print('>>> done with dataset index builder. Compilation time: {:.3f} '
+ 'seconds'.format(time.time() - start_time), flush=True)
+
+
+def set_jit_fusion_options_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ def _jit_set_nvfuser_enabled(option):
+ pass
+ torch._C._jit_set_nvfuser_enabled = _jit_set_nvfuser_enabled
+ fn(*args, **kwargs)
+ args = get_args()
+ if args.jit_compile:
+ torch_npu.npu.set_compile_mode(jit_compile=True)
+
+ return wrapper
+
+
+def coc_registration_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ res = fn(*args, **kwargs)
+ from mindspeed.core.tensor_parallel.lcal_coc.user_config import initialize_coc_from_cfg
+ args = get_args()
+ initialize_coc_from_cfg(args)
+ return res
+
+ return wrapper
+
+
+def mc2_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ res = fn(*args, **kwargs)
+ args = get_args()
+ initialize_cfg_from_args(args)
+ return res
+
+ return wrapper
+
+
+def deter_comp_wrapper(fn):
+ @wraps(fn)
+ def wrapper(seed_, data_parallel_random_init=False):
+ fn(seed_, data_parallel_random_init=False)
+ extend_seed_all(seed_)
+ print_rank_0("deterministic computing is applied for npu.")
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/megatron_adaptor.py b/model/train/yoco_moe/mindspeed/megatron_adaptor.py
new file mode 100644
index 0000000000000000000000000000000000000000..241a8768b4513f0e918fa6ca02dd99f686bc8907
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/megatron_adaptor.py
@@ -0,0 +1,1145 @@
+import os
+import sys
+import shutil
+import argparse
+import time
+from functools import wraps
+from multiprocessing import Lock
+import torch
+from torch.distributed import all_gather_into_tensor, reduce_scatter_tensor
+from torch_npu.contrib import transfer_to_npu
+from mindspeed.features_manager import FEATURES_LIST
+from .arguments import process_args
+
+
+_ARGS = None
+
+
+def add_args(args, key, value):
+ if key is not None:
+ key = key[2:].replace('-', '_')
+ if value is None:
+ value = True
+ elif len(value) == 1:
+ value = value[0]
+ setattr(args, key, value)
+
+
+def parser_unknown_args(args, unknown):
+ i = 0
+ key = value = None
+ while i < len(unknown):
+ if unknown[i].startswith("--"):
+ add_args(args, key, value)
+ key = unknown[i]
+ value = None
+ else:
+ if value is None:
+ value = [unknown[i]]
+ else:
+ value.append(unknown[i])
+ i += 1
+ add_args(args, key, value)
+
+
+def get_mindspeed_args():
+ global _ARGS
+ if _ARGS is None:
+ parser = argparse.ArgumentParser(description='MindSpeed Arguments', allow_abbrev=False)
+ _ARGS, unknown = process_args(parser).parse_known_args()
+ parser_unknown_args(_ARGS, unknown)
+ return _ARGS
+
+
+def dummy_jit(fn):
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+
+def lcm(a, b):
+ import math
+ return (a * b) // math.gcd(a, b)
+
+
+def type_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ res = fn(*args, **kwargs)
+ if isinstance(res, str):
+ res = res.replace('npu', 'cuda')
+ return res
+
+ return wrapper
+
+
+def version_wrapper(fn):
+ @wraps(fn)
+ def wrapper(name, *args, **kwargs):
+ if name == 'transformer-engine':
+ return '0.0'
+ res = fn(name, *args, **kwargs)
+ return res
+
+ return wrapper
+
+
+# Patch view method to ensure tensor is contiguous before performing view
+def ensure_contiguous_wrapper(fn):
+ def wrapper(tensor, *args, **kwargs):
+ if not tensor.is_contiguous():
+ tensor = tensor.contiguous()
+ return fn(tensor, *args, **kwargs)
+
+ return wrapper
+
+
+def multi_tensor_applier(op, noop_flag_buffer, tensor_lists, *args):
+ return op(noop_flag_buffer, tensor_lists, *args)
+
+
+def multi_tensor_l2norm(overflow_buf, tensor_lists, per_parameter):
+ total_norm = 0.0
+ norm_type = 2.0
+ ret_per_tensor = [] if per_parameter else None
+ for grads_for_norm in tensor_lists:
+ for grad in grads_for_norm:
+ grad_norm = torch.norm(grad, norm_type)
+ total_norm += grad_norm ** norm_type
+ if per_parameter:
+ ret_per_tensor.append(total_norm.clone())
+ if not tensor_lists:
+ grad_norm = torch.cuda.FloatTensor([0])
+ total_norm = grad_norm ** norm_type
+ return total_norm ** (1 / norm_type), ret_per_tensor
+
+
+def multi_tensor_scale(overflow_buf, tensor_lists, scale):
+ if len(tensor_lists) != 2:
+ raise AssertionError('The size of tensor list must be 2, but got {}'.format(len(tensor_lists)))
+ if len(tensor_lists[0]) != len(tensor_lists[1]):
+ raise AssertionError('The size of tensor list must be same, but got {} and {}'.format(len(tensor_lists[0]),
+ len(tensor_lists[1])))
+
+ with torch.no_grad():
+ for i in range(len(tensor_lists[0])):
+ tensor_lists[1][i].copy_(tensor_lists[0][i] * scale)
+
+
+def te_adaptation(aspm):
+ aspm.register_patch('torch.compile', torch.jit.script)
+ # Need replace modules before import megatron
+ aspm.register_patch('importlib.metadata.version', version_wrapper)
+ aspm.register_patch('transformer_engine.pytorch.LayerNormLinear', torch.nn.Module, create_dummy=True)
+ aspm.register_patch('transformer_engine.pytorch.DotProductAttention', torch.nn.Module, create_dummy=True)
+ aspm.register_patch('transformer_engine.pytorch.Linear', torch.nn.Module, create_dummy=True)
+ aspm.register_patch('transformer_engine.common.recipe.DelayedScaling', torch.nn.Module, create_dummy=True)
+ aspm.register_patch('flash_attn.flash_attn_interface.flash_attn_unpadded_func', create_dummy=True)
+
+
+def apex_adaptation(aspm):
+ from .core.fusions.fused_layer_norm import fused_layer_norm_affine
+ from .ops.npu_matmul_add import npu_matmul_add_fp32, npu_matmul_add_fp16
+ aspm.register_patch('amp_C.multi_tensor_l2norm', multi_tensor_l2norm, create_dummy=True)
+ aspm.register_patch('amp_C.multi_tensor_scale', multi_tensor_scale, create_dummy=True)
+ aspm.register_patch('fused_layer_norm_cuda', create_dummy=True)
+ aspm.register_patch('apex.multi_tensor_apply.multi_tensor_applier', multi_tensor_applier, create_dummy=True)
+ aspm.register_patch('apex.normalization.fused_layer_norm.fused_layer_norm_affine', fused_layer_norm_affine,
+ create_dummy=True)
+ aspm.register_patch('fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32', npu_matmul_add_fp32, create_dummy=True)
+ aspm.register_patch('fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16', npu_matmul_add_fp16, create_dummy=True)
+
+
+def torch_adaptation(aspm):
+ aspm.register_patch('torch.nn.parameter.Parameter.type', type_wrapper)
+ aspm.register_patch('torch.Tensor.type', type_wrapper)
+ aspm.register_patch('torch.Tensor.view', ensure_contiguous_wrapper)
+ aspm.register_patch('torch.distributed._all_gather_base', all_gather_into_tensor)
+ aspm.register_patch('torch.distributed._reduce_scatter_base', reduce_scatter_tensor)
+ # lmc is supported python >=3.9
+ if sys.version_info < (3, 9):
+ aspm.register_patch('math.lcm', lcm, create_dummy=True)
+
+
+def communication_adaptation(aspm, mindspeed_args):
+ if mindspeed_args.disable_gloo_group:
+ from mindspeed.optimizer.distrib_optimizer import get_parameter_state_dp_zero_hccl, \
+ load_parameter_state_from_dp_zero_hccl
+ from mindspeed.core.parallel_state import (get_data_parallel_group_gloo_replace,
+ get_data_modulo_expert_parallel_group_gloo_replace,
+ new_group_wrapper)
+ from mindspeed.utils import check_param_hashes_across_dp_replicas_hccl
+
+ aspm.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.get_parameter_state_dp_zero',
+ get_parameter_state_dp_zero_hccl)
+ aspm.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.load_parameter_state_from_dp_zero',
+ load_parameter_state_from_dp_zero_hccl)
+ aspm.register_patch('megatron.core.utils.check_param_hashes_across_dp_replicas',
+ check_param_hashes_across_dp_replicas_hccl)
+
+ aspm.register_patch('megatron.core.parallel_state.get_data_parallel_group_gloo',
+ get_data_parallel_group_gloo_replace)
+ aspm.register_patch('megatron.core.parallel_state.get_data_modulo_expert_parallel_group_gloo',
+ get_data_modulo_expert_parallel_group_gloo_replace)
+ aspm.register_patch('torch.distributed.new_group', new_group_wrapper)
+
+
+def mcore_models_adaptation_l0(aspm):
+ from .core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec_wrapper
+ from .core.parallel_state import get_nccl_options_wrapper
+ # Replace FusedLayerNorm with MindSpeed's PTNorm operator in get_gpt-layer
+ aspm.register_patch('megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec',
+ get_gpt_layer_local_spec_wrapper)
+ aspm.register_patch('megatron.core.parallel_state.get_nccl_options', get_nccl_options_wrapper)
+
+
+def mcore_models_adaptation(aspm, mindspeed_args):
+ import megatron.core
+ megatron.core.jit.jit_fuser = dummy_jit
+
+ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
+ from .core.models.common.embeddings.rotary_pos_embedding import get_pos_emb_on_this_cp_rank, \
+ rotary_embedding_init_wrapper
+ aspm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.get_pos_emb_on_this_cp_rank',
+ get_pos_emb_on_this_cp_rank)
+ aspm.register_patch('megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec',
+ get_gpt_layer_local_spec)
+ aspm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.__init__',
+ rotary_embedding_init_wrapper)
+ from .core.models.common.embeddings.language_model_embedding import language_model_embedding_forward_wrapper
+ aspm.register_patch('megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
+ language_model_embedding_forward_wrapper)
+ from .core.models.common.embeddings.rotary_pos_embedding import rotary_embedding_get_rotary_seq_len_wrapper
+ aspm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.get_rotary_seq_len',
+ rotary_embedding_get_rotary_seq_len_wrapper)
+ # Fix DDP scaling factor with Context Parallel
+ from .core.data_parallel.distributed_data_parallel import distributed_data_parallel_init_with_cp
+ aspm.register_patch('megatron.core.distributed.distributed_data_parallel.DistributedDataParallel.__init__',
+ distributed_data_parallel_init_with_cp)
+
+ if not mindspeed_args.automated_pipeline and mindspeed_args.noop_layers:
+ from .core.transformer.transformer_block import _build_layers
+ from .core.transformer.moe.moe_utils import track_moe_metrics
+ from megatron.core.transformer.transformer_block import TransformerBlock
+ from mindspeed.training import num_floating_point_wrapper
+ TransformerBlock._build_layers = _build_layers
+ aspm.register_patch('megatron.training.training.num_floating_point_operations', num_floating_point_wrapper)
+ aspm.register_patch('megatron.core.transformer.moe.moe_utils.track_moe_metrics', track_moe_metrics)
+
+
+ if mindspeed_args.recompute_norm:
+ from .core.models.gpt.gpt_layer_specs import build_norm_recompute_layer_wrapper
+ aspm.register_patch('megatron.core.transformer.transformer_block.TransformerBlock._build_layers', build_norm_recompute_layer_wrapper)
+
+ if getattr(mindspeed_args, 'reset_attention_mask', False):
+ from .core.datasets.gpt_dataset import _get_ltor_masks_and_position_ids, collate_wrapper
+ from .utils import get_batch_on_this_cp_rank_wrapper
+ aspm.register_patch('megatron.core.datasets.gpt_dataset._get_ltor_masks_and_position_ids', _get_ltor_masks_and_position_ids)
+ aspm.register_patch('torch.utils.data._utils.collate.default_collate', collate_wrapper)
+ aspm.register_patch('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank_wrapper)
+
+ from mindspeed.core.pipeline_parallel.p2p_communication import _p2p_ops_eod
+ aspm.register_patch('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod)
+ from mindspeed.core.models.gpt.gpt_model import gpt_forward_wrapper
+ aspm.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_forward_wrapper)
+ from .core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb_thd
+ aspm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.apply_rotary_pos_emb_thd', apply_rotary_pos_emb_thd)
+ from .core.transformer.attention import attention_forward
+ aspm.register_patch('megatron.core.transformer.attention.Attention.forward', attention_forward)
+
+ from .core.models.common.embeddings.rotary_pos_embedding import rotary_forward
+ aspm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.forward', rotary_forward)
+
+
+def mcore_transformer_adaptation_l0(aspm):
+ import megatron.core
+ from .core.transformer.custom_layers.transformer_engine import PTNorm
+ from .core.transformer.dot_product_attention import dot_product_attention_forward_wrapper, \
+ dot_product_attention_init
+ megatron.core.transformer.transformer_block.LayerNormImpl = PTNorm
+ aspm.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TENorm', PTNorm)
+ # Add cp parameters to dot_deduct_mattention init, and add fusion attention support for alibi in non cp situations
+ aspm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.__init__',
+ dot_product_attention_init)
+ aspm.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention.forward',
+ dot_product_attention_forward_wrapper)
+
+
+def mcore_transformer_adaptation(aspm, args):
+ from .core.transformer.module import megatron_module_init_wrapper
+ from .core.transformer.attention import (attention_init, SelfAttentionSubmodules,
+ self_attention_init_wrapper, attention_forward_wrapper)
+ from .core.transformer.transformer_block import transformer_block_checkpointed_forward_wrapper
+ from .core.transformer.transformer import parallel_transformer_layer_init_wrapper
+ from .core.transformer.transformer import core_mlp_forward_wrapper
+ from .core.transformer.mlp import mlp_init_2d_wrapper
+ from .core.transformer.transformer_block import transformer_block_forward_wrapper
+ aspm.register_patch('megatron.core.transformer.attention.SelfAttentionSubmodules', SelfAttentionSubmodules)
+ aspm.register_patch('megatron.core.transformer.attention.SelfAttention.__init__', self_attention_init_wrapper)
+ aspm.register_patch("megatron.core.transformer.attention.Attention.forward", attention_forward_wrapper)
+ aspm.register_patch('megatron.core.transformer.attention.Attention.__init__', attention_init)
+ aspm.register_patch('megatron.core.transformer.module.MegatronModule.__init__', megatron_module_init_wrapper)
+ aspm.register_patch('megatron.core.transformer.transformer_block.TransformerBlock._checkpointed_forward',
+ transformer_block_checkpointed_forward_wrapper)
+ aspm.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer.__init__',
+ parallel_transformer_layer_init_wrapper)
+ aspm.register_patch('megatron.core.transformer.mlp.MLP.forward',
+ core_mlp_forward_wrapper)
+ aspm.register_patch('megatron.core.transformer.mlp.MLP.__init__', mlp_init_2d_wrapper)
+ aspm.register_patch('megatron.core.transformer.transformer_block.TransformerBlock.forward',
+ transformer_block_forward_wrapper)
+ if hasattr(args, "multi_head_latent_attention") and args.multi_head_latent_attention:
+ from mindspeed.core.transformer.attention import self_attention_init_mla_wrapper
+ aspm.register_patch('megatron.core.transformer.attention.SelfAttention.__init__', self_attention_init_mla_wrapper)
+
+
+def mcore_parallel_state_adaptation(aspm):
+ from .core.parallel_state import initialize_model_parallel_wrapper
+ from .core.parallel_state import destroy_model_parallel_wrapper
+ from .core.memory.auto_pipeline.autopipeline_solver import destroy_model_parallel_profiling_wrapper
+ from .core.parallel_state import get_context_parallel_group_for_send_recv_overlap
+ aspm.register_patch('megatron.core.parallel_state.initialize_model_parallel',
+ initialize_model_parallel_wrapper)
+ aspm.register_patch('megatron.core.parallel_state.destroy_model_parallel',
+ destroy_model_parallel_wrapper)
+ aspm.register_patch('megatron.core.parallel_state.destroy_model_parallel',
+ destroy_model_parallel_profiling_wrapper)
+ aspm.register_patch('megatron.core.parallel_state.get_context_parallel_group_for_send_recv_overlap',
+ get_context_parallel_group_for_send_recv_overlap)
+
+
+def mcore_fusions_adaptation(aspm, args):
+ from .core.fusions.fused_bias_swiglu import SwiGLUFunction, BiasSwiGLUFunction
+ from .core.fusions.fused_layer_norm import FusedLayerNormAffineFunction, FastLayerNormFN
+ from .core.fusions.fused_softmax import is_kernel_available, ScaledUpperTriangMaskedSoftmax, ScaledMaskedSoftmax, \
+ ScaledSoftmax, forward_fused_softmax
+ from .core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb_bshd
+ aspm.register_patch('megatron.core.fusions.fused_layer_norm.FusedLayerNormAffineFunction',
+ FusedLayerNormAffineFunction)
+ aspm.register_patch('megatron.core.fusions.fused_layer_norm.FastLayerNormFN', FastLayerNormFN)
+ aspm.register_patch('megatron.core.fusions.fused_softmax.ScaledUpperTriangMaskedSoftmax',
+ ScaledUpperTriangMaskedSoftmax)
+ aspm.register_patch('megatron.core.fusions.fused_softmax.ScaledMaskedSoftmax', ScaledMaskedSoftmax)
+ aspm.register_patch('megatron.core.fusions.fused_softmax.ScaledSoftmax', ScaledSoftmax)
+ aspm.register_patch('megatron.core.fusions.fused_softmax.FusedScaleMaskSoftmax.is_kernel_available',
+ is_kernel_available)
+ aspm.register_patch('megatron.core.fusions.fused_softmax.FusedScaleMaskSoftmax.forward_fused_softmax',
+ forward_fused_softmax)
+ aspm.register_patch('megatron.core.fusions.fused_bias_swiglu.SwiGLUFunction', SwiGLUFunction)
+ aspm.register_patch('megatron.core.fusions.fused_bias_swiglu.BiasSwiGLUFunction', BiasSwiGLUFunction)
+
+ aspm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.apply_rotary_pos_emb_bshd',
+ apply_rotary_pos_emb_bshd)
+ if hasattr(args, 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute:
+ from .core.fusions.npu_moe_token_permute import permute_wrapper
+ from .core.fusions.npu_moe_token_unpermute import unpermute_wrapper
+ aspm.register_patch('megatron.core.transformer.moe.moe_utils.permute', permute_wrapper)
+ aspm.register_patch('megatron.core.transformer.moe.moe_utils.unpermute', unpermute_wrapper)
+ if args.npu_deterministic:
+ from mindspeed.initialize import deter_comp_wrapper
+ aspm.register_patch('megatron.training.initialize._set_random_seed', deter_comp_wrapper)
+
+
+def mcore_optimizer_adapation(aspm, mindspeed_args):
+ from .optimizer.distrib_optimizer import reuse_fp32_param_distrib_optimizer_init_wrapper
+ from .optimizer.optimizer import (step_with_ready_grads, prepare_grads,
+ reuse_fp32_param_init_wrapper, optimizer_config_init_wrapper)
+ from .core.distributed.param_and_grad_buffer import reuse_fp32_param_param_and_grad_buffer_init_wrapper
+ # optim relative.
+ aspm.register_patch('megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.prepare_grads',
+ prepare_grads)
+ aspm.register_patch('megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step_with_ready_grads',
+ step_with_ready_grads)
+ aspm.register_patch('megatron.core.optimizer.optimizer.Float16OptimizerWithFloat16Params.__init__',
+ reuse_fp32_param_init_wrapper)
+ aspm.register_patch('megatron.core.optimizer.optimizer_config.OptimizerConfig.__init__',
+ optimizer_config_init_wrapper)
+ aspm.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__',
+ reuse_fp32_param_distrib_optimizer_init_wrapper)
+ aspm.register_patch('megatron.core.distributed.ParamAndGradBuffer.__init__',
+ reuse_fp32_param_param_and_grad_buffer_init_wrapper)
+
+ if mindspeed_args.param_and_grad_buffer_pad:
+ from .core.distributed.param_and_grad_buffer import param_and_grad_buffer_init_pad
+ aspm.register_patch('megatron.core.distributed.ParamAndGradBuffer.__init__',
+ param_and_grad_buffer_init_pad)
+
+
+def mcore_pipeline_parallel_adaptation(aspm, mindspeed_args):
+ from .core.pipeline_parallel.schedules import get_tensor_shapes_wrapper, get_forward_backward_func_wrapper
+ from .core.performance.auto_pipeline_perf.schedules import get_forward_backward_func_decorator, \
+ backward_step_decorator, forward_step_decorator
+
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.get_forward_backward_func',
+ get_forward_backward_func_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.get_forward_backward_func',
+ get_forward_backward_func_decorator)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.backward_step',
+ backward_step_decorator)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.forward_step',
+ forward_step_decorator)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.get_tensor_shapes',
+ get_tensor_shapes_wrapper)
+ if mindspeed_args.optimize_vpp_send_recv_comm:
+ from .core.pipeline_parallel.p2p_communication import _p2p_ops_send_recv_overlap
+ aspm.register_patch('megatron.core.pipeline_parallel.p2p_communication._p2p_ops',
+ _p2p_ops_send_recv_overlap)
+ if mindspeed_args.variable_seq_lengths:
+ from .core.pipeline_parallel.p2p_communication import _communicate_shapes, _communicate
+ aspm.register_patch('megatron.core.pipeline_parallel.p2p_communication._communicate',
+ _communicate)
+ aspm.register_patch('megatron.core.pipeline_parallel.p2p_communication._communicate_shapes',
+ _communicate_shapes)
+
+
+def mcore_multiparam_pipeline_parallel_adaptation(aspm, mindspeed_args):
+ if mindspeed_args.use_multiparameter_pipeline_model_parallel:
+ from .core.pipeline_parallel.multiparameter_schedules import get_tensor_shapes_wrapper, forward_step_wrapper, \
+ recv_forward_wrapper, recv_backward_wrapper, send_forward_wrapper, send_backward_wrapper, \
+ send_forward_recv_backward_wrapper, send_backward_recv_forward_wrapper, backward_step_wrapper
+
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.get_tensor_shapes',
+ get_tensor_shapes_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.forward_step',
+ forward_step_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.backward_step',
+ backward_step_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.recv_forward',
+ recv_forward_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.recv_backward',
+ recv_backward_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.send_forward',
+ send_forward_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.send_backward',
+ send_backward_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.send_forward_recv_backward',
+ send_forward_recv_backward_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.send_backward_recv_forward',
+ send_backward_recv_forward_wrapper)
+
+
+def mcore_tensor_parallel_adaptation_l0(aspm):
+ from .core.tensor_parallel.random import _set_cuda_rng_state
+ aspm.register_patch('megatron.core.tensor_parallel.random._set_cuda_rng_state', _set_cuda_rng_state)
+
+
+def mcore_tensor_parallel_adaptation_l1(aspm):
+ from .core.tensor_parallel.cross_entropy import calculate_predicted_logits
+ # use logical negation followed by multiplication to achieve the same effect as setting selected elements to zero
+ aspm.register_patch('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
+ calculate_predicted_logits)
+
+
+def mcore_tensor_parallel_adaptation(aspm, args):
+ from .core.tensor_parallel.random import checkpoint_wrapper
+ from .core.tensor_parallel.random import checkpoint_function_backward
+ from .core.tensor_parallel.layers import vocab_parallel_embedding_forward
+ from .core.tensor_parallel.layers import row_parallel_nocomm_optimizer_wrapper
+ from .core.tensor_parallel.layers import parallel_linear_init_wrapper
+
+ def has_recomputation_or_swap(args):
+ return (args.swap_attention or
+ args.recompute_in_bubble or
+ args.adaptive_recompute_device_swap or
+ args.recompute_in_advance or
+ args.adaptive_memory_optimization)
+
+ aspm.register_patch('megatron.core.tensor_parallel.random.CheckpointFunction.backward',
+ checkpoint_function_backward)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
+ vocab_parallel_embedding_forward)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.RowParallelLinear.forward',
+ row_parallel_nocomm_optimizer_wrapper)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.RowParallelLinear.__init__',
+ parallel_linear_init_wrapper)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__',
+ parallel_linear_init_wrapper)
+ aspm.register_patch('megatron.core.tensor_parallel.random.checkpoint', checkpoint_wrapper)
+ if has_recomputation_or_swap(args):
+ from .core.tensor_parallel.layers import linear_forward_main_grad_wrapper, linear_backward_main_grad_wrapper
+ aspm.register_patch('megatron.core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.forward',
+ linear_forward_main_grad_wrapper)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.backward',
+ linear_backward_main_grad_wrapper)
+
+
+def megatron_legacy_adaptation(aspm):
+ from .model.language_model import parallel_lm_logits, embedding_forward_wrapper
+ from .core.performance.auto_pipeline_perf.data_samplers import build_pretraining_data_loader_decorator
+ from .core.performance.auto_pipeline_perf.transformer import get_attention_mask_wrapper
+ aspm.register_patch('mindspeed.model.transformer.get_attention_mask', get_attention_mask_wrapper)
+ aspm.register_patch('megatron.legacy.data.data_samplers.build_pretraining_data_loader',
+ build_pretraining_data_loader_decorator)
+ aspm.register_patch('megatron.legacy.model.language_model.parallel_lm_logits', parallel_lm_logits)
+ aspm.register_patch('megatron.legacy.model.language_model.Embedding.forward', embedding_forward_wrapper)
+
+
+def legacy_model_fusions_adaptation(aspm):
+ from .core.fusions.fused_layer_norm import FusedLayerNormAffineFunction, FastLayerNormFN, fused_layer_norm_affine
+ from .core.fusions.fused_softmax import is_kernel_available, ScaledUpperTriangMaskedSoftmax, ScaledMaskedSoftmax, \
+ ScaledSoftmax, forward_fused_softmax
+ aspm.register_patch('megatron.legacy.model.fused_layer_norm.FusedLayerNormAffineFunction',
+ FusedLayerNormAffineFunction)
+ aspm.register_patch('megatron.legacy.model.fused_layer_norm.FastLayerNormFN', FastLayerNormFN)
+ aspm.register_patch('megatron.legacy.model.fused_layer_norm.fused_layer_norm_affine', fused_layer_norm_affine)
+ aspm.register_patch('megatron.legacy.model.fused_softmax.ScaledUpperTriangMaskedSoftmax',
+ ScaledUpperTriangMaskedSoftmax)
+ aspm.register_patch('megatron.legacy.model.fused_softmax.ScaledMaskedSoftmax', ScaledMaskedSoftmax)
+ aspm.register_patch('megatron.legacy.model.fused_softmax.ScaledSoftmax', ScaledSoftmax)
+ aspm.register_patch('megatron.legacy.model.fused_softmax.FusedScaleMaskSoftmax.is_kernel_available',
+ is_kernel_available)
+ aspm.register_patch('megatron.legacy.model.fused_softmax.FusedScaleMaskSoftmax.forward_fused_softmax',
+ forward_fused_softmax)
+
+
+def legacy_model_rms_norm_adaptation(aspm):
+ from .core.fusions.rms_norm import rms_norm_init_wrapper, rms_norm_forward_wrapper, rms_norm_norm_wrapper
+ aspm.register_patch('megatron.legacy.model.rms_norm.RMSNorm.__init__', rms_norm_init_wrapper)
+ aspm.register_patch('megatron.legacy.model.rms_norm.RMSNorm.forward', rms_norm_forward_wrapper)
+ aspm.register_patch('megatron.legacy.model.rms_norm.RMSNorm._norm', rms_norm_norm_wrapper)
+
+
+def legacy_model_transformer_l0(aspm):
+ from .model.transformer import parallel_mlp_init_wrapper, flash_self_attention_forward, \
+ flash_self_attention_init_wrapper, parallel_transformer_forward_wrapper, flash_self_attention_init_add_config_wrapper
+ from .model.transformer import parallel_attention_init, parallel_attention_forward
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelTransformer.forward',
+ parallel_transformer_forward_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelMLP.__init__', parallel_mlp_init_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.FlashSelfAttention.forward', flash_self_attention_forward)
+ aspm.register_patch('megatron.legacy.model.transformer.FlashSelfAttention.__init__',
+ flash_self_attention_init_add_config_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.FlashSelfAttention.__init__',
+ flash_self_attention_init_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelAttention.__init__', parallel_attention_init)
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelAttention.forward',
+ parallel_attention_forward)
+
+
+def legacy_model_transformer(aspm, args):
+ from .model.transformer import parallel_mlp_forward, parallel_transformer_init_wrapper, \
+ parallel_transformer_init
+ from .model.transformer import core_attention_init_wrapper, core_attention_forward
+ from .core.transformer.transformer import parallel_transformer_checkpointed_forward_wrapper
+ from .model.transformer import switch_mlp_init_wrapper, switch_mlp_forward_wrapper, \
+ parallel_transformer_layer_init_wrapper
+ if not args.automated_pipeline and args.noop_layers:
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelTransformer.__init__', parallel_transformer_init)
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelTransformer.__init__',
+ parallel_transformer_init_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelMLP.forward', parallel_mlp_forward)
+ aspm.register_patch('megatron.legacy.model.transformer.CoreAttention.__init__', core_attention_init_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.CoreAttention.forward', core_attention_forward)
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelTransformer._checkpointed_forward',
+ parallel_transformer_checkpointed_forward_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.SwitchMLP.__init__', switch_mlp_init_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.SwitchMLP.forward', switch_mlp_forward_wrapper)
+ aspm.register_patch('megatron.legacy.model.transformer.ParallelTransformerLayer.__init__',
+ parallel_transformer_layer_init_wrapper)
+
+
+def megatron_training_adaptation_l0(aspm):
+ from .initialize import _compile_dependencies, set_jit_fusion_options_wrapper
+ from .utils import get_batch_on_this_cp_rank
+ from .training import pretrain, get_device_wrapper
+ from .arguments import parse_args_wrapper, validate_args_wrapper, core_transformer_config_from_args_wrapper
+ from .yaml_arguments import core_transformer_config_from_yaml_wrapper, print_args_wrapper
+
+ from .core.training import train_decorator, train_step_decorator
+ from .core.transformer.transformer_config import transformer_config_post_init_wrapper
+ aspm.register_patch('megatron.training.training.train', train_decorator)
+ aspm.register_patch('megatron.training.training.train_step', train_step_decorator)
+ aspm.register_patch('megatron.training.yaml_arguments.core_transformer_config_from_yaml',
+ core_transformer_config_from_yaml_wrapper)
+ aspm.register_patch('megatron.training.initialize._compile_dependencies', _compile_dependencies)
+ aspm.register_patch('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank)
+ aspm.register_patch('megatron.training.arguments.parse_args', parse_args_wrapper)
+ aspm.register_patch('megatron.training.arguments.validate_args', validate_args_wrapper)
+ aspm.register_patch('megatron.training.arguments._print_args', print_args_wrapper)
+ aspm.register_patch('megatron.training.yaml_arguments.validate_yaml', validate_args_wrapper)
+ aspm.register_patch('megatron.training.yaml_arguments._print_args', print_args_wrapper)
+ aspm.register_patch('megatron.training.arguments.core_transformer_config_from_args',
+ core_transformer_config_from_args_wrapper)
+ aspm.register_patch('megatron.training.initialize.set_jit_fusion_options', set_jit_fusion_options_wrapper)
+ aspm.register_patch('megatron.training.training.pretrain', pretrain)
+ aspm.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__',
+ transformer_config_post_init_wrapper)
+ aspm.register_patch('megatron.training.dist_signal_handler.get_device', get_device_wrapper)
+
+
+def megatron_training_adaptation(aspm, mindspeed_args):
+ from .core.performance.auto_pipeline_perf.global_vars import get_num_microbatches_wrapper
+ from .core.training import training_log
+ from .utils import get_batch_on_this_tp_rank
+ from .tokenizer import build_tokenizer_wrapper
+ from .core.training import pretrain_decorator, setup_model_and_optimizer_decorator
+ aspm.register_patch('megatron.core.num_microbatches_calculator.get_num_microbatches', get_num_microbatches_wrapper)
+ aspm.register_patch('megatron.training.training.pretrain', pretrain_decorator)
+ aspm.register_patch('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_decorator)
+ aspm.register_patch('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
+ if mindspeed_args.op_cal_tflops:
+ aspm.register_patch('megatron.training.training.training_log', training_log)
+ aspm.register_patch('megatron.training.tokenizer.tokenizer.build_tokenizer', build_tokenizer_wrapper)
+
+
+def megatron_training_ema_adaptation(aspm, mindspeed_args):
+ if mindspeed_args.optimizer_selection == 'fused_ema_adamw':
+ from .checkpointing import generate_state_dict_ema_wrapper, save_checkpoint_ema_wrapper
+ from .optimizer.distrib_optimizer import ema_distrib_optimizer_init_wrapper
+ aspm.register_patch('megatron.training.checkpointing.save_checkpoint', save_checkpoint_ema_wrapper)
+ aspm.register_patch('megatron.training.checkpointing.generate_state_dict', generate_state_dict_ema_wrapper)
+ aspm.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__',
+ ema_distrib_optimizer_init_wrapper)
+ if hasattr(mindspeed_args, "ema_decay"):
+ from .optimizer.optimizer import get_megatron_optimizer_func_wrapper
+ aspm.register_patch('megatron.core.optimizer.get_megatron_optimizer',
+ get_megatron_optimizer_func_wrapper)
+ elif mindspeed_args.use_ema:
+ from .training import pretrain, train_step
+ from .checkpointing import save_checkpoint, _load_base_checkpoint
+ aspm.register_patch('megatron.training.training.train_step', train_step)
+ aspm.register_patch('megatron.training.checkpointing.save_checkpoint', save_checkpoint)
+ aspm.register_patch('megatron.training.checkpointing._load_base_checkpoint', _load_base_checkpoint)
+
+
+def memory_fragmentation_adaptation(aspm, args):
+ from megatron.legacy.model.transformer import ParallelTransformerLayer
+ if args.memory_fragmentation:
+ from .core.memory.memory_fragmentation.pluggable_allocator_adpator import change_allocator
+ time.sleep(10)
+ change_allocator()
+
+ from .core.memory.memory_fragmentation.memory_recorder import memory_recorder_wrapper
+ aspm.register_patch('megatron.training.training.setup_model_and_optimizer', memory_recorder_wrapper)
+
+ from .core.memory.memory_fragmentation.malloc_recorder import malloc_recorder_wrapper
+ aspm.register_patch('megatron.training.training.train_step', malloc_recorder_wrapper)
+
+ from .core.memory.memory_fragmentation.optimizer_init_precise import optimizer_init_wrapper
+ aspm.register_patch('megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step', optimizer_init_wrapper)
+
+ from .core.memory.adaptive_recomputing.adaptive_recompute import allowed_recomputing_module_wrapper
+ allowed_recomputing_module_wrapper(ParallelTransformerLayer)
+ from .core.memory.adaptive_recomputing.adaptive_recompute import setup_model_and_optimizer_wrapper
+ aspm.register_patch('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_wrapper)
+ if (args.adaptive_recompute_enable and not args.memory_fragmentation) or args.swap_attention:
+ from .core.memory.adaptive_recomputing.adaptive_recompute import allowed_recomputing_module_wrapper
+ if hasattr(args, "use_legacy_models") and not args.use_legacy_models:
+ from megatron.core.transformer.transformer_layer import TransformerLayer
+ allowed_recomputing_module_wrapper(TransformerLayer)
+ else:
+ allowed_recomputing_module_wrapper(ParallelTransformerLayer)
+ from .core.memory.adaptive_recomputing.adaptive_recompute import setup_model_and_optimizer_wrapper
+ aspm.register_patch('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_wrapper)
+ if args.smart_swap and (not args.memory_fragmentation and not args.adaptive_recompute_enable):
+ from .core.memory.smart_swap.swap_adaptor import change_allocator
+ time.sleep(10)
+ change_allocator()
+ from .core.memory.smart_swap.swap_megatron_adaptor import train_step_wrapper
+ aspm.register_patch('megatron.training.training.train_step', train_step_wrapper)
+ if args.adaptive_memory_optimization and not (args.adaptive_recompute_enable or args.memory_fragmentation or args.swap_attention or args.smart_swap):
+ from .core.memory.adaptive_memory.adaptive_memory_opt import addup_allowed_mem_adapt_module
+ if hasattr(args, "use_legacy_models") and args.use_legacy_models:
+ addup_allowed_mem_adapt_module(ParallelTransformerLayer)
+ else:
+ from megatron.core.transformer.transformer_layer import TransformerLayer
+ addup_allowed_mem_adapt_module(TransformerLayer)
+ from .core.memory.adaptive_memory.adaptive_memory_opt import setup_adapt_memory_optimizer_wrapper
+ aspm.register_patch('megatron.training.training.setup_model_and_optimizer', setup_adapt_memory_optimizer_wrapper)
+ from .core.memory.adaptive_recomputing.pluggable_allocator_adpator import change_allocator
+ time.sleep(10)
+ change_allocator()
+
+ if os.getenv('OOTB_OPTIMIZER_PROFILING', 'FALSE') == 'TRUE':
+ print(f"OOTB_OPTIMIZER_PROFILING success open")
+ from .core.memory.adaptive_recomputing.pluggable_allocator_adpator import change_allocator
+ import megatron.training
+ from mindspeed.auto_tuning.module.parse.recompute_parser import allowed_recompute_parser_module_wrapper
+ allowed_recompute_parser_module_wrapper(megatron.legacy.model.transformer.ParallelTransformerLayer)
+ from mindspeed.auto_tuning.module.parse.recompute_parser import setup_model_and_optimizer_decorator
+ aspm.register_patch('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_decorator)
+ print(f"setup_model_and_optimizer_decorator success")
+
+ if args.adaptive_recompute_enable or args.memory_fragmentation:
+ import megatron.training.initialize
+ aspm.register_patch('megatron.training.initialize_megatron', megatron.training.initialize.initialize_megatron)
+
+
+def mcore_moe_adaptation_l0(pm):
+ from .core.transformer.moe.grouped_gemm_util import Ops, grouped_gemm_is_available, get_device_capability, \
+ assert_grouped_gemm_is_available
+ pm.register_patch('megatron.core.transformer.moe.grouped_gemm_util.ops', Ops)
+ pm.register_patch('megatron.core.transformer.moe.grouped_gemm_util.grouped_gemm_is_available',
+ grouped_gemm_is_available)
+ pm.register_patch('megatron.core.transformer.moe.grouped_gemm_util.assert_grouped_gemm_is_available',
+ assert_grouped_gemm_is_available)
+ pm.register_patch('torch.cuda.get_device_capability', get_device_capability)
+
+
+def mcore_moe_adaptation(pm, args):
+ from .core.pipeline_parallel.schedules import forward_step
+ pm.register_patch('megatron.core.pipeline_parallel.schedules.forward_step',
+ forward_step)
+ if args.moe_permutation_async_comm:
+ if hasattr(args, 'moe_token_dispatcher_type') and args.moe_token_dispatcher_type == 'alltoall':
+ from .core.transformer.moe.experts import sequential_mlp_forward
+ from .core.transformer.moe.moe_utils import permute, unpermute
+ if args.moe_tp_extend_ep:
+ from .core.transformer.moe.token_dispatcher import (
+ preprocess_tp_extend_ep, alltoall_token_unpermutation_tp_extend_ep,
+ alltoall_token_permutation_tp_extend_ep
+ )
+ from .core.transformer.moe.router import routing_tp_extend_ep
+ from .core.transformer.moe.moe_layer import base_moe_init_wrapper
+ pm.register_patch('megatron.core.transformer.moe.moe_layer.BaseMoELayer.__init__',
+ base_moe_init_wrapper)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.preprocess',
+ preprocess_tp_extend_ep)
+ pm.register_patch('megatron.core.transformer.moe.router.TopKRouter.routing', routing_tp_extend_ep)
+
+ if args.moe_alltoall_overlap_comm:
+ from .core.transformer.moe.token_dispatcher import alltoall_token_permutation_new, \
+ alltoall_token_unpermutation_new
+ from .core.transformer.moe.experts import group_mlp_forward
+ from .core.transformer.mlp import mlp_init
+ from .core.transformer.moe.moe_layer import moe_layer_init
+ pm.register_patch('megatron.core.transformer.mlp.MLP.__init__', mlp_init)
+ pm.register_patch('megatron.core.transformer.moe.experts.GroupedMLP.forward', group_mlp_forward)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
+ alltoall_token_permutation_new)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation',
+ alltoall_token_unpermutation_new)
+ pm.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.__init__', moe_layer_init)
+ else:
+ pm.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
+ alltoall_token_permutation_tp_extend_ep)
+ pm.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation',
+ alltoall_token_unpermutation_tp_extend_ep)
+ else:
+ from .core.transformer.moe.token_dispatcher import preprocess, alltoall_token_permutation, \
+ alltoall_token_unpermutation_with_bmm
+ pm.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.preprocess',
+ preprocess)
+ if args.moe_alltoall_overlap_comm:
+ from .core.transformer.moe.token_dispatcher import alltoall_token_permutation_new, \
+ alltoall_token_unpermutation_new
+ from .core.transformer.moe.experts import group_mlp_forward
+ from .core.transformer.mlp import mlp_init
+ from .core.transformer.moe.moe_layer import moe_layer_init
+ pm.register_patch('megatron.core.transformer.mlp.MLP.__init__', mlp_init)
+ pm.register_patch('megatron.core.transformer.moe.experts.GroupedMLP.forward', group_mlp_forward)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
+ alltoall_token_permutation_new)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation',
+ alltoall_token_unpermutation_new)
+ pm.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.__init__', moe_layer_init)
+ else:
+ pm.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
+ alltoall_token_permutation)
+ if args.moe_bmm_mc2:
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation',
+ alltoall_token_unpermutation_with_bmm)
+ pm.register_patch('megatron.core.transformer.moe.experts.SequentialMLP.forward', sequential_mlp_forward)
+ pm.register_patch('megatron.core.transformer.moe.moe_utils.permute', permute)
+ pm.register_patch('megatron.core.transformer.moe.moe_utils.unpermute', unpermute)
+ else:
+ from .core.transformer.moe.router import aux_loss_load_balancing
+ pm.register_patch('megatron.core.transformer.moe.router.TopKRouter.aux_loss_load_balancing', aux_loss_load_balancing)
+
+ if args.moe_tp_extend_ep:
+ from .core.transformer.moe.moe_layer import base_moe_init_wrapper
+ pm.register_patch('megatron.core.transformer.moe.moe_layer.BaseMoELayer.__init__', base_moe_init_wrapper)
+
+ if args.moe_allgather_overlap_comm:
+ from .core.transformer.moe.token_dispatcher import (allgather_token_permutation_new,
+ allgather_token_unpermutation_new)
+ from .core.transformer.moe.experts import group_mlp_forward
+ from .core.transformer.mlp import mlp_init
+ pm.register_patch('megatron.core.transformer.mlp.MLP.__init__', mlp_init)
+ pm.register_patch('megatron.core.transformer.moe.experts.GroupedMLP.forward', group_mlp_forward)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher.token_permutation',
+ allgather_token_permutation_new)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher.token_unpermutation',
+ allgather_token_unpermutation_new)
+ else:
+ from .core.transformer.moe.token_dispatcher import (allgather_token_permutation,
+ allgather_token_unpermutation)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher.token_permutation',
+ allgather_token_permutation)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher.token_unpermutation',
+ allgather_token_unpermutation)
+
+ from .core.transformer.moe.moe_layer import moe_layer_init_wrapper
+ pm.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.__init__', moe_layer_init_wrapper)
+ else:
+ if hasattr(args, 'moe_token_dispatcher_type') and args.moe_token_dispatcher_type == 'alltoall':
+ from .core.transformer.moe.token_dispatcher import alltoall_preprocess_npu, \
+ alltoall_token_unpermutation_with_bmm, alltoall_token_permutation_with_bmm
+ pm.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.preprocess',
+ alltoall_preprocess_npu)
+ if args.moe_bmm_mc2:
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation',
+ alltoall_token_permutation_with_bmm)
+ pm.register_patch(
+ 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation',
+ alltoall_token_unpermutation_with_bmm)
+ else:
+ from .core.transformer.moe.token_dispatcher import allgather_token_permutation_npu
+ pm.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher.token_permutation', allgather_token_permutation_npu)
+
+ from .core.transformer.moe.experts import groupedmlp_init_wrapper, groupedmlp_forward
+ pm.register_patch('megatron.core.transformer.moe.experts.GroupedMLP.__init__', groupedmlp_init_wrapper)
+ if not args.moe_alltoall_overlap_comm and not args.moe_allgather_overlap_comm:
+ pm.register_patch('megatron.core.transformer.moe.experts.GroupedMLP.forward', groupedmlp_forward)
+
+ if args.use_ascend_mc2 and not hasattr(args, 'moe_grouped_gemm'):
+ # MoE MLP not use mc2 linear
+ from .core.models.gpt.gpt_layer_specs import build_layers_wrapper
+ from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
+ from megatron.core.transformer.transformer_block import TransformerBlock
+ TransformerBlock._build_layers = build_layers_wrapper(TransformerBlock._build_layers,
+ ColumnParallelLinear.forward,
+ RowParallelLinear.forward)
+
+
+def deepspeed_moe_adaptation(pm, args):
+ if args.use_pipe_experts or args.use_nanopipe or args.ampipe_degree > 1:
+ from .core.tensor_parallel.layers import (row_parallel_moe, column_parallel_moe,
+ linear_with_grad_accumulation_and_async_allreduce_moe)
+ pm.register_patch('megatron.core.tensor_parallel.layers.RowParallelLinear.forward', row_parallel_moe)
+ pm.register_patch('megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward', column_parallel_moe)
+ pm.register_patch('megatron.core.tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce',
+ linear_with_grad_accumulation_and_async_allreduce_moe)
+ if args.use_pipe_experts:
+ from .core.distributed.param_and_grad_buffer import pipe_register_grad_ready
+ pm.register_patch('megatron.core.distributed.ParamAndGradBuffer.register_grad_ready', pipe_register_grad_ready)
+ if args.ampipe_degree > 1:
+ from mindspeed.model.language_model import embedding_forward_ampipe
+ from mindspeed.model.transformer import parallel_transformer_forward_ampipe
+ from mindspeed.model.transformer import parallel_transformer_layer_forward_ampipe
+ pm.register_patch('megatron.legacy.model.language_model.Embedding.forward', embedding_forward_ampipe)
+ pm.register_patch('megatron.legacy.model.transformer.ParallelTransformer.forward',
+ parallel_transformer_forward_ampipe)
+ pm.register_patch('megatron.legacy.model.transformer.ParallelTransformerLayer.forward',
+ parallel_transformer_layer_forward_ampipe)
+
+
+def coc_adaptation(aspm, args):
+ from .initialize import coc_registration_wrapper, mc2_wrapper
+ if args.use_ascend_mc2:
+ from .core.memory.auto_pipeline.autopipeline import initialize_cfg_from_args_wrapper
+ aspm.register_patch('megatron.training.initialize.initialize_megatron', mc2_wrapper)
+ aspm.register_patch('mindspeed.core.tensor_parallel.ascend_turbo.initialize.initialize_cfg_from_args',
+ initialize_cfg_from_args_wrapper)
+ if args.use_ascend_coc:
+ aspm.register_patch('megatron.training.initialize.initialize_megatron', coc_registration_wrapper)
+
+
+def zero3_adaptation(aspm, args):
+ if args.enable_zero3:
+ from .core.data_parallel.distributed_data_parallel import distributed_data_parallel_init_zero3, \
+ distributed_data_parallel_zero_grad_wrapper
+ from .core.tensor_parallel.layers import (parallel_linear_init_zero3_wrapper,
+ column_parallel_linear_forward_zero3,
+ linear_forward_zero3_wrapper, linear_backward_zero3_wrapper,
+ row_parallel_linear_forward_zero3,
+ linear_with_grad_accumulation_and_async_allreduce_zero3)
+ from .optimizer.distrib_optimizer import (build_optimizer_group_ranges_zero3_wrapper,
+ _copy_main_params_to_model_params_zero3,
+ _copy_model_grads_to_main_grads_zero3,
+ build_model_and_main_param_groups_zero3_wrapper,
+ distributed_optimizer_zero3_init)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce',
+ linear_with_grad_accumulation_and_async_allreduce_zero3)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.RowParallelLinear.__init__',
+ parallel_linear_init_zero3_wrapper)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__',
+ parallel_linear_init_zero3_wrapper)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward',
+ column_parallel_linear_forward_zero3)
+ aspm.register_patch('megatron.core.tensor_parallel.layers.RowParallelLinear.forward',
+ row_parallel_linear_forward_zero3)
+ aspm.register_patch(
+ 'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer._build_optimizer_group_ranges',
+ build_optimizer_group_ranges_zero3_wrapper)
+ aspm.register_patch(
+ 'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer._copy_main_params_to_model_params',
+ _copy_main_params_to_model_params_zero3)
+ aspm.register_patch(
+ 'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer._copy_model_grads_to_main_grads',
+ _copy_model_grads_to_main_grads_zero3)
+ aspm.register_patch(
+ 'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer._build_model_and_main_param_groups',
+ build_model_and_main_param_groups_zero3_wrapper)
+ aspm.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__',
+ distributed_optimizer_zero3_init)
+ aspm.register_patch(
+ 'megatron.core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.forward',
+ linear_forward_zero3_wrapper)
+ aspm.register_patch(
+ 'megatron.core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.backward',
+ linear_backward_zero3_wrapper)
+ aspm.register_patch('megatron.core.distributed.distributed_data_parallel.DistributedDataParallel.__init__',
+ distributed_data_parallel_init_zero3)
+ aspm.register_patch(
+ 'megatron.core.distributed.distributed_data_parallel.DistributedDataParallel.zero_grad_buffer',
+ distributed_data_parallel_zero_grad_wrapper)
+
+
+def tensor_2d_adaptation(aspm, args):
+ if args.tp_2d:
+ from mindspeed.core.tensor_parallel.tp_2d.norm_factory import get_norm_tp_2d
+ from mindspeed.core.tensor_parallel.tp_2d.norm_factory import _allreduce_layernorm_grads_wrapper
+ from mindspeed.core.models.common.embeddings.rotary_pos_embedding import rotary_embedding_forward_wrapper
+ from mindspeed.core.pipeline_parallel.flexible_schedules import forward_backward_pipelining_with_interleaving_patch
+ aspm.register_patch('megatron.legacy.model.utils.get_norm', get_norm_tp_2d)
+ aspm.register_patch('megatron.core.distributed.finalize_model_grads._allreduce_layernorm_grads',
+ _allreduce_layernorm_grads_wrapper)
+ aspm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.forward',
+ rotary_embedding_forward_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
+ forward_backward_pipelining_with_interleaving_patch)
+ from .core.transformer.transformer_config import transformer_config_post_init
+ aspm.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__',
+ transformer_config_post_init)
+ from mindspeed.model.language_model import model_parallel_config_post_init_wrapper
+ aspm.register_patch('megatron.core.model_parallel_config.ModelParallelConfig.__post_init__',
+ model_parallel_config_post_init_wrapper)
+ from mindspeed.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_wrapper
+ aspm.register_patch('megatron.core.models.gpt.gpt_layer_specs._get_mlp_module_spec',
+ get_mlp_module_spec_wrapper)
+ from mindspeed.core.transformer.attention import self_attention_init_tp2d_wrapper
+ aspm.register_patch('megatron.core.transformer.attention.SelfAttention.__init__', self_attention_init_tp2d_wrapper)
+
+
+def megatron_training_adaptation_with_layerzero(aspm, mindspeed_args):
+ '''This function is used to add layerzero feature within mindspeed
+ layerzero manages the paramter in a different manner compared to Megatron Optimizer
+
+ So if layerzero is on, setup_model_and_optimizer will return a module wrapped by layerzero and the Optimizer will be replaced.
+ '''
+ if mindspeed_args.layerzero:
+ from mindspeed.core.distributed.layerzero import (layerzero_setup_model_and_optimizer_wrapper,
+ layerzero_initialize_model_parallel_wrapper,
+ mga_finalize_model_grads_wrapper,
+ save_checkpoint,
+ )
+ aspm.register_patch('megatron.training.training.setup_model_and_optimizer', layerzero_setup_model_and_optimizer_wrapper)
+ aspm.register_patch('megatron.core.parallel_state.initialize_model_parallel', layerzero_initialize_model_parallel_wrapper)
+ aspm.register_patch('megatron.core.distributed.finalize_model_grads', mga_finalize_model_grads_wrapper)
+ aspm.register_patch('megatron.training.checkpointing.save_checkpoint', save_checkpoint)
+
+
+def auto_parallel_mm_adaptation(aspm, mindspeed_args):
+ from mindspeed.core.auto_parallel.mm_search.schedules import backward_step_decorator
+ if mindspeed_args.auto_parallel_mm or mindspeed_args.auto_parallel_profile:
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.backward_step',
+ backward_step_decorator)
+
+
+def dist_train_adaptation(aspm, args):
+ if args.dist_train:
+ from mindspeed.multi_modal import dist_train
+ # pipeline parallel adaption
+ aspm.register_patch('megatron.core.pipeline_parallel.schedules.get_forward_backward_func', dist_train.pipeline_parallel.dist_schedules.get_forward_backward_func_wrapper)
+ aspm.register_patch('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', dist_train.pipeline_parallel.dist_schedules.p2p_ops_wrapper)
+ # parallel state adaption
+ aspm.register_patch('megatron.training.initialize._initialize_distributed', dist_train.training.initialize_distributed_wrapper)
+ aspm.register_patch('megatron.core.mpu.initialize_model_parallel', dist_train.parallel_state.initialize_model_parallel)
+ aspm.register_patch('megatron.core.mpu.is_pipeline_last_stage', dist_train.parallel_state.get_is_pipeline_last_stage_wrapper)
+ aspm.register_patch('megatron.core.mpu.is_pipeline_first_stage', dist_train.parallel_state.get_is_pipeline_first_stage_wrapper)
+ aspm.register_patch('megatron.core.mpu.get_tensor_model_parallel_src_rank', dist_train.parallel_state.get_tensor_model_parallel_src_rank_wrapper)
+ aspm.register_patch('megatron.core.mpu.is_initialized', dist_train.parallel_state.is_initialized)
+ aspm.register_patch('megatron.core.mpu.model_parallel_is_initialized', dist_train.parallel_state.model_parallel_is_initialized)
+ aspm.register_patch('megatron.core.mpu.get_model_parallel_group', dist_train.parallel_state.get_model_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_tensor_model_parallel_group', dist_train.parallel_state.get_tensor_model_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_group', dist_train.parallel_state.get_pipeline_model_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_data_parallel_group', dist_train.parallel_state.get_data_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_data_parallel_group_gloo', dist_train.parallel_state.get_data_parallel_group_gloo)
+ aspm.register_patch('megatron.core.mpu.get_context_parallel_group', dist_train.parallel_state.get_context_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_context_parallel_global_ranks', dist_train.parallel_state.get_context_parallel_global_ranks)
+ aspm.register_patch('megatron.core.mpu.get_embedding_group', dist_train.parallel_state.get_embedding_group)
+ aspm.register_patch('megatron.core.mpu.get_position_embedding_group', dist_train.parallel_state.get_position_embedding_group)
+ aspm.register_patch('megatron.core.mpu.get_data_modulo_expert_parallel_group_gloo', dist_train.parallel_state.get_data_modulo_expert_parallel_group_gloo)
+ aspm.register_patch('megatron.core.mpu.get_amax_reduction_group', dist_train.parallel_state.get_amax_reduction_group)
+ aspm.register_patch('megatron.core.mpu.get_tensor_and_data_parallel_group', dist_train.parallel_state.get_tensor_and_data_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_tensor_and_context_parallel_group', dist_train.parallel_state.get_tensor_and_context_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_expert_model_parallel_group', dist_train.parallel_state.get_expert_model_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_tensor_and_expert_parallel_group', dist_train.parallel_state.get_tensor_and_expert_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_data_modulo_expert_parallel_group', dist_train.parallel_state.get_data_modulo_expert_parallel_group)
+ aspm.register_patch('megatron.core.mpu.get_tensor_model_parallel_world_size', dist_train.parallel_state.get_tensor_model_parallel_world_size)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_world_size', dist_train.parallel_state.get_pipeline_model_parallel_world_size)
+ aspm.register_patch('megatron.core.mpu.get_tensor_model_parallel_rank', dist_train.parallel_state.get_tensor_model_parallel_rank)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_rank', dist_train.parallel_state.get_pipeline_model_parallel_rank)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_split_rank', dist_train.parallel_state.get_pipeline_model_parallel_split_rank)
+ aspm.register_patch('megatron.core.mpu.is_rank_in_embedding_group', dist_train.parallel_state.is_rank_in_embedding_group)
+ aspm.register_patch('megatron.core.mpu.is_rank_in_position_embedding_group', dist_train.parallel_state.is_rank_in_position_embedding_group)
+ aspm.register_patch('megatron.core.mpu.get_virtual_pipeline_model_parallel_rank', dist_train.parallel_state.get_virtual_pipeline_model_parallel_rank)
+ aspm.register_patch('megatron.core.mpu.get_virtual_pipeline_model_parallel_world_size', dist_train.parallel_state.get_virtual_pipeline_model_parallel_world_size)
+ aspm.register_patch('megatron.core.mpu.get_data_parallel_src_rank', dist_train.parallel_state.get_data_parallel_src_rank)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_first_rank', dist_train.parallel_state.get_pipeline_model_parallel_first_rank)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_last_rank', dist_train.parallel_state.get_pipeline_model_parallel_last_rank)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_next_rank', dist_train.parallel_state.get_pipeline_model_parallel_next_rank)
+ aspm.register_patch('megatron.core.mpu.get_pipeline_model_parallel_prev_rank', dist_train.parallel_state.get_pipeline_model_parallel_prev_rank)
+ aspm.register_patch('megatron.core.mpu.get_expert_model_parallel_world_size', dist_train.parallel_state.get_expert_model_parallel_world_size)
+ aspm.register_patch('megatron.core.mpu.get_expert_model_parallel_rank', dist_train.parallel_state.get_expert_model_parallel_rank)
+ aspm.register_patch('megatron.core.mpu.get_global_memory_buffer', dist_train.parallel_state.get_global_memory_buffer)
+ aspm.register_patch('megatron.core.mpu.get_moe_layer_wise_logging_tracker', dist_train.parallel_state.get_moe_layer_wise_logging_tracker)
+ # checkpoint
+ aspm.register_patch('megatron.training.checkpointing.get_checkpoint_name', dist_train.checkpointing.get_checkpoint_name_wrapper)
+
+
+def optimizer_selection(aspm, mindspeed_args):
+ if mindspeed_args.optimizer_selection == 'fused_torch_adamw':
+ from .optimizer.adamw import FusedTorchAdamW as AdamW
+ elif mindspeed_args.optimizer_selection == 'fused_adamw':
+ from .optimizer.adamw import AdamW
+ elif mindspeed_args.optimizer_selection == 'fused_ema_adamw':
+ from .optimizer.ema_adamw import FusedEmaAdamW as AdamW
+ aspm.register_patch('apex.optimizers.FusedAdam', AdamW, create_dummy=True)
+
+
+def adaptation_l0(aspm, mindspeed_args):
+ """
+ The minimum patch set for megatron to adapt to NPU
+ """
+ # transformer_engine
+ te_adaptation(aspm)
+ apex_adaptation(aspm)
+ torch_adaptation(aspm)
+ # Need replace transformer_engine modules before import megatron
+ aspm.apply_patches()
+
+ mcore_models_adaptation_l0(aspm)
+ mcore_tensor_parallel_adaptation_l0(aspm)
+ mcore_transformer_adaptation_l0(aspm)
+ mcore_moe_adaptation_l0(aspm)
+ legacy_model_transformer_l0(aspm)
+ megatron_training_adaptation_l0(aspm)
+ # context parallel(ring attention) requires mcore parallel state patch
+ mcore_parallel_state_adaptation(aspm)
+ communication_adaptation(aspm, mindspeed_args)
+
+
+def adaptation_l1(aspm, mindspeed_args):
+ """
+ Affinity optimization (fusion operators, etc.)
+ """
+ # fusion operators
+ mcore_fusions_adaptation(aspm, mindspeed_args)
+ legacy_model_fusions_adaptation(aspm)
+ # affinity optimization
+ mcore_tensor_parallel_adaptation_l1(aspm)
+
+
+def adaptation_l2(aspm, mindspeed_args):
+ """
+ Advanced acceleration algorithm
+ """
+ mcore_models_adaptation(aspm, mindspeed_args)
+ mcore_optimizer_adapation(aspm, mindspeed_args)
+ mcore_pipeline_parallel_adaptation(aspm, mindspeed_args)
+ mcore_multiparam_pipeline_parallel_adaptation(aspm, mindspeed_args)
+ mcore_tensor_parallel_adaptation(aspm, mindspeed_args)
+ mcore_transformer_adaptation(aspm, mindspeed_args)
+
+ # megatron legacy
+ megatron_legacy_adaptation(aspm)
+ legacy_model_transformer(aspm, mindspeed_args)
+ legacy_model_rms_norm_adaptation(aspm)
+
+ megatron_training_adaptation(aspm, mindspeed_args)
+ megatron_training_ema_adaptation(aspm, mindspeed_args)
+ memory_fragmentation_adaptation(aspm, mindspeed_args)
+ coc_adaptation(aspm, mindspeed_args)
+ mcore_moe_adaptation(aspm, mindspeed_args)
+ deepspeed_moe_adaptation(aspm, mindspeed_args)
+ zero3_adaptation(aspm, mindspeed_args)
+ tensor_2d_adaptation(aspm, mindspeed_args)
+ auto_parallel_mm_adaptation(aspm, mindspeed_args)
+ dist_train_adaptation(aspm, mindspeed_args)
+
+
+def delete_lock_file(directory, lock):
+ with lock:
+ flag_lock = False
+ if os.path.exists(directory):
+ for root, dirs, files in os.walk(directory):
+ for name in files:
+ if name.endswith('.lock') or name.endswith('lock'):
+ if os.path.exists(directory):
+ flag_lock = True
+ print(f"Process (PID: {os.getpid()}) is deleting Lock directory")
+ shutil.rmtree(directory)
+ print(f"Process (PID: {os.getpid()}) deleted Lock directory")
+ if flag_lock:
+ break
+ else:
+ print(f"Process (PID: {os.getpid()}) Directory {directory} does not exist.")
+ if flag_lock:
+ break
+
+
+def exe_adaptation():
+ modified_argv_path = os.getenv("OOTB_OPTIMIZER_MODIFIED_ARGV_PATH", None)
+ if modified_argv_path:
+ from mindspeed.auto_tuning.mindspeed_adaptor import MindSpeedAdaptor
+ MindSpeedAdaptor.set_argv(sys.argv, modified_argv_path)
+ print("================OOTB_OPTIMIZER_MODIFIED_ARGV DONE!====================")
+ mindspeed_args = get_mindspeed_args()
+
+ from torch.utils.cpp_extension import _get_build_directory
+ build_directory = _get_build_directory("", True)
+ delete_lock = Lock()
+ delete_lock_file(build_directory, delete_lock)
+ mindspeed_args.adaptive_recompute_enable = mindspeed_args.adaptive_recompute_device_size > 0 or mindspeed_args.adaptive_recompute_device_swap
+ if (mindspeed_args.adaptive_recompute_enable and not mindspeed_args.memory_fragmentation) or mindspeed_args.swap_attention:
+ from .core.memory.adaptive_recomputing.pluggable_allocator_adpator import change_allocator
+ if not mindspeed_args.swap_attention:
+ time.sleep(10)
+ change_allocator()
+ from .patch_utils import MindSpeedPatchesManager as aspm
+
+ if mindspeed_args.optimization_level >= 0:
+ # The minimum patch set for megatron to adapt to NPU
+ optimizer_selection(aspm, mindspeed_args)
+ adaptation_l0(aspm, mindspeed_args)
+
+ if mindspeed_args.optimization_level >= 1:
+ # Affinity optimization (fusion operators, etc.)
+ adaptation_l1(aspm, mindspeed_args)
+
+ if mindspeed_args.optimization_level >= 2:
+ # Advanced acceleration algorithm
+ adaptation_l2(aspm, mindspeed_args)
+
+ if mindspeed_args.layerzero:
+ # layerzero features
+ megatron_training_adaptation_with_layerzero(aspm, mindspeed_args)
+
+ aspm.apply_patches()
+
+ # New features structure
+ for feature in FEATURES_LIST:
+ if getattr(mindspeed_args, feature.feature_name, None) or feature.default_patches:
+ feature.register_patches(aspm, mindspeed_args)
+
+ aspm.apply_patches()
+
+ # accelerate package will check TE on sys.modules,so we need remove this patch
+ del sys.modules['transformer_engine']
+
+
+exe_adaptation()
diff --git a/model/train/yoco_moe/mindspeed/model/__init__.py b/model/train/yoco_moe/mindspeed/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/model/alibi_mask.py b/model/train/yoco_moe/mindspeed/model/alibi_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd749b7626524b7dfb18773e8de63d41b251e502
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/model/alibi_mask.py
@@ -0,0 +1,79 @@
+import math
+import torch
+
+from megatron.core import parallel_state
+
+
+def get_slopes(n):
+ def get_slopes_power_of_2(n):
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
+ ratio = start
+ return [start * ratio ** i for i in range(n)]
+
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
+ :n - closest_power_of_2]
+
+
+class AlibiForFusionAttnSingleton:
+ _alibi_tensor_args = None
+ _alibi_tensor = None
+
+ _alibi_slopes_headnum = None
+ _alibi_slopes = None
+
+ @classmethod
+ def get_alibi_tensor_for_fusion_attn(cls, max_seq_len, num_attention_heads, dtype, neg_diagonal_opposite=False,
+ last_k=1024):
+ if cls._alibi_tensor is None or cls._alibi_tensor_args != (
+ max_seq_len, num_attention_heads, neg_diagonal_opposite, last_k):
+ if last_k > max_seq_len:
+ last_k = max_seq_len
+
+ tp_world_size = parallel_state.get_tensor_model_parallel_world_size()
+ current_head_num = num_attention_heads // tp_world_size
+ slopes = AlibiForFusionAttnSingleton.get_alibi_slopes_for_fusion_attn(num_attention_heads)
+
+ position_point = torch.arange(max_seq_len) - max_seq_len + 1
+ diag = torch.diag(torch.diag(position_point)).unsqueeze(0).unsqueeze(0)
+
+ position_point = position_point.unsqueeze(0).unsqueeze(0).expand(current_head_num, last_k, -1)
+ position_point = position_point - diag.transpose(-1, -2)[:, -last_k:, :].expand(current_head_num, last_k,
+ max_seq_len)
+
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point.npu()
+
+ if not neg_diagonal_opposite:
+ alibi = -torch.abs(alibi)
+
+ alibi = alibi.unsqueeze(0)
+ alibi = torch.Tensor(alibi).npu()
+ if dtype == torch.float16:
+ alibi = alibi.to(torch.float16)
+ elif dtype == torch.bfloat16:
+ alibi = alibi.to(torch.bfloat16)
+
+ cls._alibi_tensor = alibi
+ cls._alibi_tensor_args = (max_seq_len, num_attention_heads, neg_diagonal_opposite, last_k)
+
+ return cls._alibi_tensor
+
+ @classmethod
+ def get_alibi_slopes_for_fusion_attn(cls, n):
+ if cls._alibi_slopes is None or cls._alibi_slopes_headnum != n:
+ slopes = get_slopes(n)
+
+ tp_world_size = parallel_state.get_tensor_model_parallel_world_size()
+ tp_index = parallel_state.get_tensor_model_parallel_rank()
+
+ current_head_num = n // tp_world_size
+ slopes = torch.Tensor(slopes[tp_index * current_head_num: tp_index * current_head_num + current_head_num]).npu()
+
+ cls._alibi_slopes = slopes
+ cls._alibi_slopes_headnum = n
+ return cls._alibi_slopes
+
+ return cls._alibi_slopes
diff --git a/model/train/yoco_moe/mindspeed/model/language_model.py b/model/train/yoco_moe/mindspeed/model/language_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..79368b24831202f43fd860487643a89370c359cb
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/model/language_model.py
@@ -0,0 +1,161 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import wraps
+import torch
+from megatron.training import get_args
+from megatron.core import mpu, tensor_parallel
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm, TPYCollectiveComm
+from mindspeed.core.tensor_parallel.comm_autograd_function import (auto_grad_sync_gather_along_first_dim,
+ auto_grad_sync_gather_along_last_dim,
+ auto_grad_scatter_along_first_dim_then_last_dim)
+from mindspeed.moe.utils import get_slice_indices_from_disorder_to_order, get_slice_indices_from_order_to_disorder
+
+
+def parallel_lm_logits(
+ input_,
+ word_embeddings_weight,
+ parallel_output,
+ bias=None
+):
+ args = get_args()
+ # Parallel logits.
+ if args.async_tensor_model_parallel_allreduce or\
+ args.sequence_parallel:
+ input_parallel = input_
+ model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
+ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
+ model_parallel and not args.sequence_parallel
+ else:
+ input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
+ async_grad_allreduce = False
+
+ if args.use_nd_matmul:
+ input_parallel = tensor_parallel.gather_from_tensor_model_parallel_region(input_parallel)
+
+ if args.tp_2d:
+ input_parallel = auto_grad_sync_gather_along_first_dim(input_parallel, TPXCollectiveComm)
+ input_parallel = auto_grad_sync_gather_along_last_dim(input_parallel, TPYCollectiveComm)
+
+ # Matrix multiply.
+ logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
+ input=input_parallel,
+ weight=word_embeddings_weight,
+ bias=bias,
+ gradient_accumulation_fusion=args.gradient_accumulation_fusion,
+ async_grad_allreduce=async_grad_allreduce,
+ sequence_parallel=args.sequence_parallel)
+ # Gather if needed.
+ if parallel_output:
+ return logits_parallel
+
+ return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
+
+
+def embedding_forward_wrapper(forward):
+ @wraps(forward)
+ def wrapper(self, *args, **kwargs):
+ encoder_input = forward(self, *args, **kwargs)
+ if get_args().use_nd_matmul:
+ encoder_input = tensor_parallel.scatter_to_tensor_model_parallel_region(encoder_input)
+ if get_args().tp_2d:
+ encoder_input = auto_grad_scatter_along_first_dim_then_last_dim(
+ encoder_input, TPXCollectiveComm, TPYCollectiveComm
+ )
+ return encoder_input
+ return wrapper
+
+
+class AmpipeEmbeddingRearrange(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, embeddings, ampipe_degree):
+ seqlen = embeddings.size(0)
+ new_indices = get_slice_indices_from_disorder_to_order(seqlen, ampipe_degree, device=embeddings.device)
+ embeddings = torch.index_select(embeddings, dim=0, index=new_indices)
+ ctx.ampipe_degree = ampipe_degree
+ return embeddings
+
+ @staticmethod
+ def backward(ctx, grad_input):
+ seqlen = grad_input.size(0)
+ new_indices = get_slice_indices_from_order_to_disorder(seqlen, ctx.ampipe_degree, device=grad_input.device)
+ grad_input = torch.index_select(grad_input, dim=0, index=new_indices)
+ return grad_input, None
+
+
+def embedding_forward_ampipe(self, input_ids, position_ids, tokentype_ids=None):
+ # Embeddings.
+ words_embeddings = self.word_embeddings(input_ids)
+ if self.add_position_embedding:
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = words_embeddings + position_embeddings
+ else:
+ embeddings = words_embeddings
+
+ if tokentype_ids is not None:
+ assert self.tokentype_embeddings is not None
+ embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
+ else:
+ assert self.tokentype_embeddings is None
+
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
+ embeddings = embeddings.transpose(0, 1).contiguous()
+
+ # If the input flag for fp32 residual connection is set, convert for float.
+ if self.fp32_residual_connection:
+ embeddings = embeddings.float()
+
+ # Dropout.
+ if self.sequence_parallel:
+ ampipe_degree = get_args().ampipe_degree
+ if ampipe_degree > 1:
+ embeddings = AmpipeEmbeddingRearrange.apply(embeddings, ampipe_degree)
+ embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
+ # `scatter_to_sequence_parallel_region` returns a view, which prevents
+ # the original tensor from being garbage collected. Clone to facilitate GC.
+ # Has a small runtime cost (~0.5%).
+ if self.clone_scatter_output_in_embedding:
+ embeddings = embeddings.clone()
+ with tensor_parallel.get_cuda_rng_tracker().fork():
+ embeddings = self.embedding_dropout(embeddings)
+ else:
+ embeddings = self.embedding_dropout(embeddings)
+
+ return embeddings
+
+
+def model_parallel_config_post_init_wrapper(init_func):
+ @wraps(init_func)
+ def wrapper(self, *args, **kwargs):
+ if self.pipeline_model_parallel_size > 1:
+ if self.pipeline_dtype is None:
+ raise ValueError(
+ "When using pipeline parallelism, pipeline_dtype must be specified"
+ )
+
+ if self.autocast_dtype is None:
+ self.autocast_dtype = self.params_dtype
+
+ if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1:
+ raise ValueError(
+ "Cannot defer embedding wgrad compute when pipeline model parallel is not used"
+ )
+
+ if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion:
+ raise ValueError(
+ "Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
+ )
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/model/transformer.py b/model/train/yoco_moe/mindspeed/model/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1dd298a6cc1e0b51fb64d258d3d0b940a16f4ca
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/model/transformer.py
@@ -0,0 +1,1814 @@
+# coding=utf-8
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import math
+import threading
+from functools import wraps
+from contextlib import nullcontext
+from typing import Any
+import warnings
+
+import torch
+import torch_npu
+import torch.nn.functional as F
+import numpy as np
+from megatron.training import get_args, get_tokenizer
+from megatron.core import parallel_state, mpu, tensor_parallel
+from megatron.core.num_microbatches_calculator import get_num_microbatches
+from megatron.core.utils import make_viewless_tensor
+from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
+from megatron.legacy.model.utils import openai_gelu, erf_gelu, get_norm
+from megatron.legacy.model.transformer import ParallelMLP, ParallelTransformer, ParallelTransformerLayer, CoreAttention, \
+ FlashSelfAttention, ParallelAttention
+from megatron.core.enums import ModelType
+from megatron.legacy.model.enums import AttnType, AttnMaskType, LayerType
+from megatron.legacy.model.transformer import (_get_num_layers,
+ _get_layer_type,
+ bias_dropout_add_fused_train,
+ bias_dropout_add_fused_inference,
+ get_bias_dropout_add)
+from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl
+from megatron.core.tensor_parallel.mappings import _split_along_first_dim
+from megatron.core.transformer.module import MegatronModule
+
+from mindspeed.core.context_parallel.ulysses_context_parallel import UlyssesContextAttention
+from mindspeed.core.context_parallel.ring_context_parallel import ringattn_context_parallel
+from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid_ulysses,
+ get_context_parallel_group_for_hybrid_ring,
+ get_context_parallel_for_hybrid_ring_world_size,
+ get_context_parallel_for_hybrid_ring_rank,
+ get_context_parallel_for_hybrid_ring_global_ranks,
+ get_ring_ranks_for_intra_window,
+ get_ring_ranks_for_inter_window_kv,
+ get_ring_ranks_for_inter_window_dkv,
+ get_ring_group_for_intra_window,
+ get_ring_group_for_intra_window_send_recv_overlap)
+from mindspeed.core.fusions.fused_bias_swiglu import fused_swiglu
+from mindspeed.core.parallel_state import get_tensor_model_parallel_world_size_for_nd1_dim1
+from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPXOverlapCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPYCollectiveComm
+from mindspeed.core.tensor_parallel.comm_group_api import TPYOverlapCollectiveComm
+from mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d import ParallelLinear2D
+from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
+from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
+from mindspeed.moe.ampipe.ampipe import AttMoEPipe
+from mindspeed.ops.fusion_attention_v2 import npu_fusion_attention
+from mindspeed.core.tensor_parallel.layers import Nd_ParallelLinear
+from mindspeed.core.tensor_parallel.checkpoint_manager import get_pipeline_checkpoint_manager
+from mindspeed.model.alibi_mask import AlibiForFusionAttnSingleton, get_slopes
+from mindspeed.moe.ampipe.ampipe_args import ForwardArgs
+from mindspeed.moe.utils import (get_slice_indices_from_order_to_disorder,
+ get_slice_indices_from_disorder_to_order,
+ all_gather_along_first_dim)
+from mindspeed.core.context_parallel.adaptive_context_parallel import adaptive_attn_context_parallel
+from mindspeed.core.context_parallel.utils import get_scheduling_info
+
+try:
+ from einops import rearrange
+except ImportError:
+ rearrange = None
+
+_GLOBAL_ATTN_MASK = None
+
+
+class Alibi:
+ _instance = None
+ alibi = None
+ matmul_result = None
+ output_size = None
+ lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ if cls._instance:
+ return cls._instance
+ else:
+ with cls.lock:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+
+def _get_inverted_mask(attention_mask, alibi):
+ inverted_mask = attention_mask.to(alibi.dtype)
+ inverted_mask = inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), float("-inf")
+ )
+ return inverted_mask.to(alibi.device) + alibi.unsqueeze(0)
+
+
+def _build_alibi_tensor(max_seq_len, num_attention_heads, square_alibi_mask, fill_neg_inf):
+ def _fill_with_neg_inf(t):
+ """FP16-compatible function that fills a tensor with -inf."""
+ return t.float().fill_(float("-inf")).type_as(t)
+
+ def _buffered_future_mask(maxpos, alibi, attn_heads):
+ _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
+ _future_mask = _future_mask.unsqueeze(0) + alibi
+ return _future_mask[:attn_heads, :maxpos, :maxpos]
+
+ slopes = torch.Tensor(get_slopes(num_attention_heads))
+ if square_alibi_mask:
+ position_point = torch.arange(max_seq_len) - max_seq_len + 1
+ position_point = position_point.unsqueeze(0).unsqueeze(0).expand(num_attention_heads, max_seq_len, -1)
+ diag = torch.diag(position_point[0])
+ position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
+ else:
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(
+ num_attention_heads, -1, -1)
+
+ # Select the part of the tensor that corresponds to our tensor parallel index.
+ tp_world_size = parallel_state.get_tensor_model_parallel_world_size()
+ tp_index = parallel_state.get_tensor_model_parallel_rank()
+ alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index]
+
+ if fill_neg_inf:
+ return _buffered_future_mask(max_seq_len, alibi, num_attention_heads)
+
+ return alibi
+
+
+def core_attention_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *arg, **kwargs):
+ fn(self, *arg, **kwargs)
+
+ args = get_args()
+ self.hidden_size_per_partition = self.hidden_size_per_partition // arg[1].context_parallel_size
+ self.square_alibi_mask = args.square_alibi_mask
+ self.fill_neg_inf = args.fill_neg_inf
+ self.beta = 1.0
+ self.config = arg[1]
+ if self.apply_query_key_layer_scaling:
+ self.beta = 1.0 / self.layer_number
+ if args.position_embedding_type == 'alibi':
+ self.alibi = Alibi()
+ alibi = _build_alibi_tensor(args.seq_length,
+ self.config.num_attention_heads,
+ args.square_alibi_mask,
+ args.fill_neg_inf
+ ).to(torch.cuda.current_device())
+ if self.config.params_dtype == torch.float16:
+ alibi = alibi.to(torch.float16)
+ elif self.config.params_dtype == torch.bfloat16:
+ alibi = alibi.to(torch.bfloat16)
+ self.alibi.alibi = alibi
+ else:
+ self.alibi = None
+
+ return wrapper
+
+
+def core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
+ # ===================================
+ # Raw attention scores. [b, np, s, s]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1),
+ query_layer.size(2),
+ query_layer.size(0),
+ key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.reshape(output_size[2],
+ output_size[0] * output_size[1], -1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(output_size[3],
+ output_size[0] * output_size[1], -1)
+
+ if self.alibi is None:
+ matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
+ (output_size[0] * output_size[1], output_size[2], output_size[3]),
+ query_layer.dtype, "mpu")
+
+ matmul_result = torch.baddbmm(
+ matmul_input_buffer,
+ query_layer.transpose(0, 1),
+ key_layer.transpose(0, 1).transpose(1, 2),
+ beta=0.0, alpha=(1.0 / self.norm_factor))
+ else:
+ if self.alibi.matmul_result is None or self.alibi.output_size != output_size:
+ args = get_args()
+
+ self.alibi.output_size = output_size
+ alibi = _build_alibi_tensor(args.seq_length,
+ self.config.num_attention_heads,
+ args.square_alibi_mask,
+ args.fill_neg_inf
+ ).to(torch.cuda.current_device())
+ if self.config.params_dtype == torch.float16:
+ alibi = alibi.to(torch.float16)
+ elif self.config.params_dtype == torch.bfloat16:
+ alibi = alibi.to(torch.bfloat16)
+ self.alibi.alibi = alibi
+
+ if self.fill_neg_inf:
+ _alibi = self.alibi.alibi[:, :output_size[3], :output_size[3]]
+ attention_mask = attention_mask.repeat(output_size[0], 1, 1, 1)[:output_size[0], :, :, :]
+ self.alibi.matmul_result = _get_inverted_mask(attention_mask, _alibi).view(-1, output_size[2],
+ output_size[2]).contiguous()
+ else:
+ self.alibi.matmul_result = self.alibi.alibi[:, :, :output_size[3]].repeat(output_size[0], 1, 1)
+
+ q_trans = query_layer.transpose(0, 1).contiguous()
+ k_trans = key_layer.transpose(0, 1).transpose(1, 2).contiguous()
+ matmul_result = self.beta * self.alibi.matmul_result + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor)
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ # ===========================
+ # Attention probs and dropout
+ # ===========================
+
+ # attention scores and attention mask [b, np, sq, sk]
+ if self.square_alibi_mask:
+ attention_scores = torch.max(
+ attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min)
+ )
+ attention_probs = torch.nn.functional.softmax(attention_scores, -1)
+ else:
+ attention_probs = self.scale_mask_softmax(attention_scores,
+ attention_mask)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ if not self.sequence_parallel:
+ with tensor_parallel.get_cuda_rng_tracker().fork():
+ attention_probs = self.attention_dropout(attention_probs)
+ else:
+ attention_probs = self.attention_dropout(attention_probs)
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1),
+ value_layer.size(2),
+ query_layer.size(0),
+ value_layer.size(3))
+
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0),
+ output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1],
+ output_size[2], -1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + \
+ (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ return context_layer
+
+
+class NoopTransformerLayer(MegatronModule):
+ def __init__(self, layer_number):
+ super().__init__(None)
+ self.layer_number = layer_number
+
+ def forward(self, hidden_states, *args, **kwargs):
+ return hidden_states.clone()
+
+
+def parallel_transformer_init(self, config,
+ model_type, layer_type=LayerType.encoder,
+ self_attn_mask_type=AttnMaskType.padding,
+ post_norm=True,
+ pre_process=True,
+ post_process=True,
+ drop_path_rate=0.0):
+ super(ParallelTransformer, self).__init__()
+
+ args = get_args()
+
+ self.layer_type = layer_type
+ self.model_type = model_type
+ self.bf16 = config.bf16
+ self.fp32_residual_connection = config.fp32_residual_connection
+ self.post_norm = post_norm
+ self.pre_process = pre_process
+ self.post_process = post_process
+ self.input_tensor = None
+ self.drop_path_rate = drop_path_rate
+ self.transformer_impl = args.transformer_impl
+ self.retro_add_retriever = args.retro_add_retriever
+
+ # Store activation checkpoiting flag.
+ self.recompute_granularity = config.recompute_granularity
+ self.recompute_method = config.recompute_method
+ self.recompute_num_layers = config.recompute_num_layers
+ self.distribute_saved_activations = \
+ config.distribute_saved_activations and not config.sequence_parallel
+
+ self.sequence_parallel = config.sequence_parallel
+
+ # Transformer Engine Init.
+ self.transformer_engine_v_0_10 = False
+ self.transformer_engine_v_0_11 = False
+ self.transformer_engine_v_0_8 = False
+ if self.transformer_impl == 'transformer_engine':
+ global transformer_engine
+ import transformer_engine
+ from importlib.metadata import version
+ from pkg_resources import packaging
+
+ te_version = packaging.version.Version(version("transformer-engine"))
+ if te_version >= packaging.version.Version("0.8.0"):
+ self.transformer_engine_v_0_8 = True
+ if te_version >= packaging.version.Version("0.10.0"):
+ self.transformer_engine_v_0_10 = True
+ if te_version >= packaging.version.Version("0.11.0"):
+ self.transformer_engine_v_0_11 = True
+
+ del version, packaging
+
+ assert not args.squared_relu, "TransformerEngine does not support squared relu activation."
+
+ self.use_fp8 = config.fp8 is not None
+ self.fp8_recipe = None
+ self.fp8_group = None
+ if self.use_fp8:
+ assert args.transformer_impl == 'transformer_engine', \
+ 'transformer-engine required for fp8 training and inference'
+ self.fp8_group = mpu.get_amax_reduction_group()
+ if config.fp8 == "e4m3":
+ fp8_format = transformer_engine.common.recipe.Format.E4M3
+ elif config.fp8 == "hybrid":
+ fp8_format = transformer_engine.common.recipe.Format.HYBRID
+ else:
+ raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
+ self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
+ margin=config.fp8_margin,
+ interval=config.fp8_interval,
+ fp8_format=fp8_format,
+ amax_history_len=config.fp8_amax_history_len,
+ amax_compute_algo=config.fp8_amax_compute_algo,
+ override_linear_precision=(False, False, not config.fp8_wgrad),
+ )
+
+ self.num_microbatches_in_previous_step = -1
+ self.microbatch_count = 0
+ self.checkpoint_core_attention = config.recompute_granularity == 'selective'
+
+ # Number of layers.
+ self.num_layers = _get_num_layers(args, model_type,
+ layer_type==LayerType.decoder)
+
+ self.drop_path_rates = [
+ rate.item() for rate in
+ torch.linspace(0, self.drop_path_rate, config.num_layers)]
+
+ self.retro_layer_numbers = None
+ if model_type == ModelType.retro_decoder:
+ retro_layer_start = 6 if config.num_layers <= 15 else 9
+ self.retro_layer_numbers = \
+ np.arange(retro_layer_start, config.num_layers + 1, 3).tolist()
+ if model_type == ModelType.retro_encoder:
+ self.retro_layer_numbers = [1]
+
+ # Transformer layers.
+ if args.retro_add_retriever:
+ assert self.recompute_granularity != 'full', \
+ "Full recompute not supported for Retro."
+ assert args.transformer_impl == 'local', \
+ "Transformer engine does not support Retro layers."
+ def build_layer(layer_number):
+ if args.transformer_impl == 'local':
+ if (hasattr(args, 'noop_layers') and isinstance(args.noop_layers, set)
+ and layer_number - 1 in args.noop_layers):
+ return NoopTransformerLayer(layer_number)
+
+ current_layer_type = _get_layer_type(
+ model_type, layer_type, self.retro_layer_numbers,
+ layer_number)
+ return ParallelTransformerLayer(
+ config,
+ layer_number,
+ layer_type=current_layer_type,
+ self_attn_mask_type=self_attn_mask_type,
+ drop_path_rate=self.drop_path_rates[layer_number - 1])
+ else:
+ # This argument is only available from TE v0.10 onwards.
+ extra_transformer_engine_kwargs = {}
+ if self.transformer_engine_v_0_8:
+ extra_transformer_engine_kwargs["bias"] = config.add_bias_linear
+ if self.transformer_engine_v_0_10:
+ extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
+ if self.transformer_engine_v_0_11:
+ extra_transformer_engine_kwargs["normalization"] = config.normalization
+ assert config.attention_softmax_in_fp32, "TransformerEngine only supports softmax compute in FP32."
+ assert (
+ (bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and config.fp16) == config.apply_query_key_layer_scaling
+ ), "Unsupported config for apply_query_key_layer_scaling in TransformerEngine."
+ return transformer_engine.pytorch.TransformerLayer(
+ config.hidden_size,
+ config.ffn_hidden_size,
+ config.num_attention_heads,
+ layernorm_epsilon=config.layernorm_epsilon,
+ hidden_dropout=config.hidden_dropout,
+ attention_dropout=config.attention_dropout,
+ init_method=config.init_method,
+ output_layer_init_method=config.output_layer_init_method,
+ layer_number=layer_number,
+ kv_channels=config.kv_channels,
+ self_attn_mask_type=self_attn_mask_type.name,
+ tp_group=mpu.get_tensor_model_parallel_group(),
+ get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
+ fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
+ seq_length=args.seq_length,
+ micro_batch_size=args.micro_batch_size,
+ sequence_parallel=config.sequence_parallel,
+ params_dtype=config.params_dtype,
+ apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
+ output_layernorm=False,
+ layer_type="encoder",
+ drop_path_rate=self.drop_path_rates[layer_number - 1],
+ set_parallel_mode=True,
+ fuse_qkv_params=True,
+ **extra_transformer_engine_kwargs)
+
+ if config.virtual_pipeline_model_parallel_size is not None:
+ assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
+ 'num_layers_per_stage must be divisible by ' \
+ 'virtual_pipeline_model_parallel_size'
+ assert args.model_type != ModelType.encoder_and_decoder
+ # Number of layers in each model chunk is the number of layers in the stage,
+ # divided by the number of model chunks in a stage.
+ self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
+ # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
+ # layers to stages like (each list is a model chunk):
+ # Stage 0: [0] [2] [4] [6]
+ # Stage 1: [1] [3] [5] [7]
+ # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
+ # layers to stages like (each list is a model chunk):
+ # Stage 0: [0, 1] [4, 5]
+ # Stage 1: [2, 3] [6, 7]
+ offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
+ config.num_layers // config.virtual_pipeline_model_parallel_size) + \
+ (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
+ else:
+ # Each stage gets a contiguous set of layers.
+ if args.model_type == ModelType.encoder_and_decoder and \
+ mpu.get_pipeline_model_parallel_world_size() > 1:
+ pipeline_rank = mpu.get_pipeline_model_parallel_rank()
+ if layer_type == LayerType.encoder:
+ offset = pipeline_rank * self.num_layers
+ else:
+ num_ranks_in_enc = config.pipeline_model_parallel_split_rank
+ offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
+ else:
+ offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
+
+ if self.num_layers == 0:
+ # When a standalone embedding stage is used (e.g.,
+ # args.standalone_embedding_stage == True), virtual pipeline ranks
+ # on pipeline rank 0 will have zero transformer layers assigned to
+ # them. This results in the model's input and output tensors to be
+ # the same, which will cause failure for certain output tensor
+ # optimizations (e.g., pipeline output deallocation). To remedy
+ # this, we assign a 'no-op' layer on these ranks, which will
+ # disconnect the input tensor from the output tensor.
+ self.num_layers = 1
+ self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
+ else:
+ self.layers = torch.nn.ModuleList(
+ [build_layer(i + 1 + offset) for i in range(self.num_layers)])
+
+ # Update dropout rate for Retro encoder.
+ if model_type == ModelType.retro_encoder:
+ for layer in self.layers:
+ if layer.self_attention.use_flash_attn:
+ layer.self_attention.core_attention_flash.dropout_p = \
+ torch.nn.Dropout(args.retro_encoder_attention_dropout)
+ else:
+ layer.self_attention.core_attention.attention_dropout.p =\
+ args.retro_encoder_attention_dropout
+ layer.hidden_dropout = args.retro_encoder_hidden_dropout
+
+ if self.post_process and self.post_norm:
+ # Final layer norm before output.
+ self.final_norm = get_norm(config)
+
+
+def parallel_transformer_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ def build_layer(model_type, config, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding):
+ current_layer_type = _get_layer_type(
+ model_type, layer_type, self.retro_layer_numbers,
+ layer_number)
+ return ParallelTransformerLayer(
+ config,
+ layer_number,
+ layer_type=current_layer_type,
+ self_attn_mask_type=self_attn_mask_type,
+ drop_path_rate=self.drop_path_rates[layer_number - 1])
+ fn(self, *args, **kwargs)
+
+ argument = get_args()
+ if argument.automated_pipeline and argument.num_layer_list and argument.virtual_pipeline_model_parallel_size is None:
+ start_layer_num = 1
+ self.layers = torch.nn.ModuleList()
+ for idx, value in enumerate(argument.num_layer_list):
+ if parallel_state.get_pipeline_model_parallel_rank() == idx:
+ self.num_layers = value
+ for layer_num in range(start_layer_num, start_layer_num + value):
+ self.layers.append(build_layer(kwargs['model_type'], args[0], layer_num, self_attn_mask_type=kwargs['self_attn_mask_type']))
+ start_layer_num += value
+ self.layers = torch.nn.ModuleList(self.layers)
+
+ # Update dropout rate for Retro encoder.
+ if kwargs['model_type'] == ModelType.retro_encoder:
+ for layer in self.layers:
+ if layer.self_attention.use_flash_attn:
+ layer.self_attention.core_attention_flash.dropout_p = \
+ torch.nn.Dropout(argument.retro_encoder_attention_dropout)
+ else:
+ layer.self_attention.core_attention.attention_dropout.p = \
+ argument.retro_encoder_attention_dropout
+ layer.hidden_dropout = argument.retro_encoder_hidden_dropout
+ return wrapper
+
+
+def set_attention_mask(attn_mask):
+ global _GLOBAL_ATTN_MASK
+ _GLOBAL_ATTN_MASK = attn_mask
+
+
+def generate_attention_mask(compress, device):
+ global _GLOBAL_ATTN_MASK
+ args = get_args()
+ if not args.use_flash_attn:
+ warnings.warn("Flash Attention is highly recommended")
+ _GLOBAL_ATTN_MASK = (torch.tril(torch.ones([args.micro_batch_size, 1, args.seq_length, args.seq_length], dtype=bool, device=device), diagonal=-(args.pre_tockens + 1)) \
+ + torch.triu(torch.ones([args.micro_batch_size, 1, args.seq_length, args.seq_length], dtype=bool, device=device), diagonal=args.next_tockens + 1))
+ return
+
+ if compress:
+ seq_len = 2048
+ else:
+ seq_len = args.seq_length
+
+ _GLOBAL_ATTN_MASK = torch.triu(
+ torch.ones((seq_len, seq_len),
+ device=device, dtype=torch.bool), diagonal=1)
+
+
+def get_attention_mask():
+ global _GLOBAL_ATTN_MASK
+ if _GLOBAL_ATTN_MASK is not None:
+ return _GLOBAL_ATTN_MASK
+
+ args = get_args()
+ should_generate_mask = False
+ device = 'npu'
+
+ if args.attention_mask_type == 'causal':
+ args.sparse_mode = 2
+ should_generate_mask = True
+ compress = True
+
+ # ampipe开启在同时不开cp时需要生成全量mask,开cp时生成causal mask
+ if args.ampipe_degree > 1 and args.context_parallel_size <= 1:
+ args.sparse_mode = 0
+ should_generate_mask = True
+ compress = False
+
+ # EoD 模式 Ring Attention的实现
+ # general 为基线方案,causal 为加速方案
+ # 如果 cp > 1 且使用了Ring Attention 并行(包括Hybrid并行)。则Mask为动态生成的,不需要额外的Mask
+ if args.reset_attention_mask:
+ if args.attention_mask_type == 'general':
+ args.sparse_mode = 2
+ if args.context_parallel_size == 1 or args.context_parallel_algo == 'ulysses_cp_algo':
+ should_generate_mask = True
+ compress = True
+ else:
+ args.sparse_mode = 1
+ should_generate_mask = False
+ else:
+ should_generate_mask = True
+ compress = True
+
+
+ if args.attention_mask_on_cpu:
+ device = 'cpu'
+
+ if should_generate_mask:
+ generate_attention_mask(compress, device)
+
+ return _GLOBAL_ATTN_MASK
+
+
+def parallel_transformer_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, hidden_states, attention_mask, **kwargs):
+ args = get_args()
+ if attention_mask is None:
+ attention_mask = get_attention_mask()
+ return fn(self, hidden_states, attention_mask, **kwargs)
+ return wrapper
+
+
+def parallel_transformer_forward_ampipe(self, hidden_states, attention_mask,
+ encoder_output=None, enc_dec_attn_mask=None,
+ retriever_input=None,
+ retriever_output=None,
+ retriever_attn_mask=None,
+ inference_params=None,
+ rotary_pos_emb=None):
+ # hidden_states: [s, b, h]
+
+ # Checks.
+ if inference_params:
+ assert self.recompute_granularity is None, \
+ 'inference does not work with activation checkpointing'
+
+ if not self.pre_process:
+ # See set_input_tensor()
+ hidden_states = self.input_tensor
+
+ # Viewless tensor.
+ # - We only need to create a viewless tensor in the case of micro batch
+ # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
+ # above creates a view tensor, and '.contiguous()' is a pass-through.
+ # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
+ # the need to make it viewless.
+ #
+ # However, we don't explicitly check mbs == 1 here because
+ # make_viewless_tensor() has negligible overhead when its input
+ # is already viewless.
+ #
+ # - For the 'else' case above, calling make_viewless_tensor() here is
+ # likely redundant, since p2p_communication.py (likely originator)
+ # already creates viewless tensors. That said, make_viewless_tensor()
+ # is called here to be future-proof and corner-case-proof.
+ hidden_states = make_viewless_tensor(
+ hidden_states,
+ requires_grad=True,
+ keep_graph=True,
+ )
+
+ # RNG context.
+ if self.sequence_parallel:
+ rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
+ else:
+ rng_context = nullcontext()
+
+ # Forward layers.
+ with rng_context:
+ # Determine if the current iteration is first microbatch
+ if self.num_microbatches_in_previous_step != get_num_microbatches():
+ self.microbatch_count = 0 # Reset count on new batch size rampup interval
+ self.num_microbatches_in_previous_step = get_num_microbatches()
+ is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0
+
+ # Forward pass.
+ if self.recompute_granularity == 'full':
+ hidden_states = self._checkpointed_forward(hidden_states,
+ attention_mask,
+ encoder_output,
+ enc_dec_attn_mask,
+ rotary_pos_emb,
+ is_first_microbatch)
+ else:
+ forward_kwargs = {
+ 'encoder_output': encoder_output,
+ 'enc_dec_attn_mask': enc_dec_attn_mask,
+ 'inference_params': inference_params,
+ }
+
+ forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
+ forward_kwargs['retriever_input'] = retriever_input
+ forward_kwargs['retriever_output'] = retriever_output
+ forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
+
+ for index in range(self.num_layers):
+ layer = self._get_layer(index)
+
+ hidden_states = layer(
+ hidden_states,
+ attention_mask,
+ **forward_kwargs)
+
+ # First Retro decoder layer returns both hidden_states
+ # and retriever_output. Make retriever_output available
+ # to subsequence Retro layers.
+ if isinstance(hidden_states, tuple):
+ assert len(hidden_states) == 2
+ hidden_states, retriever_output = hidden_states
+ forward_kwargs["retriever_output"] = retriever_output
+ if self.sequence_parallel:
+ ampipe_degree = get_args().ampipe_degree
+ if ampipe_degree > 1:
+ hidden_states = AmpipeLastTransformerLayerRearrange.apply(hidden_states, ampipe_degree)
+ # Skip counter update for eval and activation checkpointing
+ if torch.is_grad_enabled() and self.training:
+ self.microbatch_count += 1
+
+ # Final layer norm.
+ if self.post_process and self.post_norm:
+ hidden_states = self.final_norm(hidden_states)
+
+ return hidden_states
+
+
+class AmpipeLastTransformerLayerRearrange(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, hidden_states, ampipe_degree) -> Any:
+ ag_hidden_states = all_gather_along_first_dim(hidden_states, True)
+ hidden_states.untyped_storage().resize_(0)
+ seqlen = ag_hidden_states.size(0)
+
+ new_indices = get_slice_indices_from_order_to_disorder(seqlen, ampipe_degree, device=torch.npu.current_device())
+ select_hidden_states = torch.index_select(ag_hidden_states, dim=0, index=new_indices)
+ hidden_states_chunk = _split_along_first_dim(select_hidden_states)
+ hidden_states_chunk = hidden_states_chunk.clone()
+ select_hidden_states.untyped_storage().resize_(0)
+ ctx.ampipe_degree = ampipe_degree
+ return hidden_states_chunk
+
+ @staticmethod
+ def backward(ctx, grad_input) -> Any:
+ ag_grad_input = all_gather_along_first_dim(grad_input, True)
+ grad_input.untyped_storage().resize_(0)
+ seqlen = ag_grad_input.size(0)
+
+ new_indices = get_slice_indices_from_disorder_to_order(seqlen, ctx.ampipe_degree, device=torch.npu.current_device())
+ select_grad_input = torch.index_select(ag_grad_input, dim=0, index=new_indices)
+ grad_output_chunk = _split_along_first_dim(select_grad_input)
+ grad_output_chunk = grad_output_chunk.clone()
+ select_grad_input.untyped_storage().resize_(0)
+ return grad_output_chunk, None
+
+
+def parallel_mlp_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ fn(self, *args, **kwargs)
+ self.layer_number = None
+ _args = get_args()
+ if _args.swiglu and _args.use_fused_swiglu:
+ self.activation_func = fused_swiglu
+
+ config = args[0]
+ is_expert = kwargs.get('is_expert') if 'is_expert' in kwargs.keys() else False
+
+ ffn_hidden_size = config.ffn_hidden_size
+ if config.gated_linear_unit:
+ ffn_hidden_size *= 2
+ if _args.use_nd_matmul:
+ self.dense_h_to_4h = Nd_ParallelLinear(
+ config.hidden_size,
+ ffn_hidden_size,
+ config=config,
+ init_method=config.init_method,
+ bias=self.add_bias,
+ skip_bias_add=True,
+ input_is_parallel=True,
+ is_expert=is_expert,
+ matmul_id=1
+ )
+ self.dense_4h_to_h = Nd_ParallelLinear(
+ config.ffn_hidden_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ bias=self.add_bias,
+ skip_bias_add=True,
+ input_is_parallel=True,
+ is_expert=is_expert,
+ matmul_id=2
+ )
+ elif _args.tp_2d:
+ self.dense_h_to_4h = ParallelLinear2D(
+ config.hidden_size,
+ ffn_hidden_size,
+ config=config,
+ init_method=config.init_method,
+ add_bias=self.add_bias,
+ skip_bias_add=True,
+ is_expert=is_expert,
+ ag_comm_intf=TPXCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ rs_comm_intf=TPYCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=_args.enable_overlap_matmul_with_rs,
+ partition_dim=0,
+ enable_backward_overlap_ag_with_matmul=_args.enable_backward_overlap_ag_with_matmul)
+ self.dense_4h_to_h = ParallelLinear2D(
+ config.ffn_hidden_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ add_bias=self.add_bias,
+ skip_bias_add=True,
+ ag_comm_intf=TPYCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ rs_comm_intf=TPXCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=_args.enable_overlap_ag_with_matmul,
+ enable_overlap_matmul_with_rs=False,
+ partition_dim=1,
+ enable_backward_overlap_ag_with_matmul=_args.enable_backward_overlap_ag_with_matmul)
+ else:
+ self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
+ config.hidden_size,
+ ffn_hidden_size,
+ config=config,
+ init_method=config.init_method,
+ bias=self.add_bias,
+ gather_output=False,
+ skip_bias_add=True,
+ is_expert=is_expert
+ )
+ self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
+ config.ffn_hidden_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ bias=self.add_bias,
+ skip_bias_add=True,
+ input_is_parallel=True,
+ is_expert=is_expert
+ )
+ if _args.use_nanopipe and parallel_state.get_pipeline_model_parallel_world_size() > 1 \
+ and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
+ setattr(self.dense_h_to_4h, "in_nano", True)
+ setattr(self.dense_4h_to_h, "in_nano", True)
+ # use dynamic property assignment to ADD pipe_experts attribution
+ if not _args.swiglu:
+ self.dense_h_to_4h.pipe_experts = _args.use_pipe_experts
+ self.dense_4h_to_h.pipe_experts = _args.use_pipe_experts
+ if _args.ampipe_degree > 1:
+ setattr(self.dense_h_to_4h, "ampipe_degree", _args.ampipe_degree)
+ setattr(self.dense_4h_to_h, "ampipe_degree", _args.ampipe_degree)
+ return wrapper
+
+
+def should_recompute(args, layer_number, num_recompute):
+ vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
+ vpp_size = args.virtual_pipeline_model_parallel_size
+ pp_size = args.transformer_pipeline_model_parallel_size
+
+ if vpp_size is not None:
+ layer_per_chunk = args.num_layers_per_virtual_pipeline_stage
+ elif pp_size is not None:
+ layer_per_chunk = args.num_layers // pp_size
+ else:
+ layer_per_chunk = args.num_layers
+
+ if vpp_rank is None or not args.enable_recompute_layers_per_pp_rank:
+ vpp_rank = 0
+ if vpp_size is None or not args.enable_recompute_layers_per_pp_rank:
+ vpp_size = 1
+ recompute_priority = ((layer_number - 1) % layer_per_chunk) * vpp_size + vpp_rank
+ full_recompute_layers = args.recompute_num_layers
+
+ if full_recompute_layers:
+ if recompute_priority < full_recompute_layers:
+ # Do full recomputation
+ return False
+ elif num_recompute is None:
+ return True
+ elif recompute_priority < full_recompute_layers + num_recompute:
+ return True
+ else:
+ return False
+
+ if num_recompute is None:
+ return True
+ else:
+ return recompute_priority < num_recompute
+
+
+def should_recompute_activation(layer_number):
+ args = get_args()
+ if not args.recompute_activation_function or layer_number is None:
+ return False
+
+ if args.recompute_in_bubble or args.recompute_in_advance:
+ pipeline_checkpoint_manager = get_pipeline_checkpoint_manager(args.virtual_pipeline_model_parallel_size)
+ if pipeline_checkpoint_manager.chunk_do_recompute:
+ return False
+ elif args.recompute_in_bubble:
+ return True
+
+ if args.recompute_activation_function_num_layers is not None:
+ if args.recompute_activation_function_num_layers < 0:
+ raise AssertionError('--recompute-activation-function-num-layers cannot be less than 0.')
+ elif args.recompute_activation_function_num_layers > args.num_layers:
+ raise AssertionError('--recompute-activation-function-num-layers cannot be greater than the number of layers.')
+ return should_recompute(args, layer_number, args.recompute_activation_function_num_layers)
+
+
+def should_recompute_norm(self):
+ args = get_args()
+ if not args.recompute_norm or self.layer_number is None:
+ return False
+ return should_recompute(args, self.layer_number, args.recompute_norm_num_layers)
+
+
+def parallel_mlp_forward(self, hidden_states):
+ self.layer_number = getattr(self, "layer_number", None)
+ is_recompute_activation = should_recompute_activation(self.layer_number)
+ args = get_args()
+
+ def activation_function(*function_args):
+ intermediate, bias = function_args
+
+ if self.bias_gelu_fusion:
+ assert self.add_bias is True
+ assert self.activation_func == F.gelu
+ intermediate = bias_gelu_impl(intermediate, bias)
+ else:
+ if bias is not None:
+ intermediate = intermediate + bias
+ intermediate = self.activation_func(intermediate)
+ return intermediate
+
+ if not is_recompute_activation:
+ # [s, b, 4hp]
+ intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
+ if not args.use_pipe_experts and args.ampipe_degree > 1 and args.ampipe_tp_sp_comm_overlap:
+ from mindspeed.moe.async_comm_utils import get_fw_ar_rs_output_ampipe, async_all_to_all
+ last_chunk_output = get_fw_ar_rs_output_ampipe(args.sequence_parallel)
+ if last_chunk_output is not None:
+ a2a_output, handle = async_all_to_all(last_chunk_output)
+
+ if self.bias_gelu_fusion:
+ assert self.add_bias is True
+ assert self.activation_func == F.gelu
+ intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
+ else:
+ if bias_parallel is not None:
+ intermediate_parallel = intermediate_parallel + bias_parallel
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+
+ # [s, b, h]
+ output, output_bias = self.dense_4h_to_h(intermediate_parallel)
+ if not args.use_pipe_experts and args.ampipe_degree > 1 and args.ampipe_tp_sp_comm_overlap:
+ if last_chunk_output is not None:
+ handle.wait()
+ return output, output_bias, a2a_output
+ else:
+ if not args.use_pipe_experts and args.ampipe_degree > 1 and args.ampipe_tp_sp_comm_overlap:
+ from mindspeed.moe.async_comm_utils import (get_fw_ar_rs_output_ampipe,
+ async_all_to_all)
+ last_chunk_output = get_fw_ar_rs_output_ampipe(args.sequence_parallel)
+ if last_chunk_output is not None:
+ a2a_output, handle = async_all_to_all(last_chunk_output)
+
+ intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
+ self.activation_checkpoint_manager = CheckpointWithoutOutput()
+ intermediate_parallel = self.activation_checkpoint_manager.checkpoint(activation_function,
+ False,
+ intermediate_parallel,
+ bias_parallel)
+ # [s, b, h]
+ output, output_bias = self.dense_4h_to_h(intermediate_parallel)
+
+ # discard the output of the activation function,
+ # which will be restored by recomputation during backward.
+ self.activation_checkpoint_manager.discard_output()
+
+ # when backward to output of dense_4h_to_h,
+ # recompute and restore the output of activation function.
+ if output.requires_grad:
+ output.register_hook(self.activation_checkpoint_manager.recompute)
+ if not args.use_pipe_experts and args.ampipe_degree > 1 and args.ampipe_tp_sp_comm_overlap:
+ if last_chunk_output is not None:
+ handle.wait()
+ return output, output_bias, a2a_output
+ return output, output_bias
+
+
+def flash_self_attention_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *arg, **kwargs):
+ fn(self, *arg, **kwargs)
+ args = get_args()
+
+ self.pse = None
+ self.pse_type = args.alibi_fusion_attn_type
+
+ if self.pse_type is None:
+ self.pse_type = 1 # not use pse
+ elif self.pse_type == 0:
+ alibi = AlibiForFusionAttnSingleton.get_alibi_tensor_for_fusion_attn(args.seq_length,
+ args.num_attention_heads,
+ args.params_dtype,
+ args.alibi_diagonal_opposite,
+ 1024)
+ self.pse = alibi
+
+ elif self.pse_type == 2 or self.pse_type == 3:
+ self.pse = AlibiForFusionAttnSingleton.get_alibi_slopes_for_fusion_attn(args.num_attention_heads)
+
+ return wrapper
+
+
+def flash_self_attention_init_add_config_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *arg, **kwargs):
+ if 'config' in kwargs:
+ self.config = kwargs.pop('config')
+ fn(self, *arg, **kwargs)
+
+ return wrapper
+
+
+def flash_self_attention_forward(self, q, k, v, attention_mask):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ q, k, v: The tensor containing the query, key, and value. (S, B, H, D)
+ """
+ args = get_args()
+ seq_length, _, head_num, head_dim = q.shape[0], q.shape[1], q.shape[2], q.shape[3]
+
+ q, k, v = [rearrange(x, 's b h d -> s b (h d)') for x in [q, k, v]]
+
+ try:
+ scale = 1.0 / math.sqrt(head_dim) if self.softmax_scale is None else self.softmax_scale
+ except Exception as e:
+ raise ValueError('Invalid head_dim: {}'.format(head_dim)) from e
+ cp_expanded_by_2d_tp = args.tp_2d and args.tp_y > 1
+ if cp_expanded_by_2d_tp:
+ tp_y_cp_sz = TensorParallelYUnionCP().get_parallel_group_world_size()
+ else:
+ tp_y_cp_sz = args.context_parallel_size
+ if tp_y_cp_sz > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo',
+ 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']:
+ in_hybrid_mode = False
+ if get_context_parallel_group_for_hybrid_ring(check_initialized=False) is not None:
+ in_hybrid_mode = True
+
+ if not in_hybrid_mode:
+ if cp_expanded_by_2d_tp:
+ tp_y_cp = TensorParallelYUnionCP()
+ cp_group = tp_y_cp.group
+ cp_size = tp_y_cp.get_parallel_group_world_size()
+ rank = tp_y_cp.get_parallel_rank()
+ cp_global_ranks = tp_y_cp.global_ranks
+ else:
+ cp_group = mpu.get_context_parallel_group()
+ cp_size = mpu.get_context_parallel_world_size()
+ rank = mpu.get_context_parallel_rank()
+ cp_global_ranks = mpu.get_context_parallel_global_ranks()
+ else:
+ cp_group = get_context_parallel_group_for_hybrid_ring()
+ cp_size = get_context_parallel_for_hybrid_ring_world_size()
+ rank = get_context_parallel_for_hybrid_ring_rank()
+ cp_global_ranks = get_context_parallel_for_hybrid_ring_global_ranks()
+
+ cp_para = dict()
+ if hasattr(self, 'config'):
+ cp_para['megatron_cp_in_bnsd'] = self.config.megatron_cp_in_bnsd
+ cp_para['causal'] = args.attention_mask_type == 'causal'
+ cp_para['cp_group'] = cp_group
+ cp_para['cp_size'] = cp_size
+ cp_para['rank'] = rank
+
+ if args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
+ cp_para['cp_global_ranks'] = cp_global_ranks
+ if args.use_cp_send_recv_overlap:
+ if cp_expanded_by_2d_tp:
+ cp_para['cp_group_for_send_recv_overlap'] = tp_y_cp.overlap_group
+ else:
+ cp_para['cp_group_for_send_recv_overlap'] = mpu.get_context_parallel_group_for_send_recv_overlap()
+ else:
+ cp_para['cp_group_for_send_recv_overlap'] = None
+ cp_para['pse'] = self.pse
+ cp_para['pse_type'] = self.pse_type
+ if args.context_parallel_size > 1 and not args.tp_2d:
+ cp_para['cp_inner_ranks'] = get_ring_ranks_for_intra_window()
+ cp_para['cp_outer_ranks'] = get_ring_ranks_for_inter_window_kv()
+ cp_para['cp_dkv_outer_ranks'] = get_ring_ranks_for_inter_window_dkv()
+ cp_para['cp_group_for_intra_window'] = get_ring_group_for_intra_window()
+ cp_para['cp_group_for_intra_window_send_recv_overlap'] = get_ring_group_for_intra_window_send_recv_overlap()
+ output = ringattn_context_parallel(q, k, v, head_num, cp_para, scale, attention_mask, self.dropout_p)
+ else:
+ cp_para['scheduling_info'] = get_scheduling_info()
+ output = adaptive_attn_context_parallel(q, k, v, head_num, cp_para, scale, attention_mask, self.dropout_p)
+ else:
+ if args.use_fusion_attn_v2:
+ output = npu_fusion_attention(
+ q, k, v, head_num, args.shape_order,
+ pse=self.pse,
+ padding_mask=None,
+ atten_mask=attention_mask,
+ scale=scale,
+ pse_type=self.pse_type,
+ pre_tokens=args.pre_tockens,
+ next_tokens=args.next_tockens,
+ keep_prob=1 - self.dropout_p,
+ inner_precise=0,
+ sparse_mode=args.sparse_mode
+ )[0]
+ else:
+ output = torch_npu.npu_fusion_attention(
+ q, k, v, head_num, args.shape_order,
+ pse=None,
+ padding_mask=None,
+ atten_mask=attention_mask,
+ scale=scale,
+ pre_tockens=args.pre_tockens,
+ next_tockens=args.next_tockens,
+ keep_prob=1 - self.dropout_p,
+ inner_precise=0,
+ sparse_mode=args.sparse_mode
+ )[0]
+ return output
+
+
+def parallel_attention_init(self, config, layer_number,
+ attention_type=AttnType.self_attn,
+ attn_mask_type=AttnMaskType.padding):
+ super(ParallelAttention, self).__init__()
+ args = get_args()
+ self.layer_number = max(1, layer_number)
+ self.attention_type = attention_type
+ self.attn_mask_type = attn_mask_type
+ self.params_dtype = config.params_dtype
+ self.sequence_parallel = config.sequence_parallel
+ self.config = config
+ self.group_query_attention = args.group_query_attention
+ self.num_query_groups = config.num_query_groups
+
+ query_projection_size = config.kv_channels * config.num_attention_heads
+ if self.group_query_attention:
+ kv_projection_size = config.kv_channels * config.num_query_groups
+ else:
+ kv_projection_size = config.kv_channels * config.num_attention_heads
+
+ self.use_flash_attn = args.use_flash_attn \
+ and attention_type == AttnType.self_attn \
+ and self.attn_mask_type == AttnMaskType.causal
+ if self.use_flash_attn:
+ try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
+ except ImportError:
+ try:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
+ except ImportError:
+ flash_attn_unpadded_func = None
+ if flash_attn_unpadded_func is None:
+ raise ImportError('FlashAttention is not installed, please install with '
+ 'pip install flash-attn')
+ assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
+ 'self-attention for now')
+ assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
+ 'supports causal mask for now')
+ if rearrange is None:
+ raise ImportError('einops is not installed, please install with pip install einops')
+
+ # Per attention head and per partition values.
+ from megatron import core
+ self.hidden_size_per_attention_head = core.utils.divide(
+ query_projection_size, config.num_attention_heads)
+
+ # Strided linear layer.
+ if attention_type == AttnType.self_attn:
+ self.query_key_value = tensor_parallel.ColumnParallelLinear(
+ config.hidden_size,
+ query_projection_size + 2 * kv_projection_size,
+ config=config,
+ init_method=config.init_method,
+ bias=config.add_bias_linear or config.add_qkv_bias,
+ gather_output=False)
+ else:
+ assert attention_type == AttnType.cross_attn
+
+ if self.group_query_attention:
+ raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
+ assert query_projection_size == kv_projection_size
+
+ self.query = tensor_parallel.ColumnParallelLinear(
+ config.hidden_size,
+ query_projection_size,
+ config=config,
+ init_method=config.init_method,
+ bias=config.add_bias_linear,
+ gather_output=False)
+
+ self.key_value = tensor_parallel.ColumnParallelLinear(
+ config.hidden_size,
+ 2 * kv_projection_size,
+ config=config,
+ init_method=config.init_method,
+ bias=config.add_bias_linear,
+ gather_output=False)
+
+ self.core_attention = CoreAttention(self.layer_number, config,
+ self.attn_mask_type)
+ self.checkpoint_core_attention = config.recompute_granularity == 'selective'
+
+ if self.use_flash_attn:
+ self.core_attention_flash = FlashSelfAttention(
+ causal=True, attention_dropout=config.attention_dropout, config=config
+ )
+
+ # Output.
+ self.dense = tensor_parallel.RowParallelLinear(
+ query_projection_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ bias=config.add_bias_linear,
+ input_is_parallel=True,
+ skip_bias_add=True)
+ # patch for attention
+ patch_for_attention(config, self)
+
+
+def patch_for_attention(config, self):
+ _args = get_args()
+ attn_heads_split_num = (
+ get_tensor_model_parallel_world_size_for_nd1_dim1()
+ if _args.tp_2d
+ else mpu.get_tensor_model_parallel_world_size()
+ )
+ # Per attention head and per partition values.
+ self.num_attention_heads_per_partition = config.num_attention_heads // attn_heads_split_num
+ if self.group_query_attention:
+ if config.num_query_groups % attn_heads_split_num != 0:
+ raise NotImplementedError(
+ "Currently the num_query_groups should be a multiple of the tensor parallel size"
+ )
+ self.num_query_groups_per_partition = config.num_query_groups // attn_heads_split_num
+ else:
+ self.num_query_groups_per_partition = self.num_attention_heads_per_partition
+ query_projection_size = config.kv_channels * config.num_attention_heads
+ if _args.group_query_attention:
+ kv_projection_size = config.kv_channels * config.num_query_groups
+ else:
+ kv_projection_size = config.kv_channels * config.num_attention_heads
+ # qkv bias
+ bias = config.add_qkv_bias or config.add_bias_linear
+ cp = config.context_parallel_size
+ if _args.tp_2d:
+ tp_y_cp_sz = cp * _args.tp_y
+ else:
+ tp_y_cp_sz = cp
+ if tp_y_cp_sz > 1 and _args.context_parallel_algo in ['ulysses_cp_algo', 'hybrid_cp_algo',
+ 'hybrid_adaptive_cp_algo']:
+ if _args.tp_2d:
+ tp_y_cp = TensorParallelYUnionCP()
+ ulysses_group = tp_y_cp.group
+ else:
+ ulysses_group = mpu.get_context_parallel_group()
+ if _args.context_parallel_algo == 'hybrid_cp_algo' or _args.context_parallel_algo == 'hybrid_adaptive_cp_algo':
+ ulysses_group = get_context_parallel_group_for_hybrid_ulysses()
+ if self.use_flash_attn:
+ self.core_attention_flash = UlyssesContextAttention(self.core_attention_flash, ulysses_group)
+ else:
+ self.core_attention = UlyssesContextAttention(self.core_attention, ulysses_group)
+ if _args.use_nd_matmul:
+ self.query_key_value = Nd_ParallelLinear(
+ config.hidden_size,
+ query_projection_size + 2 * kv_projection_size,
+ config=config,
+ init_method=config.init_method,
+ bias=bias,
+ skip_bias_add=True,
+ input_is_parallel=True,
+ matmul_id=1
+ )
+ elif _args.tp_2d:
+ self.query_key_value = ParallelLinear2D(
+ config.hidden_size,
+ query_projection_size + 2 * kv_projection_size,
+ config=config,
+ init_method=config.init_method,
+ add_bias=bias,
+ skip_bias_add=True,
+ ag_comm_intf=TPXCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ rs_comm_intf=TPYCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ partition_dim=0,
+ enable_backward_overlap_ag_with_matmul=False)
+ else:
+ self.query_key_value = tensor_parallel.ColumnParallelLinear(
+ config.hidden_size,
+ query_projection_size + 2 * kv_projection_size,
+ config=config,
+ init_method=config.init_method,
+ bias=bias,
+ gather_output=False)
+ # dense bias
+ bias = _args.add_dense_bias or config.add_bias_linear
+ skip_bias_add = _args.skip_bias_add
+ # Output.
+ if _args.use_nd_matmul:
+ self.dense = Nd_ParallelLinear(
+ query_projection_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ bias=bias,
+ skip_bias_add=True,
+ input_is_parallel=True,
+ matmul_id=2
+ )
+ elif _args.tp_2d:
+ self.dense = ParallelLinear2D(
+ query_projection_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ add_bias=bias,
+ skip_bias_add=True,
+ ag_comm_intf=TPYCollectiveComm,
+ ag_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
+ rs_comm_intf=TPXCollectiveComm,
+ rs_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
+ enable_overlap_ag_with_matmul=False,
+ enable_overlap_matmul_with_rs=False,
+ partition_dim=1,
+ enable_backward_overlap_ag_with_matmul=_args.enable_backward_overlap_ag_with_matmul)
+ else:
+ self.dense = tensor_parallel.RowParallelLinear(
+ query_projection_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ bias=bias,
+ input_is_parallel=True,
+ skip_bias_add=skip_bias_add)
+ if _args.use_nanopipe and parallel_state.get_pipeline_model_parallel_world_size() > 1 \
+ and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
+ setattr(self.query_key_value, "in_nano", True)
+ setattr(self.dense, "in_nano", True)
+ if _args.ampipe_degree > 1:
+ setattr(self.query_key_value, 'ampipe_degree', _args.ampipe_degree)
+ setattr(self.query_key_value, 'is_dense_h_to_3h', True)
+
+
+def parallel_attention_forward(self, hidden_states, attention_mask,
+ encoder_output=None, inference_params=None,
+ rotary_pos_emb=None):
+ # hidden_states: [sq, b, h]
+
+ # =================================================
+ # Pre-allocate memory for key-values for inference.
+ # =================================================
+ is_first_step = False
+ if inference_params:
+ if self.layer_number not in inference_params.key_value_memory_dict:
+ inf_max_seq_len = inference_params.max_sequence_length
+ inf_max_batch_size = inference_params.max_batch_size
+ inference_key_memory = self._allocate_memory(
+ inf_max_seq_len, inf_max_batch_size,
+ self.num_query_groups_per_partition)
+ inference_value_memory = self._allocate_memory(
+ inf_max_seq_len, inf_max_batch_size,
+ self.num_query_groups_per_partition)
+
+ inference_params.key_value_memory_dict[self.layer_number] = (
+ inference_key_memory, inference_value_memory)
+ is_first_step = True
+ else:
+ inference_key_memory, inference_value_memory = \
+ inference_params.key_value_memory_dict[self.layer_number]
+
+ # =====================
+ # Query, Key, and Value
+ # =====================
+ if self.attention_type == AttnType.self_attn:
+
+ # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
+ mixed_x_layer, _ = self.query_key_value(hidden_states)
+
+ # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_query_groups_per_partition,
+ (
+ (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
+ * self.hidden_size_per_attention_head
+ ),
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
+ (query_layer,
+ key_layer,
+ value_layer) = torch.split(
+ mixed_x_layer,
+ [
+ (
+ self.num_attention_heads_per_partition // self.num_query_groups_per_partition
+ * self.hidden_size_per_attention_head
+ ),
+ self.hidden_size_per_attention_head,
+ self.hidden_size_per_attention_head
+ ],
+ dim=3)
+
+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
+ query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
+ else:
+ # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
+ mixed_kv_layer, _ = self.key_value(encoder_output)
+
+ # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
+ new_tensor_shape = mixed_kv_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ 2 * self.hidden_size_per_attention_head)
+ mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
+
+ # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
+ (key_layer,
+ value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
+
+ # Attention head [sq, b, h] --> [sq, b, hp]
+ query_layer, _ = self.query(hidden_states)
+ # [sq, b, hp] --> [sq, b, np, hn]
+ new_tensor_shape = query_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ query_layer = query_layer.view(*new_tensor_shape)
+
+ # ==================================
+ # Adjust key and value for inference
+ # ==================================
+
+ # duplicate the pos_emb for self attention
+ if rotary_pos_emb is not None:
+ if isinstance(rotary_pos_emb, tuple):
+ rotary_pos_emb = rotary_pos_emb
+ else:
+ rotary_pos_emb = ((rotary_pos_emb,) * 2)
+
+ if inference_params:
+ batch_start = inference_params.batch_size_offset
+ batch_end = batch_start + key_layer.size(1)
+ assert batch_end <= inference_key_memory.size(1)
+ sequence_start = inference_params.sequence_len_offset
+ sequence_end = sequence_start + key_layer.size(0)
+ assert sequence_end <= inference_key_memory.size(0)
+ # Copy key and values.
+ inference_key_memory[sequence_start:sequence_end,
+ batch_start:batch_end, ...] = key_layer
+ inference_value_memory[sequence_start:sequence_end,
+ batch_start:batch_end, ...] = value_layer
+ key_layer = inference_key_memory[
+ :sequence_end, batch_start:batch_end, ...]
+ value_layer = inference_value_memory[
+ :sequence_end, batch_start:batch_end, ...]
+
+
+ # adjust the key rotary positional embedding
+ if rotary_pos_emb is not None:
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+ # need to cross check this condition during inference
+ # if not set_inference_key_value_memory:
+ if not is_first_step:
+ # In inference, we compute one token at a time.
+ # Select the correct positional embedding
+ # (only the last token in the sequence)
+ q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
+ else:
+ # In the first forward pass of inference,
+ # we use the entire provided prefix.
+ # q_pos_emb here has the rope embeddings of the entire
+ # prefix + to-be-generated output so
+ # we slice to just the prefix.
+ q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
+ k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
+ rotary_pos_emb = (q_pos_emb, k_pos_emb)
+
+ # ==================================
+ # core attention computation
+ # ==================================
+
+ # apply relative positional encoding (rotary embedding)
+ if rotary_pos_emb is not None:
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+ query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.config)
+ key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.config)
+
+ if not self.use_flash_attn:
+ if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
+ key_layer = key_layer.repeat_interleave(
+ self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2)
+ value_layer = value_layer.repeat_interleave(
+ self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2)
+ if self.checkpoint_core_attention:
+ context_layer = self._checkpointed_attention_forward(
+ query_layer, key_layer, value_layer, attention_mask)
+ else:
+ context_layer = self.core_attention(
+ query_layer, key_layer, value_layer, attention_mask)
+ else:
+ if get_args().ampipe_degree > 1:
+ return query_layer, key_layer, value_layer
+ if not self.sequence_parallel:
+ with tensor_parallel.get_cuda_rng_tracker().fork():
+ context_layer = self.core_attention_flash(query_layer, key_layer, value_layer, attention_mask)
+ else:
+ context_layer = self.core_attention_flash(query_layer, key_layer, value_layer, attention_mask)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output, bias = self.dense(context_layer)
+
+ return output, bias
+
+
+def switch_mlp_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ global_args = get_args()
+ if global_args.moe_model_type == 'megatron_moe':
+ fn(self, *args, **kwargs)
+ return
+ from megatron.legacy.model.transformer import SwitchMLP
+ super(SwitchMLP, self).__init__()
+ config = args[0]
+ layer_number = args[1] if len(args) > 1 else None
+ from megatron.core.parallel_state import get_expert_model_parallel_group
+ from mindspeed.moe.moe import MoE
+ from mindspeed.moe.mixtral_parallel_mlpbm import MixtralParallelMLPBM
+ try:
+ expert_parallel_group = get_expert_model_parallel_group()
+ except AttributeError:
+ expert_parallel_group = None
+
+ if layer_number is None:
+ self.block = MoE(
+ config.hidden_size,
+ MixtralParallelMLPBM(config, ) if global_args.swiglu else ParallelMLP(config, is_expert=False),
+ num_experts=global_args.num_experts,
+ ep_size=config.expert_model_parallel_size,
+ k=config.moe_router_topk,
+ capacity_factor=global_args.moe_train_capacity_factor,
+ eval_capacity_factor=global_args.moe_train_capacity_factor,
+ aux_loss_coef=config.moe_aux_loss_coeff,
+ ep_group=expert_parallel_group,
+ noisy_gate_policy=global_args.noisy_gate_policy,
+ no_drop=global_args.moe_no_drop,
+ dynamic_padding=global_args.moe_dynamic_padding,
+ use_sinkhorn=global_args.moe_use_sinkhorn,
+ sequence_parallel=config.sequence_parallel
+ )
+ else:
+ if layer_number % global_args.expert_interval == 0:
+ self.block = MoE(
+ config.hidden_size,
+ MixtralParallelMLPBM(config, ) if global_args.swiglu else ParallelMLP(config, is_expert=False),
+ num_experts=global_args.num_experts,
+ ep_size=config.expert_model_parallel_size,
+ k=config.moe_router_topk,
+ capacity_factor=global_args.moe_train_capacity_factor,
+ eval_capacity_factor=global_args.moe_train_capacity_factor,
+ aux_loss_coef=config.moe_aux_loss_coeff,
+ ep_group=expert_parallel_group,
+ noisy_gate_policy=global_args.noisy_gate_policy,
+ no_drop=global_args.moe_no_drop,
+ dynamic_padding=global_args.moe_dynamic_padding,
+ use_sinkhorn=global_args.moe_use_sinkhorn,
+ sequence_parallel=config.sequence_parallel
+ )
+ else:
+ self.block = ParallelMLP(config)
+ return
+ return wrapper
+
+
+def switch_mlp_forward_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ global_args = get_args()
+ if global_args.moe_model_type == 'megatron_moe':
+ return fn(self, *args, **kwargs)
+ hidden_states = args[0]
+ used_token = args[1] if len(args) > 1 else None
+ output = self.block(hidden_states, used_token)
+ return output[0], None
+ return wrapper
+
+
+def parallel_transformer_layer_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ from megatron.legacy.model.transformer import SwitchMLP
+ super(ParallelTransformerLayer, self).__init__()
+ global_args = get_args()
+ fn(self, *args, **kwargs)
+ self.pipe_degree = global_args.ampipe_degree
+ self.ampipe_enabled = global_args.ampipe_degree > 1
+ if self.mlp.__class__ is SwitchMLP:
+ experts_modules = self.mlp.block.moe_layer.experts.experts if global_args.moe_model_type == 'deepspeed_moe' \
+ else self.mlp.local_experts
+ for expert in experts_modules:
+ expert.layer_number = self.layer_number
+ else:
+ self.mlp.layer_number = self.layer_number
+
+ return wrapper
+
+
+def parallel_transformer_layer_forward_ampipe(self, hidden_states, attention_mask,
+ encoder_output=None, enc_dec_attn_mask=None,
+ retriever_input=None,
+ retriever_output=None,
+ retriever_attn_mask=None,
+ inference_params=None,
+ rotary_pos_emb=None):
+
+ # Update the params in case the retro param changes during inference
+ # TODO: better redesign with inference param
+ args = get_args()
+ if args.retro_add_retriever:
+ self.retro_num_neighbors = args.retro_num_neighbors
+ self.retro_chunk_length = args.retro_chunk_length
+ self.retro_retrieved_length = \
+ args.retro_num_retrieved_chunks * args.retro_chunk_length
+
+ # hidden_states: [s, b, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ norm_output = self.input_norm(hidden_states)
+
+ if self.ampipe_enabled:
+ q, k, v = self.self_attention(
+ norm_output,
+ attention_mask,
+ inference_params=inference_params,
+ rotary_pos_emb=rotary_pos_emb)
+ # release memory to reduce peak memory usage
+ del norm_output
+ dense_layer = self.self_attention.dense
+ ln = self.post_attention_norm
+ k, v = [rearrange(x, 's b n d -> s b (n d)') for x in [k, v]]
+
+ ampipe_forward_args = ForwardArgs(
+ dense_layer, bias_dropout_add_fused_train, ln, self.mlp.block, self.hidden_dropout
+ )
+ out_mlp, residual = AttMoEPipe.apply(q, k, v, hidden_states, attention_mask, ampipe_forward_args)
+
+ with self.bias_dropout_add_exec_handler():
+ output = bias_dropout_add_fused_train(
+ out_mlp,
+ None,
+ residual,
+ self.hidden_dropout)
+
+ output = make_viewless_tensor(inp=output, requires_grad=output.requires_grad, keep_graph=True)
+
+ return output
+
+ # Self attention.
+ attention_output, attention_bias = \
+ self.self_attention(
+ norm_output,
+ attention_mask,
+ inference_params=inference_params,
+ rotary_pos_emb=rotary_pos_emb)
+
+ # Residual connection.
+ if self.apply_residual_connection_post_norm:
+ residual = norm_output
+ else:
+ residual = hidden_states
+
+ if self.drop_path is None:
+ # jit scripting for a nn.module (with dropout) is not
+ # trigerring the fusion kernel. For now, we use two
+ # different nn.functional routines to account for varying
+ # dropout semantics during training and inference phases.
+ if self.bias_dropout_fusion:
+ if self.training:
+ bias_dropout_add_func = bias_dropout_add_fused_train
+ else:
+ bias_dropout_add_func = bias_dropout_add_fused_inference
+ else:
+ bias_dropout_add_func = get_bias_dropout_add(self.training)
+
+ if attention_bias is not None:
+ attention_bias = attention_bias.expand_as(residual)
+ with self.bias_dropout_add_exec_handler():
+ norm_input = bias_dropout_add_func(
+ attention_output,
+ attention_bias,
+ residual,
+ self.hidden_dropout)
+ else:
+ out = torch.nn.functional.dropout(attention_output + attention_bias,
+ p=self.hidden_dropout,
+ training=self.training)
+ norm_input = residual + self.drop_path(out)
+
+ # Layer norm post the self attention.
+ norm_output = self.post_attention_norm(norm_input)
+
+ # Cross attention.
+ if self.layer_type == LayerType.encoder:
+ pass
+ elif self.layer_type == LayerType.decoder:
+ norm_input, norm_output = \
+ self.default_decoder_cross_attention(
+ encoder_output,
+ enc_dec_attn_mask,
+ norm_input,
+ norm_output,
+ bias_dropout_add_func)
+ elif self.layer_type == LayerType.retro_encoder:
+ norm_input, norm_output = \
+ self.retro_encoder_cross_attention(
+ retriever_output,
+ norm_input,
+ norm_output,
+ bias_dropout_add_func)
+ elif self.layer_type in (LayerType.retro_decoder,
+ LayerType.retro_decoder_with_retriever):
+ retriever_output, norm_input, norm_output = \
+ self.retro_decoder_cross_attention(
+ retriever_input,
+ retriever_output,
+ retriever_attn_mask,
+ norm_input,
+ norm_output,
+ inference_params,
+ bias_dropout_add_func)
+ else:
+ raise Exception("Unsupported layer type, '%s'." %
+ self.layer_type.name)
+
+ # MLP.
+ mlp_output, mlp_bias = self.mlp(norm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_norm:
+ residual = norm_output
+ else:
+ residual = norm_input
+
+ if self.drop_path is None:
+ if mlp_bias is not None:
+ mlp_bias = mlp_bias.expand_as(residual)
+ with self.bias_dropout_add_exec_handler():
+ output = bias_dropout_add_func(
+ mlp_output,
+ mlp_bias,
+ residual,
+ self.hidden_dropout)
+
+ # Jit compiled function creates 'view' tensor. This tensor
+ # potentially gets saved in the MPU checkpoint function context,
+ # which rejects view tensors. While making a viewless tensor here
+ # won't result in memory savings (like the data loader, or
+ # p2p_communication), it serves to document the origin of this
+ # 'view' tensor.
+ output = make_viewless_tensor(inp=output, requires_grad=output.requires_grad, keep_graph=True)
+
+ else:
+ if mlp_bias is not None:
+ mlp_output = mlp_output + mlp_bias
+ out = torch.nn.functional.dropout(mlp_output,
+ p=self.hidden_dropout,
+ training=self.training)
+ output = residual + self.drop_path(out)
+
+ if self.layer_type == LayerType.retro_decoder_with_retriever:
+ return output, retriever_output
+ else:
+ return output
diff --git a/model/train/yoco_moe/mindspeed/moe/__init__.py b/model/train/yoco_moe/mindspeed/moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/__init__.py b/model/train/yoco_moe/mindspeed/moe/ampipe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7addb594ca3a33c04da35146d1f8339a56e4f0d8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe.py
@@ -0,0 +1,327 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import itertools
+from collections import namedtuple
+
+import torch
+from einops import rearrange
+
+from megatron.training import get_args
+from mindspeed.moe.ampipe.ampipe_args import (ForwardArgs, FlashAttentionFwdArgs, FwdCommArgs, BiasDropoutAddNormArgs,
+ MLPFwdArgs, PostMLPArgs, BwdCommArgs, FlashAttentionBwdArgs, MLPBwdArgs)
+from mindspeed.moe.ampipe.ampipe_async_communication import AsyncCommunication
+from mindspeed.moe.ampipe.ampipe_bias_dropout_add_ln_computer import BiasDropoutAddNormComputer
+from mindspeed.moe.ampipe.ampipe_fa_computer import FlashAttentionComputer
+from mindspeed.moe.ampipe.ampipe_moe_gating_computer import MoEGatingComputer
+from mindspeed.moe.ampipe.ampipe_moe_mlp_computer import MoEMLPComputer
+from mindspeed.moe.ampipe.ampipe_post_mlp_computer import MoEPostMLPComputer
+from mindspeed.moe.async_comm_utils import get_async_comm_utils_data_instance
+
+
+class AttMoEPipe(torch.autograd.Function):
+ """
+ Ampipe autograd.Function Class
+
+ Include FlashAttention & LayerNorm & MoE Layer
+ Args:
+ q: query
+ k: key
+ v: value
+ hidden_states: hidden_states before transformer layer used as residual.
+ attention_mask: global attention mask.
+ attention_dense: post attention dense layer object.
+ bias_dropout_add_func: bias dropout add function
+ post_attention_norm: post attention norm object.
+ moe: moe layer object.
+ hidden_dropout: dropout prob.
+ """
+ @staticmethod
+ def forward(ctx, q, k, v, hidden_states, attention_mask, ampipe_forward_args: ForwardArgs):
+ attention_dense = ampipe_forward_args.attention_dense
+ bias_dropout_add_func = ampipe_forward_args.bias_dropout_add_func
+ post_attention_norm = ampipe_forward_args.post_attention_norm
+ moe = ampipe_forward_args.moe
+ hidden_dropout = ampipe_forward_args.hidden_dropout
+
+ global_args = get_args()
+ pipe_degree = global_args.ampipe_degree
+ AttMoEPipe.save_args_to_ctx(ctx, ampipe_forward_args, global_args)
+
+ # 初始化反向保存tensor列表
+ flash_tensor_list = []
+ dense_tensor_list = []
+ bdal_tensor_list = []
+ gate_tensor_list = []
+ mlp_tensor_list = []
+ post_mlp_tensor_list = []
+
+ # 初始化临时列表
+ ln_input_list = []
+ moe_output_list = []
+ weights_list = [None] * pipe_degree
+ token_ec_idx_list = [None] * pipe_degree
+ mlp_inputs, a2a_inputs, a2a_events, ag_events = AttMoEPipe._init_fwd_comm_list()
+
+ # 初始化attention相关变量
+ q_shape = q.shape
+ ctx.head = q_shape[2]
+ q = rearrange(q, "s b n d -> s b (n d)")
+ fa_fwd_args = AttMoEPipe._init_attention_args(pipe_degree, q_shape, attention_dense, flash_tensor_list)
+ # 切分残差以及bias
+ hidden_states_chunks = hidden_states.chunk(pipe_degree, dim=0)
+ bias_chunks = attention_dense.bias.chunk(pipe_degree, dim=0) if attention_dense.bias is not None else None
+ ln_seq_len = hidden_states.shape[0]
+ ctx.fa_computer = fa_computer = FlashAttentionComputer(fa_fwd_args)
+ for c in range(pipe_degree):
+ # Attention(FA)
+ fa_fwd_args.cur_degree = c
+ fwd_comm_args = FwdCommArgs(c, mlp_inputs, a2a_inputs, a2a_events, ag_events)
+ ctx.async_comm = async_comm = AsyncCommunication(fwd_comm_args)
+ detach_attn_out, attn_out, attn_bias = fa_computer.forward(ctx, q, k, v, attention_mask)
+ fa_fwd_args.q_token_start_idx += fa_fwd_args.chunk_len
+
+ # Bias + Dropout + Add + LN
+ bias_chunk = bias_chunks[c] if attention_dense.bias is not None else None
+ bdal_fwd_args = BiasDropoutAddNormArgs(bias_dropout_add_func, post_attention_norm,
+ hidden_states_chunks[c], bias_chunk, hidden_dropout)
+ ctx.bdal_computer = bdal_computer = BiasDropoutAddNormComputer(bdal_tensor_list, bdal_fwd_args)
+ ln_output, ln_input = bdal_computer.forward(ctx, attn_out)
+ attn_out.untyped_storage().resize_(0)
+ dense_tensor_list.append(detach_attn_out)
+ dense_tensor_list.append(attn_out)
+ ln_input_list.append(ln_input)
+
+ # MoE Gating以及token重排
+ ctx.gate_computer = gate_computer = MoEGatingComputer(moe, gate_tensor_list)
+ gate_output = gate_computer.forward(ln_output)
+ if global_args.enable_token_rearrange_opt:
+ dispatched_input, l_aux, token_ec_idx_list[c], weights_list[c] = gate_output
+ else:
+ dispatched_input, l_aux, weights_list[c] = gate_output
+ ln_output.untyped_storage().resize_(0)
+ bdal_tensor_list.append(ln_output)
+
+ # mlp前第一次all2all以及allgather通信
+ mlp_inputs = async_comm.comm_before_moe_mlp_fwd(ctx, dispatched_input)
+ dispatched_input.untyped_storage().resize_(0)
+ gate_tensor_list.append(dispatched_input)
+
+ # MoE MLP
+ mlp_fwd_args = MLPFwdArgs(a2a_events, ag_events)
+ ctx.mlp_computer = mlp_computer = MoEMLPComputer(moe, mlp_tensor_list, mlp_fwd_args)
+ mlp_outputs = mlp_computer.forward(ctx, mlp_inputs, a2a_inputs)
+
+ # token反重排
+ post_mlp_fwd_args = PostMLPArgs(ln_seq_len // pipe_degree, a2a_events,
+ moe_output_list, weights_list, token_ec_idx_list)
+ ctx.post_mlp_computer = post_mlp_computer = MoEPostMLPComputer(post_mlp_tensor_list, post_mlp_fwd_args)
+ moe_output_list = post_mlp_computer.forward(ctx, mlp_outputs)
+ AttMoEPipe.save_tensors_for_bwd(ctx, [flash_tensor_list, dense_tensor_list, bdal_tensor_list,
+ gate_tensor_list, mlp_tensor_list, post_mlp_tensor_list])
+ ret = torch.cat(moe_output_list), torch.cat(ln_input_list)
+ return ret
+
+ @staticmethod
+ def backward(ctx, grad_moe_outs, grad_ln_ins):
+ global_args = get_args()
+ pipe_degree = ctx.pipe_degree
+ context_parallel = global_args.context_parallel_size > 1
+ sequence_parallel = global_args.sequence_parallel
+
+ # 取前向保存的tensor
+ saved_tensors_list = list(ctx.saved_tensors)
+ (flash_tensor_list_len, dense_tensor_list_len,
+ bdal_tensor_list_len, gate_tensor_list_len,
+ mlp_tensor_list_len, post_mlp_tensor_list_len) = ctx.tensor_list_length
+ start_index = 0
+ segments = []
+
+ for length in ctx.tensor_list_length:
+ end_index = start_index + length
+ segments.append(saved_tensors_list[start_index:end_index])
+ start_index = end_index
+ (flash_tensor_list, dense_tensor_list,
+ bdal_tensor_list, gate_tensor_list,
+ mlp_tensor_list, post_mlp_tensor_list) = segments
+
+ # 切分传入backward的grad
+ grad_moe_out_list = grad_moe_outs.chunk(pipe_degree)
+ grad_ln_ins_list = grad_ln_ins.chunk(pipe_degree)
+ # 初始化临时变量
+ grad_hidden, grad_q, grad_k, grad_v = [], [], None, None
+ grad_mlp_input_list, grad_a2a_input_list, a2a_events, ag_events = AttMoEPipe._init_bwd_comm_list(ctx)
+
+ for c in range(pipe_degree - 1, -1, -1):
+ # 计算token反重排的反向
+ grad_moe_out_chunk = grad_moe_out_list[c].view(-1, ctx.hidden_size)
+ post_mlp_list_slice_len = post_mlp_tensor_list_len // pipe_degree
+ grad_post_mlp = ctx.post_mlp_computer.backward(
+ post_mlp_tensor_list[c * post_mlp_list_slice_len:(c + 1) * post_mlp_list_slice_len],
+ grad_moe_out_chunk
+ )
+ # 反向第一次all2all以及allgather通信
+ bwd_comm_args = BwdCommArgs(c, grad_mlp_input_list, grad_a2a_input_list, a2a_events, ag_events)
+ ctx.async_comm.bwd_args = bwd_comm_args
+ grad_mlp_input_list = ctx.async_comm.comm_before_moe_mlp_bwd(ctx, grad_post_mlp)
+ del post_mlp_tensor_list[c * post_mlp_list_slice_len:(c + 1) * post_mlp_list_slice_len]
+ # 手动清理ctx中computer保存的tensor,以减少峰值内存
+ ctx.post_mlp_computer = None
+ ctx.async_comm = None
+ # 专家mlp反向计算
+ bwd_mlp_args = MLPBwdArgs(sequence_parallel, mlp_tensor_list_len, a2a_events, ag_events, mlp_tensor_list)
+ if ctx.pipe_experts:
+ bwd_mlp_args.second_a2a_events = []
+ ctx.mlp_computer.mlp_bwd_args = bwd_mlp_args
+ mlp_bwd_grads = ctx.mlp_computer.backward(ctx, grad_mlp_input_list, grad_a2a_input_list)
+ # 手动清理ctx中computer保存的tensor,以减少峰值内存
+ ctx.mlp_computer = None
+
+ fa_bwd_args = FlashAttentionBwdArgs(grad_q, grad_k, grad_v, flash_tensor_list, dense_tensor_list,
+ flash_tensor_list_len=flash_tensor_list_len,
+ dense_tensor_list_len=dense_tensor_list_len)
+ if context_parallel:
+ fa_bwd_args.kv_list = []
+ fa_bwd_args.dkv_list = []
+ fa_bwd_args.dout_list = []
+ else:
+ fa_bwd_args.v = flash_tensor_list.pop()
+ fa_bwd_args.k = flash_tensor_list.pop()
+ ctx.fa_computer.fa_bwd_args = fa_bwd_args
+ for c in range(pipe_degree - 1, -1, -1):
+ # 反向等待最后一次all2all
+ grad_mlp = AttMoEPipe.bwd_second_all2all_wait_last(ctx, c, mlp_bwd_grads, a2a_events, bwd_mlp_args)
+ # gating&token重排反向
+ gate_list_slice_len = gate_tensor_list_len // pipe_degree
+ grad_ln_out = ctx.gate_computer.backward(
+ gate_tensor_list[c * gate_list_slice_len:(c + 1) * gate_list_slice_len],
+ grad_mlp
+ )
+ del gate_tensor_list[c * gate_list_slice_len:(c + 1) * gate_list_slice_len]
+
+ # bias dropout add ln 反向
+ bdal_list_slice_len = bdal_tensor_list_len // pipe_degree
+ bdal_list_slice = bdal_tensor_list[c * bdal_list_slice_len:(c + 1) * bdal_list_slice_len]
+ grad_dense, d_hidden_grad, d_bias_grad = ctx.bdal_computer.backward(ctx, bdal_list_slice,
+ grad_ln_out, grad_ln_ins_list[c])
+ grad_hidden.insert(0, d_hidden_grad)
+ del bdal_list_slice
+ del bdal_tensor_list[c * bdal_list_slice_len:(c + 1) * bdal_list_slice_len]
+
+ # fa反向
+ fa_bwd_args.cur_degree = c
+ grad_q, grad_k, grad_v = ctx.fa_computer.backward(ctx, grad_dense)
+ # 手动清理ctx中computer保存的tensor,以减少峰值内存
+ ctx.gate_computer = None
+ ctx.bdal_computer = None
+ ctx.fa_computer = None
+ if not context_parallel:
+ grad_q = torch.cat(grad_q, dim=0)
+ grad_q = rearrange(grad_q, "s b (n d) -> s b n d", n=ctx.head)
+ return grad_q, grad_k, grad_v, torch.cat(grad_hidden), None, None
+
+ @staticmethod
+ def save_args_to_ctx(ctx, ampipe_forward_args, global_args):
+ ctx.ampipe_forward_args = ampipe_forward_args
+ ctx.sequence_parallel = global_args.sequence_parallel
+ ctx.num_experts = global_args.num_experts
+ ctx.num_local_experts = global_args.num_experts // global_args.expert_model_parallel_size
+ ctx.ep_size = global_args.expert_model_parallel_size
+ ctx.hidden_size = global_args.hidden_size
+ ctx.pipe_degree = global_args.ampipe_degree
+ ctx.ampipe_tp_sp_comm_overlap = global_args.ampipe_tp_sp_comm_overlap
+ ctx.pipe_experts = global_args.use_pipe_experts
+ ctx.pipe_experts_multi_data = global_args.pipe_experts_multi_data
+ ctx.pipe_experts_multi_stream = global_args.pipe_experts_multi_stream
+ ctx.flash_args = []
+ ctx.mlp_args = []
+
+ @staticmethod
+ def save_tensors_for_bwd(ctx, tensor_list):
+ flat_list = itertools.chain.from_iterable(tensor_list)
+ ctx.save_for_backward(*flat_list)
+ ctx.tensor_list_length = [len(x) for x in tensor_list]
+ for lst in tensor_list:
+ lst.clear()
+
+ @staticmethod
+ def _init_attention_args(pipe_degree, q_shape, attention_dense, flash_tensor_list):
+ seqlen, batch_size, head_num, head_dim = q_shape
+ chunk_len = seqlen // pipe_degree
+ softmax_scale = head_dim ** (-0.5)
+ return FlashAttentionFwdArgs(flash_tensor_list, attention_dense, head_num, softmax_scale, chunk_len)
+
+ @staticmethod
+ def bwd_second_all2all_wait_last(ctx, cur_degree, mlp_bwd_grads, a2a_events, mlp_bwd_args):
+ grad_mlp_last = mlp_bwd_grads[cur_degree]
+ if ctx.use_ampipe_with_pipe_expert and cur_degree == 0:
+ mlp_bwd_args.second_a2a_events[-1].wait()
+ grad_combine = torch.cat([torch.cat(i, dim=1) for i in grad_mlp_last], dim=1)
+ grad_mlp_last = grad_combine.reshape(ctx.num_experts, -1, ctx.hidden_size)
+ elif ctx.ampipe_tp_sp_comm_overlap and cur_degree == 0:
+ a2a_events[-1].wait()
+ grad_combine = torch.cat(grad_mlp_last, dim=1)
+ grad_mlp_last = grad_combine.reshape(ctx.num_experts, -1, ctx.hidden_size)
+
+ if not ctx.ampipe_tp_sp_comm_overlap:
+ a2a_events[cur_degree].wait()
+ return grad_mlp_last
+
+ @staticmethod
+ def _init_fwd_comm_list():
+ global_args = get_args()
+ pipe_degree = global_args.ampipe_degree
+ num_local_experts = global_args.num_experts // global_args.expert_model_parallel_size
+ pipe_experts_multi_data = global_args.pipe_experts_multi_data
+ pipe_experts_multi_stream = global_args.pipe_experts_multi_stream
+ a2a_inputs = []
+ ag_events = []
+
+ if not global_args.ampipe_tp_sp_comm_overlap:
+ mlp_inputs = [None] * pipe_degree
+ a2a_events = []
+ elif not global_args.use_pipe_experts or pipe_experts_multi_data <= pipe_degree:
+ mlp_inputs = [None] * (pipe_degree * num_local_experts)
+ a2a_events = [None] * (pipe_degree * num_local_experts)
+ else:
+ mlp_inputs = [None] * (pipe_experts_multi_data * num_local_experts)
+ a2a_events = [None] * (pipe_experts_multi_data * num_local_experts)
+
+ if pipe_experts_multi_stream:
+ ag_events = [None] * (pipe_experts_multi_data * num_local_experts)
+ get_async_comm_utils_data_instance().fw_ag_output = [None] * (pipe_experts_multi_data * num_local_experts)
+ CommList = namedtuple("CommList", ["mlp_inputs", "a2a_inputs", "a2a_events", "ag_events"])
+ comm_list = CommList(mlp_inputs, a2a_inputs, a2a_events, ag_events)
+ return comm_list
+
+ @staticmethod
+ def _init_bwd_comm_list(ctx):
+ if not ctx.ampipe_tp_sp_comm_overlap:
+ grad_mlp_input_list = [None] * ctx.pipe_degree
+ grad_a2a_input_list = [None] * ctx.pipe_degree
+ a2a_events = []
+ elif not ctx.pipe_experts or ctx.pipe_experts_multi_data <= ctx.pipe_degree:
+ grad_mlp_input_list = [None] * (ctx.pipe_degree * ctx.num_local_experts)
+ grad_a2a_input_list = [None] * (ctx.pipe_degree * ctx.num_local_experts)
+ a2a_events = [None] * (ctx.pipe_degree * ctx.num_local_experts)
+ else:
+ grad_mlp_input_list = [None] * (ctx.pipe_experts_multi_data * ctx.num_local_experts)
+ grad_a2a_input_list = [None] * (ctx.pipe_experts_multi_data * ctx.num_local_experts)
+ a2a_events = [None] * (ctx.pipe_experts_multi_data * ctx.num_local_experts)
+
+ ag_events = []
+ if ctx.pipe_experts_multi_stream:
+ ag_events = [None] * (ctx.pipe_experts_multi_data * ctx.num_local_experts)
+ CommList = namedtuple("CommList", ["mlp_inputs", "a2a_inputs", "a2a_events", "ag_events"])
+ comm_list = CommList(grad_mlp_input_list, grad_a2a_input_list, a2a_events, ag_events)
+ return comm_list
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_args.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..883c476a37610f3b0e0c6de15e78f09e572d4181
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_args.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass, field
+from typing import Union, Callable, Optional, List
+
+import torch
+from torch import Tensor
+
+from megatron.core import tensor_parallel
+from megatron.legacy.model import LayerNorm, RMSNorm
+
+
+@dataclass
+class ForwardArgs:
+ attention_dense: tensor_parallel.RowParallelLinear
+ bias_dropout_add_func: Callable
+ post_attention_norm: Union[LayerNorm, RMSNorm]
+ moe: torch.nn.Module
+ hidden_dropout: float
+
+
+@dataclass
+class FlashAttentionFwdArgs:
+ flash_tensor_list: List[Tensor]
+ attention_dense: tensor_parallel.RowParallelLinear
+ head_num: int
+ softmax_scale: float
+ chunk_len: int
+ q_token_start_idx: int = 0
+ sparse_mode: int = 0
+ cur_degree: int = 0
+ kv_list: List[Tensor] = field(default_factory=list)
+ o_max_sum_list: List[Tensor] = field(default_factory=list)
+
+
+@dataclass
+class FACpFwdArgs:
+ q: Tensor
+ k: Tensor
+ v: Tensor
+
+
+@dataclass
+class FlashAttentionSaveForBwdArgs:
+ n: int = 0
+ rank: int = 0
+ keep_prob: float = 0.0
+ cp_size: int = 0
+ prev_rank: int = 0
+ next_rank: int = 0
+ softmax_scale: float = 0.0
+ next_tokens: int = 0
+ cp_group: torch.distributed.ProcessGroup = None
+ cp_group_for_send_recv_overlap: torch.distributed.ProcessGroup = None
+ rng_states_qa_kva: List = field(default_factory=list)
+ rng_states_qb_kva: List = field(default_factory=list)
+ rng_states_qb_kvb: List = field(default_factory=list)
+
+
+@dataclass
+class FlashAttentionBwdArgs:
+ grad_q: List
+ grad_k: Optional[Tensor]
+ grad_v: Optional[Tensor]
+ flash_tensor_list: List[Tensor]
+ dense_tensor_list: List[Tensor]
+ attn_out_all: Tensor = None
+ k: Tensor = None
+ v: Tensor = None
+ cur_degree: int = 0
+ flash_tensor_list_len: int = 0
+ dense_tensor_list_len: int = 0
+ kv_list: List[Tensor] = field(default_factory=list)
+ dkv_list: List[Tensor] = field(default_factory=list)
+ dout_list: List[Tensor] = field(default_factory=list)
+
+
+@dataclass
+class BiasDropoutAddNormArgs:
+ bias_dropout_add_func: Callable
+ post_attention_norm: Union[LayerNorm, RMSNorm]
+ residual: Tensor
+ bias: Optional[Tensor]
+ prob: float
+
+
+@dataclass
+class FwdCommArgs:
+ cur_degree: int
+ mlp_inputs: List[Tensor]
+ a2a_inputs: List[Tensor]
+ a2a_events: List
+ ag_events: List
+
+
+@dataclass
+class BwdCommArgs:
+ cur_degree: int
+ grad_mlp_input_list: List[Tensor]
+ grad_a2a_input_list: List[Tensor]
+ a2a_events: List
+ ag_events: List
+
+
+@dataclass
+class MLPFwdArgs:
+ a2a_events: List = field(default_factory=list)
+ ag_events: List = field(default_factory=list)
+
+
+@dataclass
+class MLPSaveForBwdArgs:
+ ampipe_degree: int = 0
+ num_local_experts: int = 0
+ ep_size: int = 0
+ hidden_size: int = 0
+ sequence_parallel: bool = False
+ multi_data: int = 0
+ multi_stream: bool = False
+ input_list_before_expert: List[Tensor] = field(default_factory=list)
+
+
+@dataclass
+class MLPBwdArgs:
+ sequence_parallel: bool
+ mlp_tensor_list_len: int
+ a2a_events: List
+ ag_events: List
+ mlp_tensor_list: List[Tensor]
+ second_a2a_events: List = field(default_factory=list)
+
+
+@dataclass
+class PostMLPArgs:
+ seqlen: int = 0
+ a2a_events: List = field(default_factory=list)
+ moe_output_list: List[Tensor] = field(default_factory=list)
+ weights_list: List[Tensor] = field(default_factory=list)
+ token_ec_idx_list: List[Tensor] = field(default_factory=list)
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_async_communication.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_async_communication.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29492f00422cfd0df63db57085468e91f6bfb57
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_async_communication.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from megatron.training import get_args
+from mindspeed.moe.async_comm_utils import async_fw_all_gather, async_all_to_all, async_all_gather
+
+
+class AsyncCommunication:
+ def __init__(self, fwd_args, bwd_args=None):
+ self.bwd_args = bwd_args
+ self.fwd_args = fwd_args
+
+ def comm_before_moe_mlp_fwd(self, ctx, dispatched_input):
+ cur_degree = self.fwd_args.cur_degree
+ a2a_events = self.fwd_args.a2a_events
+ mlp_inputs = self.fwd_args.mlp_inputs
+ a2a_inputs = self.fwd_args.a2a_inputs
+ args = get_args()
+ pipe_experts = args.use_pipe_experts
+ num_experts = args.num_experts
+ num_local_experts = num_experts // args.expert_model_parallel_size
+
+ # 不开启ampipe_tp_sp_comm_overlap时,不切分专家维度,直接做全量专家的all2all
+ if not args.ampipe_tp_sp_comm_overlap:
+ a2a_tokens, a2a_handle = async_all_to_all(dispatched_input)
+ a2a_events.append(a2a_handle)
+ mlp_inputs[cur_degree] = a2a_tokens
+ return mlp_inputs
+
+ # 开启ampipe_tp_sp_comm_overlap时,按照专家切分token后再all2all
+ chunk_list = dispatched_input.chunk(num_experts)
+ for exp_index in range(num_local_experts):
+ chunks = chunk_list[exp_index:num_experts:num_local_experts]
+ a2a_tokens = torch.cat(chunks)
+ # pipe-experts适配
+ if pipe_experts:
+ comm_result = self._pipe_expert_comm_before_moe_mlp_fwd(ctx, exp_index, a2a_tokens)
+ if comm_result is not None:
+ continue
+ # 不开启pipe_experts或者pipe_experts_multi_data < ampipe_degree时不再切分token,直接all2all
+ output, a2a_handle = async_all_to_all(a2a_tokens)
+ index = cur_degree * num_local_experts + exp_index
+ mlp_inputs[index] = output
+ a2a_events[index] = a2a_handle
+ # 不提前析构通信tensor,保证正常释放通信后tensor内存
+ a2a_inputs.append(a2a_tokens)
+ return mlp_inputs
+
+ def comm_before_moe_mlp_bwd(self, ctx, grad_moe_out_chunk):
+ cur_degree = self.bwd_args.cur_degree
+ a2a_events = self.bwd_args.a2a_events
+ grad_mlp_input_list = self.bwd_args.grad_mlp_input_list
+ grad_a2a_input_list = self.bwd_args.grad_a2a_input_list
+ # 反向第一次all2all
+ # 纯ep通信隐藏
+ if not ctx.ampipe_tp_sp_comm_overlap:
+ grad_mlp_input_list[cur_degree], a2a_handle = async_all_to_all(grad_moe_out_chunk)
+ a2a_events.insert(0, a2a_handle)
+ return grad_mlp_input_list
+
+ # tp-sp域&ep域通信隐藏适配
+ chunk_list = grad_moe_out_chunk.chunk(ctx.num_experts)
+ for exp_index in range(ctx.num_local_experts):
+ chunks = chunk_list[exp_index:ctx.num_experts:ctx.num_local_experts]
+ grad_mlp_tokens = torch.cat(chunks)
+ # pipe-experts适配
+ if ctx.pipe_experts:
+ comm_result = self._pipe_expert_comm_before_moe_mlp_bwd(ctx, exp_index, grad_mlp_tokens)
+ if comm_result is not None:
+ continue
+ # 不开启pipe_experts或者pipe_experts_multi_data < ampipe_degree时不再切分token,直接all2all
+ grad_a2a_tokens, a2a_handle = async_all_to_all(grad_mlp_tokens)
+ index = (ctx.pipe_degree - 1 - cur_degree) * ctx.num_local_experts + exp_index
+ grad_mlp_input_list[index] = grad_a2a_tokens
+ a2a_events[index] = a2a_handle
+ # 不提前析构通信tensor,保证正常释放通信后tensor内存
+ grad_a2a_input_list[index] = grad_mlp_tokens
+ return grad_mlp_input_list
+
+ def _pipe_expert_comm_before_moe_mlp_fwd(self, ctx, exp_index, input_tokens):
+ cur_degree = self.fwd_args.cur_degree
+ a2a_events = self.fwd_args.a2a_events
+ mlp_inputs = self.fwd_args.mlp_inputs
+ a2a_inputs = self.fwd_args.a2a_inputs
+ ag_events = self.fwd_args.ag_events
+ args = get_args()
+ pipe_degree = args.ampipe_degree
+ pipe_experts_multi_data = args.pipe_experts_multi_data
+ pipe_experts_multi_stream = args.pipe_experts_multi_stream
+ # pipe_experts_multi_data > ampipe_degree时, 对token的C维度再切分
+ ctx.slice_size = slice_size = pipe_experts_multi_data // pipe_degree
+ a2a_token_chunk = input_tokens.chunk(slice_size, dim=1)
+ # 多流场景下pipe_experts_multi_data必须大于等于ampipe_degree
+ if pipe_experts_multi_data >= pipe_degree and pipe_experts_multi_stream:
+ for i in range(slice_size):
+ # 计算列表中索引适配pipe_experts
+ index = cur_degree * slice_size + exp_index * pipe_experts_multi_data + i
+ if (cur_degree + exp_index + i) == 0 and args.sequence_parallel:
+ a2a_token, a2a_handle = async_all_to_all(a2a_token_chunk[i])
+ else:
+ a2a_token, a2a_handle = async_all_to_all(a2a_token_chunk[i], ag_events[index])
+ a2a_events[index] = a2a_handle
+ mlp_inputs[index] = a2a_token
+ if args.sequence_parallel:
+ ag_token, ag_handle = async_fw_all_gather(a2a_token, a2a_handle, ampipe_with_mlp_multistream=True,
+ index=index)
+ ag_events[index] = ag_handle
+ mlp_inputs[index] = ag_token
+ return mlp_inputs
+ # 非多流场景下pipe_experts_multi_data必须大于ampipe_degree
+ elif pipe_experts_multi_data > pipe_degree and not pipe_experts_multi_stream:
+ for i in range(slice_size):
+ a2a_token, a2a_handle = async_all_to_all(a2a_token_chunk[i])
+ index = cur_degree * slice_size + exp_index * pipe_experts_multi_data + i
+ a2a_events[index] = a2a_handle
+ mlp_inputs[index] = a2a_token
+ a2a_inputs.append(a2a_token_chunk[i])
+ return mlp_inputs
+ return None
+
+ def _pipe_expert_comm_before_moe_mlp_bwd(self, ctx, exp_index, grad_tokens):
+ cur_degree = self.bwd_args.cur_degree
+ a2a_events = self.bwd_args.a2a_events
+ grad_mlp_input_list = self.bwd_args.grad_mlp_input_list
+ ag_events = self.bwd_args.ag_events
+ args = get_args()
+ pipe_degree = args.ampipe_degree
+ grad_token_list = grad_tokens.chunk(ctx.slice_size, dim=1)
+ # 多流场景下pipe_experts_multi_data必须大于等于ampipe_degree
+ if ctx.pipe_experts_multi_data >= pipe_degree and ctx.pipe_experts_multi_stream:
+ for i in range(ctx.slice_size):
+ # 计算列表中索引适配pipe_experts
+ index = (pipe_degree - 1 - cur_degree) * ctx.slice_size + exp_index * ctx.pipe_experts_multi_data + i
+ if cur_degree == pipe_degree - 1 and (exp_index + i) == 0 and args.sequence_parallel:
+ a2a_token, a2a_handle = async_all_to_all(grad_token_list[i])
+ else:
+ a2a_token, a2a_handle = async_all_to_all(grad_token_list[i], ag_events[index])
+ a2a_events[index] = a2a_handle
+ grad_mlp_input_list[index] = a2a_token
+ if args.sequence_parallel:
+ ag_token, ag_handle = async_all_gather(a2a_token, a2a_handle, is_bwd=True)
+ ag_events[index] = ag_handle
+ grad_mlp_input_list[index] = ag_token
+ return grad_mlp_input_list
+ # 非多流场景下pipe_experts_multi_data必须大于ampipe_degree
+ elif ctx.pipe_experts_multi_data > pipe_degree and not ctx.pipe_experts_multi_stream:
+ for i in range(ctx.slice_size):
+ a2a_token, a2a_handle = async_all_to_all(grad_token_list[i])
+ index = (pipe_degree - 1 - cur_degree) * ctx.slice_size + exp_index * ctx.pipe_experts_multi_data + i
+ a2a_events[index] = a2a_handle
+ grad_mlp_input_list[index] = a2a_token
+ return grad_mlp_input_list
+ return None
+
+ def fw_all_gather_not_multistream(self):
+ self.fwd_args.a2a_events[0].wait()
+ # 释放通信内存
+ self.fwd_args.a2a_inputs.pop()
+ _, ag_handle = async_fw_all_gather(self.fwd_args.mlp_inputs[0])
+ self.fwd_args.ag_events.append(ag_handle)
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_bias_dropout_add_ln_computer.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_bias_dropout_add_ln_computer.py
new file mode 100644
index 0000000000000000000000000000000000000000..369628456537fdc5b925904f2a8ab72804b154e0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_bias_dropout_add_ln_computer.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+
+class BiasDropoutAddNormComputer:
+ def __init__(self, bdal_tensor_list, fwd_args):
+ super().__init__()
+ self.bdal_tensor_list = bdal_tensor_list
+ self.fwd_args = fwd_args
+
+ def forward(self, ctx, input_tensor):
+ residual = self.fwd_args.residual
+ bias = self.fwd_args.bias
+ prob = self.fwd_args.prob
+
+ input_tensor = input_tensor.detach()
+ residual = residual.detach()
+ input_tensor.requires_grad = True
+ residual.requires_grad = True
+ ctx.bias = False
+ if isinstance(bias, torch.Tensor):
+ bias = bias.detach()
+ bias.requires_grad = True
+ self.bdal_tensor_list.append(bias)
+ ctx.bias = True
+
+ with torch.enable_grad():
+ ln_input = self.fwd_args.bias_dropout_add_func(input_tensor, bias, residual, prob)
+ detach_ln_input = ln_input.detach()
+ detach_ln_input.requires_grad = True
+ output = self.fwd_args.post_attention_norm(detach_ln_input)
+ self.bdal_tensor_list.extend([ln_input, detach_ln_input, input_tensor, residual])
+ return output, ln_input
+
+ def backward(self, ctx, saved_tensor_list, grad_ln_outs, grad_ln_ins):
+ if ctx.bias:
+ bias = saved_tensor_list.pop(0)
+ ln_input, detach_ln_input, input_tensor, residual, output = saved_tensor_list
+ output.backward(grad_ln_outs)
+ grad_ln = detach_ln_input.grad
+ ln_input.backward(grad_ln + grad_ln_ins)
+ input_grad = input_tensor.grad
+ residual_grad = residual.grad
+ bias_grad = bias.grad if ctx.bias else None
+ return input_grad, residual_grad, bias_grad
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_fa.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_fa.py
new file mode 100644
index 0000000000000000000000000000000000000000..bae3cb275eb465a663eb79c5255c5b922df0ca6e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_fa.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch_npu
+from megatron.training import get_args
+
+
+def flash_attn_forward(qkvn, attn_mask=None, softmax_scale=1.0, sparse_mode=0, next_tokens=0):
+ """FlashAttention forward"""
+ args = get_args()
+ q, k, v, n = qkvn
+ output = torch_npu.npu_fusion_attention(
+ q, k, v, n, "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=attn_mask,
+ scale=softmax_scale,
+ pre_tockens=args.pre_tockens,
+ next_tockens=next_tokens,
+ keep_prob=1.0 - args.attention_dropout,
+ inner_precise=0,
+ sparse_mode=sparse_mode
+ )
+ return output
+
+
+def flash_attn_backward(qkvn, dy, softmax_max, softmax_sum, attn_out,
+ attn_mask=None, softmax_scale=1.0, sparse_mode=0, next_tokens=0):
+ """FlashAttention backward"""
+ q, k, v, n = qkvn
+ output = torch_npu.npu_fusion_attention_grad(
+ q, k, v, dy, n,
+ "SBH",
+ pse=None,
+ padding_mask=None,
+ atten_mask=attn_mask,
+ softmax_max=softmax_max,
+ softmax_sum=softmax_sum,
+ attention_in=attn_out,
+ scale_value=softmax_scale,
+ pre_tockens=k.shape[0],
+ next_tockens=next_tokens,
+ sparse_mode=sparse_mode
+ )
+ return output
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_fa_computer.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_fa_computer.py
new file mode 100644
index 0000000000000000000000000000000000000000..47846823d5be7e127fb463de75dc3ab9264a9751
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_fa_computer.py
@@ -0,0 +1,132 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from megatron.training import get_args
+from mindspeed.core.context_parallel.ring_context_parallel_for_ampipe import (attn_with_cp_for_ampipe_forward,
+ attn_with_cp_for_ampipe_backward)
+from mindspeed.moe.ampipe.ampipe_args import FlashAttentionSaveForBwdArgs, FACpFwdArgs
+from mindspeed.moe.ampipe.ampipe_fa import flash_attn_forward, flash_attn_backward
+
+
+class FlashAttentionComputer:
+ def __init__(self, fa_fwd_args, fa_bwd_args=None):
+ super().__init__()
+ self.fa_bwd_args = fa_bwd_args
+ self.fa_fwd_args = fa_fwd_args
+ self.context_parallel = get_args().context_parallel_size > 1
+
+ def forward(self, ctx, q, k, v, attention_mask):
+ global_args = get_args()
+ flash_tensor_list = self.fa_fwd_args.flash_tensor_list
+ cur_degree = self.fa_fwd_args.cur_degree
+
+ if self.context_parallel:
+ if cur_degree == 0:
+ flash_args_save_for_bwd = FlashAttentionSaveForBwdArgs()
+ ctx.flash_args.append(flash_args_save_for_bwd)
+ fa_cp_fwd_args = FACpFwdArgs(q, k, v)
+ cur_attn_out = attn_with_cp_for_ampipe_forward(ctx.flash_args[0],
+ fa_cp_fwd_args=fa_cp_fwd_args,
+ fa_fwd_args=self.fa_fwd_args)
+ else:
+ flash_args_save_for_bwd = FlashAttentionSaveForBwdArgs()
+ q_token_start_idx = self.fa_fwd_args.q_token_start_idx
+ q_token_end_idx = q_token_start_idx + self.fa_fwd_args.chunk_len
+ next_tokens = q_token_start_idx
+ q_use = q[q_token_start_idx:q_token_end_idx]
+ cur_attn_mask = attention_mask[q_token_start_idx:q_token_end_idx]
+ output_chunk = flash_attn_forward((q_use, k, v, self.fa_fwd_args.head_num),
+ attn_mask=cur_attn_mask,
+ softmax_scale=self.fa_fwd_args.softmax_scale,
+ sparse_mode=self.fa_fwd_args.sparse_mode,
+ next_tokens=next_tokens)
+ cur_attn_out, cur_softmax_max, cur_softmax_sum = output_chunk[0], output_chunk[1], output_chunk[2]
+ flash_tensor_list.extend([q_use, cur_attn_mask, cur_softmax_max, cur_softmax_sum])
+ flash_args_save_for_bwd.next_tokens = next_tokens
+ ctx.flash_args.append(flash_args_save_for_bwd)
+ # 内存优化
+ self._optimize_attn_memory(k, v)
+ # 提前做一次mlp的allgather
+ should_do_allgather_in_attention = (
+ cur_degree == global_args.ampipe_degree - 1
+ and global_args.sequence_parallel
+ and global_args.ampipe_tp_sp_comm_overlap
+ and not global_args.pipe_experts_multi_stream
+ )
+ if should_do_allgather_in_attention:
+ ctx.async_comm.fw_all_gather_not_multistream()
+ # attention后的matmul (RowParallelLinear)
+ detach_attn_out = cur_attn_out.detach()
+ detach_attn_out.requires_grad = True
+ with torch.enable_grad():
+ attn_dense_out, attn_bias = self.fa_fwd_args.attention_dense(detach_attn_out)
+ return detach_attn_out, attn_dense_out, attn_bias
+
+ def backward(self, ctx, grad_output):
+ # attention dense 反向
+ c = self.fa_bwd_args.cur_degree
+ dense_list_slice_len = self.fa_bwd_args.dense_tensor_list_len // ctx.pipe_degree
+ cur_attn_out, attn_dense_out = self.fa_bwd_args.dense_tensor_list[
+ c * dense_list_slice_len:(c + 1) * dense_list_slice_len
+ ]
+ if self.context_parallel and c == ctx.pipe_degree - 1:
+ next_attn_out = self.fa_bwd_args.dense_tensor_list[0]
+ attn_out_all = torch.cat((next_attn_out.unsqueeze(0), cur_attn_out.unsqueeze(0)), dim=0)
+ self.fa_bwd_args.attn_out_all = attn_out_all
+ attn_dense_out.backward(grad_output)
+ grad_flash = cur_attn_out.grad
+ del self.fa_bwd_args.dense_tensor_list[c * dense_list_slice_len:(c + 1) * dense_list_slice_len]
+
+ # FA反向
+ flash_tensor_list = self.fa_bwd_args.flash_tensor_list
+ if self.context_parallel:
+ self.fa_bwd_args.cur_degree = ctx.pipe_degree - 1 - c
+ grad_attention = attn_with_cp_for_ampipe_backward(
+ ctx.flash_args[0], self.fa_bwd_args.attn_out_all, flash_tensor_list, grad_flash,
+ self.fa_bwd_args
+ )
+ grad_q, grad_k, grad_v = grad_attention[0], grad_attention[1], grad_attention[2]
+ else:
+ grad_q, grad_k, grad_v = self.fa_bwd_args.grad_q, self.fa_bwd_args.grad_k, self.fa_bwd_args.grad_v
+ fa_list_slice_len = (self.fa_bwd_args.flash_tensor_list_len - 2) // ctx.pipe_degree
+ q, cur_attn_mask, cur_softmax_max, cur_softmax_sum = flash_tensor_list[
+ c * fa_list_slice_len:(c + 1) * fa_list_slice_len
+ ]
+ softmax_scale = self.fa_fwd_args.softmax_scale
+ grad_attention = flash_attn_backward(
+ (q, self.fa_bwd_args.k, self.fa_bwd_args.v, ctx.head), grad_flash,
+ cur_softmax_max, cur_softmax_sum, cur_attn_out, cur_attn_mask, softmax_scale,
+ next_tokens=ctx.flash_args[c].next_tokens
+ )
+ d_q, d_k, d_v = grad_attention[0], grad_attention[1], grad_attention[2]
+ grad_k = grad_k + d_k if grad_k is not None else d_k
+ grad_v = grad_v + d_v if grad_v is not None else d_v
+ grad_q.insert(0, d_q)
+ self.fa_bwd_args.grad_q, self.fa_bwd_args.grad_k, self.fa_bwd_args.grad_v = grad_q, grad_k, grad_v
+ return grad_q, grad_k, grad_v
+
+ def _optimize_attn_memory(self, k, v):
+ if self.fa_fwd_args.cur_degree == get_args().ampipe_degree - 1:
+ if self.context_parallel:
+ for i, kv in enumerate(self.fa_fwd_args.kv_list):
+ if i < len(self.fa_fwd_args.kv_list) - 1:
+ kv.untyped_storage().resize_(0)
+ k.untyped_storage().resize_(0)
+ v.untyped_storage().resize_(0)
+ self.fa_fwd_args.kv_list.clear()
+ self.fa_fwd_args.o_max_sum_list.clear()
+ else:
+ self.fa_fwd_args.flash_tensor_list.append(k)
+ self.fa_fwd_args.flash_tensor_list.append(v)
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_moe_gating_computer.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_moe_gating_computer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f958613a7a3b79d78e917ee27d827f27b0e7f8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_moe_gating_computer.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) Microsoft Corporation.
+#
+# This source code is licensed under the Apache license found in the
+# LICENSE file in the root directory of this source tree.
+
+# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py
+# reworked/refactored some parts to make it run.
+from collections import namedtuple
+
+import torch
+
+from megatron.training import get_args
+from mindspeed.moe.utils import einsum
+
+
+class MoEGatingComputer:
+ def __init__(self, moe, gate_tensor_list):
+ super().__init__()
+ self.gate_tensor_list = gate_tensor_list
+ self.moe = moe
+
+ def forward(self, logits):
+ detach_logits = logits.detach()
+ detach_logits.requires_grad = True
+
+ d_model = logits.shape[-1]
+ with torch.enable_grad():
+ reshaped_input = detach_logits.reshape(-1, d_model)
+ global_args = get_args()
+ if not global_args.enable_token_rearrange_opt:
+ l_aux, combine_weights, dispatch_mask = self.moe.moe_layer.gate(reshaped_input)
+ dispatch_mask = dispatch_mask.type_as(logits)
+ dispatched_input = einsum("sec,sm->ecm", dispatch_mask, reshaped_input)
+ self.gate_tensor_list.append(detach_logits)
+ return dispatched_input, l_aux, combine_weights
+ else:
+ l_aux, (token_ec_idx, token_weights, expert_select_token_idx) = self.moe.moe_layer.gate(reshaped_input)
+ org_dtype = reshaped_input.dtype
+ if org_dtype == torch.bfloat16: # 规避算子性能劣化问题, 解决后可删除
+ rearranged_input = torch.index_select(
+ reshaped_input.to(torch.float32), dim=0, index=expert_select_token_idx
+ ).to(org_dtype)
+ else:
+ rearranged_input = torch.index_select(
+ reshaped_input, dim=0, index=expert_select_token_idx
+ )
+ capacity = expert_select_token_idx.size(0) // self.moe.num_experts
+ dispatched_input = rearranged_input.reshape(self.moe.num_experts, capacity, d_model).contiguous()
+ self.gate_tensor_list.append(detach_logits)
+ GatingComputerRet = namedtuple('GatingComputerRet',
+ ['dispatched_input', 'l_aux',
+ 'token_ec_idx', 'token_weights'])
+ gating_computer_ret = GatingComputerRet(dispatched_input=dispatched_input, l_aux=l_aux,
+ token_ec_idx=token_ec_idx, token_weights=token_weights)
+ return gating_computer_ret
+
+ def backward(self, saved_tensor_list, grad_output):
+ logits, dispatched_input = saved_tensor_list
+ dispatched_input.backward(grad_output)
+ grad_logits = logits.grad
+ logits.untyped_storage().resize_(0)
+ return grad_logits
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_moe_mlp_computer.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_moe_mlp_computer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ed2edfa3af738b0f2e317ebe170538e5f65547
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_moe_mlp_computer.py
@@ -0,0 +1,229 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from megatron.training import get_args
+from mindspeed.moe.ampipe.ampipe_args import MLPSaveForBwdArgs
+from mindspeed.moe.async_comm_utils import async_all_to_all, async_all_gather, async_fw_all_gather, \
+ async_fw_all_reduce_scatter_ampipe, get_fw_ar_rs_output_ampipe
+from mindspeed.moe.pipe_experts import PipeExpert
+
+
+class MoEMLPComputer:
+ def __init__(self, moe, save_tensor_list, mlp_fwd_args, mlp_bwd_args=None):
+ super().__init__()
+ self.mlp_bwd_args = mlp_bwd_args
+ self.mlp_fwd_args = mlp_fwd_args
+ self.save_tensor_list = save_tensor_list
+ self.moe = moe
+
+ def forward(self, ctx, mlp_inputs, a2a_inputs):
+ global_args = get_args()
+ mlp_save_for_bwd_args = MLPSaveForBwdArgs()
+ a2a_events = self.mlp_fwd_args.a2a_events
+ ag_events = self.mlp_fwd_args.ag_events
+ pipe_degree = global_args.ampipe_degree
+ sequence_parallel = global_args.sequence_parallel
+ num_local_experts = global_args.num_experts // global_args.expert_model_parallel_size
+ ep_size = global_args.expert_model_parallel_size
+ hidden_size = global_args.hidden_size
+ pipe_experts = global_args.use_pipe_experts
+ multi_data = global_args.pipe_experts_multi_data
+ multi_stream = global_args.pipe_experts_multi_stream
+
+ ctx.use_ampipe_with_pipe_expert = (pipe_experts and
+ (multi_data >= pipe_degree and multi_stream)
+ or (multi_data > pipe_degree and not multi_stream))
+ if ctx.use_ampipe_with_pipe_expert:
+ second_a2a_event = []
+ pipe_expert_args = [mlp_inputs, ep_size, num_local_experts, sequence_parallel, multi_data, multi_stream,
+ a2a_events, second_a2a_event, ag_events, hidden_size, self.save_tensor_list]
+ mlp_outputs = PipeExpert.forward(mlp_save_for_bwd_args, self.moe.moe_layer.experts, *pipe_expert_args)
+ ctx.mlp_args = mlp_save_for_bwd_args
+ elif global_args.ampipe_tp_sp_comm_overlap:
+ mlp_outputs = self.ampipe_experts_forward(mlp_save_for_bwd_args, mlp_inputs, a2a_inputs)
+ ctx.mlp_args = mlp_save_for_bwd_args
+ else:
+ mlp_outputs = []
+ for c in range(pipe_degree):
+ a2a_events.pop(0).wait()
+ expert_input = mlp_inputs[c].reshape(ep_size, num_local_experts, -1, hidden_size)
+ detach_expert_input = expert_input.detach()
+ detach_expert_input.requires_grad = True
+ with torch.enable_grad():
+ expert_output = self.moe.moe_layer.experts(detach_expert_input)
+ self.save_tensor_list.extend([detach_expert_input, expert_output])
+ mlp_inputs[c] = expert_output
+ a2a_tokens, a2a_handle = async_all_to_all(expert_output)
+ a2a_events.append(a2a_handle)
+ mlp_outputs.append(a2a_tokens)
+ return mlp_outputs
+
+ def backward(self, ctx, grad_mlp_input_list, grad_a2a_input_list):
+ a2a_events = self.mlp_bwd_args.a2a_events
+ ag_events = self.mlp_bwd_args.ag_events
+ mlp_tensor_list = self.mlp_bwd_args.mlp_tensor_list
+ mlp_bwd_grads = []
+ multi_stream = ctx.pipe_experts_multi_stream
+ # 适配pipe-experts
+ if ctx.use_ampipe_with_pipe_expert:
+ if self.mlp_bwd_args.sequence_parallel and not multi_stream:
+ a2a_events[0].wait()
+ grad_a2a_input_list.pop(0)
+ grad_mlp_input_list[0], ag_handle = async_all_gather(grad_mlp_input_list[0], is_bwd=True)
+ ag_events.append(ag_handle)
+ mlp_bwd_grads = PipeExpert.backward(ctx.mlp_args, grad_mlp_input_list, a2a_events, ag_events,
+ self.mlp_bwd_args.second_a2a_events, mlp_tensor_list)
+ # mlp反向tp-sp&ep通信隐藏流水实现
+ elif ctx.ampipe_tp_sp_comm_overlap:
+ if self.mlp_bwd_args.sequence_parallel:
+ a2a_events[0].wait()
+ grad_a2a_input_list.pop(0)
+ grad_mlp_input_list[0], ag_handle = async_all_gather(grad_mlp_input_list[0], is_bwd=True)
+ ag_events.append(ag_handle)
+ mlp_bwd_grads = self.ampipe_experts_backward(ctx.mlp_args, mlp_tensor_list, grad_mlp_input_list,
+ grad_a2a_input_list, a2a_events, ag_events)
+ # mlp反向纯ep通信隐藏流水实现
+ else:
+ mlp_list_slice_len = self.mlp_bwd_args.mlp_tensor_list_len // ctx.pipe_degree
+ for c in range(ctx.pipe_degree - 1, -1, -1):
+ a2a_events.pop().wait()
+ expert_input, expert_output = mlp_tensor_list[c * mlp_list_slice_len:(c + 1) * mlp_list_slice_len]
+ expert_output.backward(grad_mlp_input_list[c])
+ grad_mlp_input = expert_input.grad.reshape(self.moe.num_experts, -1, self.moe.hidden_size)
+ a2a_grad_mlp_input, a2a_handle = async_all_to_all(grad_mlp_input)
+ mlp_bwd_grads.insert(0, a2a_grad_mlp_input)
+ a2a_events.insert(0, a2a_handle)
+ mlp_tensor_list.clear()
+ return mlp_bwd_grads
+
+ def ampipe_experts_forward(self, ctx, inputs, a2a_inputs):
+ ctx.ampipe_degree = pipe_degree = get_args().ampipe_degree
+ ctx.ep_size = ep_size = get_args().expert_model_parallel_size
+ ctx.num_local_experts = num_local_experts = get_args().num_experts // ep_size
+ ctx.hidden_size = hidden_size = get_args().hidden_size
+ ctx.sequence_parallel = sequence_parallel = get_args().sequence_parallel
+ ag_events = self.mlp_fwd_args.ag_events
+ a2a_events = self.mlp_fwd_args.a2a_events
+
+ output_list = []
+ before_exp_input_list = []
+ after_exp_out_list = []
+
+ for c in range(pipe_degree):
+ for i in range(num_local_experts):
+ cur_index = c * num_local_experts + i
+ # pre expert process
+ if sequence_parallel:
+ ag_events[cur_index].wait()
+ if cur_index < num_local_experts * pipe_degree - 1:
+ a2a_events[cur_index + 1].wait()
+ a2a_inputs.pop()
+ _, ag_handle = async_fw_all_gather(inputs[cur_index + 1],
+ is_use_global_memory_buffer=False)
+ ag_events.append(ag_handle)
+ else:
+ a2a_events[cur_index].wait()
+ a2a_inputs.pop()
+ # expert compute
+ detach_input_chunk = inputs[cur_index].detach()
+ detach_input_chunk.requires_grad = True
+ before_exp_input_list.append(detach_input_chunk)
+ with torch.enable_grad():
+ out = self.moe.moe_layer.experts.experts[i](detach_input_chunk)
+ if isinstance(out, tuple):
+ if cur_index > 0:
+ out, last_chunk_out = out[0], out[-1]
+ else:
+ out = out[0] # Ignore the bias term for now
+
+ # post expert comm
+ async_fw_all_reduce_scatter_ampipe(out, sequence_parallel)
+ after_exp_out_list.append(out)
+ if cur_index > 0:
+ after_exp_out_list[cur_index - 1].untyped_storage().resize_(0)
+ output_list.append(last_chunk_out)
+ if cur_index == pipe_degree * num_local_experts - 1:
+ ar_rs_out = get_fw_ar_rs_output_ampipe(sequence_parallel)
+ a2a_out, a2a2_handle = async_all_to_all(ar_rs_out)
+ a2a2_handle.wait()
+ output_list.append(a2a_out)
+
+ for t in after_exp_out_list:
+ t.untyped_storage().resize_(0)
+ self.save_tensor_list.extend(before_exp_input_list)
+ self.save_tensor_list.extend(after_exp_out_list)
+ outputs = []
+ for c in range(pipe_degree):
+ cur_pipe_out_list = output_list[c * num_local_experts:(c + 1) * num_local_experts]
+ cur_pipe_out = torch.cat(cur_pipe_out_list, dim=1)
+ cur_pipe_out = cur_pipe_out.reshape((num_local_experts * ep_size), -1, hidden_size)
+ outputs.append(cur_pipe_out)
+ return outputs
+
+ def ampipe_experts_backward(self, ctx, saved_tensor_list, *args):
+ pipe_degree = ctx.ampipe_degree
+ num_local_experts = ctx.num_local_experts
+ ep_size = ctx.ep_size
+ hidden_size = ctx.hidden_size
+ sequence_parallel = ctx.sequence_parallel
+
+ before_exp_input_list = saved_tensor_list[:num_local_experts * pipe_degree]
+ after_exp_out_list = saved_tensor_list[num_local_experts * pipe_degree:]
+ grad_output_list, grad_a2a_input_list, a2a_event, ag_events = args
+ grad_a2a2_input_list = []
+ output_list = []
+
+ for c in range(pipe_degree - 1, -1, -1):
+ for i in range(num_local_experts):
+ reversed_index = c * num_local_experts + i
+ normal_index = (pipe_degree - c - 1) * num_local_experts + i
+ # pre expert process
+ if sequence_parallel:
+ ag_events[normal_index].wait()
+ if normal_index < num_local_experts * pipe_degree - 1:
+ a2a_event[normal_index + 1].wait()
+ grad_a2a_input_list.pop(0)
+ grad_output = grad_output_list[normal_index + 1]
+ ag_grad_output, ag_handle = async_all_gather(grad_output, is_bwd=True)
+ grad_output_list[normal_index + 1] = ag_grad_output
+ ag_events.append(ag_handle)
+ else:
+ a2a_event[normal_index].wait()
+ grad_a2a_input_list.pop(0)
+ # expert backward compute
+ mlp_grad_output = grad_output_list[normal_index]
+ after_exp_out_list[reversed_index].backward(mlp_grad_output)
+ grad_input = before_exp_input_list[reversed_index].grad
+ mlp_grad_output.untyped_storage().resize_(0)
+ before_exp_input_list[reversed_index].untyped_storage().resize_(0)
+ # post expert process
+ a2a_grad_input, a2a1_handle = async_all_to_all(grad_input)
+ output_list.append(a2a_grad_input)
+ grad_a2a2_input_list.append(grad_input)
+ if normal_index > 0:
+ a2a_event[-1].wait()
+ grad_a2a2_input_list.pop(0)
+ a2a_event.append(a2a1_handle)
+
+ outputs = []
+ for c in range(pipe_degree):
+ cur_pipe_out_list = output_list[c * num_local_experts:(c + 1) * num_local_experts]
+ if c == pipe_degree - 1:
+ outputs.insert(0, cur_pipe_out_list)
+ continue
+ cur_pipe_out = torch.cat(cur_pipe_out_list, dim=1)
+ cur_pipe_out = cur_pipe_out.reshape((num_local_experts * ep_size), -1, hidden_size)
+ outputs.insert(0, cur_pipe_out)
+ return outputs
diff --git a/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_post_mlp_computer.py b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_post_mlp_computer.py
new file mode 100644
index 0000000000000000000000000000000000000000..742be3cfd88762cae98be75c6de13e3fc6edbfb3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/ampipe/ampipe_post_mlp_computer.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+# Copyright (c) Microsoft Corporation.
+#
+# This source code is licensed under the Apache license found in the
+# LICENSE file in the root directory of this source tree.
+
+# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py
+# reworked/refactored some parts to make it run.
+import torch
+
+from megatron.training import get_args
+from mindspeed.moe.utils import einsum
+
+
+class MoEPostMLPComputer:
+ def __init__(self, save_tensor_list, fwd_args):
+ super().__init__()
+ self.fwd_args = fwd_args
+ self.save_tensor_list = save_tensor_list
+
+ def forward(self, ctx, mlp_outputs):
+ global_args = get_args()
+ weights_list = self.fwd_args.weights_list
+ token_ec_idx_list = self.fwd_args.token_ec_idx_list
+ moe_output_list = self.fwd_args.moe_output_list
+ for c in range(global_args.ampipe_degree):
+ if not global_args.ampipe_tp_sp_comm_overlap:
+ self.fwd_args.a2a_events[c].wait()
+ detach_exp_out = mlp_outputs[c].detach()
+ detach_exp_out.requires_grad = True
+ with torch.enable_grad():
+ reshape_out = detach_exp_out.reshape(ctx.ep_size * ctx.num_local_experts, -1, ctx.hidden_size)
+ if not global_args.enable_token_rearrange_opt:
+ combine_weights = weights_list[c].type_as(reshape_out)
+ combined_output = einsum("sec,ecm->sm", combine_weights.type_as(reshape_out), reshape_out)
+ else:
+ token_rearranged_ec_idx, token_exp_weights = token_ec_idx_list[c], weights_list[c]
+ E, C, M = reshape_out.shape
+ org_dtype = reshape_out.dtype
+ if org_dtype == torch.bfloat16:
+ valid_expert_out = torch.index_select(
+ reshape_out.view(E * C, M).to(torch.float32), dim=0, index=token_rearranged_ec_idx
+ ).to(org_dtype)
+ else:
+ valid_expert_out = torch.index_select(
+ reshape_out.view(E * C, M), dim=0, index=token_rearranged_ec_idx
+ )
+ combined_output = valid_expert_out * token_exp_weights.unsqueeze(1).type_as(reshape_out)
+ if global_args.moe_router_topk == 2:
+ combined_output = torch.add(*torch.chunk(combined_output, global_args.moe_router_topk, dim=0))
+ clone_out = combined_output.clone()
+ clone_out.untyped_storage().resize_(0)
+ self.save_tensor_list.extend([detach_exp_out, clone_out])
+ moe_out = combined_output.reshape((self.fwd_args.seqlen, -1, ctx.hidden_size))
+ moe_output_list.append(moe_out)
+ return moe_output_list
+
+ def backward(self, saved_tensor_list, grad_moe_out_chunk):
+ exp_out, combined_output = saved_tensor_list
+ combined_output.backward(grad_moe_out_chunk)
+ exp_grad = exp_out.grad
+ exp_out.untyped_storage().resize_(0)
+ return exp_grad
diff --git a/model/train/yoco_moe/mindspeed/moe/async_comm_utils.py b/model/train/yoco_moe/mindspeed/moe/async_comm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..326a6de205e1e725d3cee8fef021b3684c773cd9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/async_comm_utils.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved.
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+
+import torch
+import torch_npu
+import torch.distributed as dist
+
+from megatron.core.parallel_state import (
+ get_global_memory_buffer,
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_world_size,
+ get_expert_model_parallel_group
+)
+from megatron.training import get_args
+from mindspeed.core.weight_grad_store import WeightGradStore
+
+
+class SingletonMeta(type):
+ _instances = {}
+ _lock = threading.Lock()
+
+ def __call__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls not in cls._instances:
+ cls._instances[cls] = super().__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+class AsyncCommUtilsDataSingleton(metaclass=SingletonMeta):
+ def __init__(self):
+ self.all2all_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ self.tp_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ self.fw_rs_output_ampipe = []
+ self.fw_rs_event_ampipe = []
+ self.fw_ar_output_ampipe = []
+ self.fw_ar_event_ampipe = []
+ self.fw_ag_output = []
+
+
+def get_async_comm_utils_data_instance():
+ return AsyncCommUtilsDataSingleton()
+
+
+def get_fw_ag_output():
+ return get_async_comm_utils_data_instance().fw_ag_output
+
+
+def get_fw_ar_rs_output_ampipe(sequence_parallel):
+ if sequence_parallel:
+ output_list = get_async_comm_utils_data_instance().fw_rs_output_ampipe
+ event_list = get_async_comm_utils_data_instance().fw_rs_event_ampipe
+ else:
+ output_list = get_async_comm_utils_data_instance().fw_ar_output_ampipe
+ event_list = get_async_comm_utils_data_instance().fw_ar_event_ampipe
+
+ if not output_list or not event_list:
+ return None
+
+ handle = event_list.pop(0)
+ handle.wait()
+ return output_list.pop(0)
+
+
+def async_fw_all_reduce_scatter_ampipe(input_, sequence_parallel):
+ world_size = get_tensor_model_parallel_world_size()
+ if sequence_parallel:
+ # reduce scatter
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] // world_size
+ output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ handle = torch.distributed._reduce_scatter_base(
+ output, input_.contiguous(), group=get_tensor_model_parallel_group(), async_op=True
+ )
+ get_async_comm_utils_data_instance().fw_rs_output_ampipe.append(output)
+ get_async_comm_utils_data_instance().fw_rs_event_ampipe.append(handle)
+ else:
+ # all reduce
+ handle = torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group(), async_op=True)
+ get_async_comm_utils_data_instance().fw_ar_output_ampipe.append(input_)
+ get_async_comm_utils_data_instance().fw_ar_event_ampipe.append(handle)
+
+
+def async_all_gather(input_, a2a_event=None, is_use_global_memory_buffer=False, is_bwd=False, is_save_input=False):
+ world_size = get_tensor_model_parallel_world_size()
+ dim_size = list(input_.size())
+ new_dim_size = dim_size[0] * world_size
+ dim_size[0] = new_dim_size
+ if is_bwd:
+ is_save_input = True
+
+ if is_use_global_memory_buffer:
+ ag_out = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
+ else:
+ ag_out = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ input_ = input_.contiguous()
+ if a2a_event:
+ # multi stream wait event
+ if get_async_comm_utils_data_instance().tp_stream is None:
+ get_async_comm_utils_data_instance().tp_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ with torch_npu.npu.stream(get_async_comm_utils_data_instance().tp_stream):
+ a2a_event.wait()
+ if is_save_input and get_args().use_nanopipe and WeightGradStore.is_decoupleBlock:
+ WeightGradStore.save_grad_output(input_.clone().detach())
+ handle = torch.distributed._all_gather_base(
+ ag_out, input_, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ else:
+ handle = torch.distributed._all_gather_base(
+ ag_out, input_, group=get_tensor_model_parallel_group(), async_op=True
+ )
+ if is_bwd:
+ return ag_out, handle
+ return input_, ag_out, handle
+
+
+def async_fw_all_gather(input_, a2a_event=None, is_use_global_memory_buffer=False,
+ ampipe_with_mlp_multistream=False, index=0):
+ input_, ag_out, handle = async_all_gather(input_, a2a_event, is_use_global_memory_buffer)
+ if ampipe_with_mlp_multistream:
+ get_async_comm_utils_data_instance().fw_ag_output[index] = ag_out
+ else:
+ get_async_comm_utils_data_instance().fw_ag_output.append(ag_out)
+ return input_, handle
+
+
+def async_all_to_all(input_, event=None):
+ input_ = input_.contiguous()
+ output = torch.empty_like(input_)
+ if event:
+ # multi stream wait event
+ if get_async_comm_utils_data_instance().all2all_stream is None:
+ get_async_comm_utils_data_instance().all2all_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
+ with torch_npu.npu.stream(get_async_comm_utils_data_instance().all2all_stream):
+ event.wait()
+ handle = dist.all_to_all_single(output, input_, group=get_expert_model_parallel_group(), async_op=True)
+ else:
+ handle = dist.all_to_all_single(output, input_, group=get_expert_model_parallel_group(), async_op=True)
+ return output, handle
+
+
+def async_fw_ar_rs(input_, sequence_parallel):
+ world_size = get_tensor_model_parallel_world_size()
+ if sequence_parallel:
+ # reduce scatter
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] // world_size
+ output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ handle = torch.distributed._reduce_scatter_base(
+ output, input_.contiguous(), group=get_tensor_model_parallel_group(), async_op=True
+ )
+ return output, handle
+ else:
+ # all reduce
+ handle = torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group(), async_op=True)
+ return input_, handle
diff --git a/model/train/yoco_moe/mindspeed/moe/config.py b/model/train/yoco_moe/mindspeed/moe/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a5429748f745f452a1c3bba612e29eb5ebbee99
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/config.py
@@ -0,0 +1,43 @@
+# coding=utf-8
+# Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved.
+
+import torch
+
+
+class Config:
+ def __init__(self,
+ hidden_size,
+ num_experts=1,
+ ep_size=1,
+ topk=1,
+ capacity_factor=1.,
+ eval_capacity_factor=1.,
+ min_capacity=4,
+ aux_loss_coef=0.0,
+ z_loss_coef=0.0,
+ noisy_gate_policy=None,
+ no_drop=False,
+ dynamic_padding=False,
+ use_sinkhorn=False,
+ sequence_parallel=False,
+ reshape_index_select=None,
+ ):
+ self.hidden_size = hidden_size
+ self.num_experts = num_experts
+ self.ep_size = ep_size
+ self.topk = topk
+ self.capacity_factor = capacity_factor
+ self.eval_capacity_factor = eval_capacity_factor
+ self.min_capacity = min_capacity
+ self.aux_loss_coef = aux_loss_coef
+ self.z_loss_coef = z_loss_coef
+ self.noisy_gate_policy = noisy_gate_policy
+ self.no_drop = no_drop
+ self.dynamic_padding = dynamic_padding
+ self.use_sinkhorn = use_sinkhorn
+ self.dynamic_capacity = None
+ self.sequence_parallel = sequence_parallel
+ self.reshape_index_select = None
+ if reshape_index_select:
+ self.reshape_index_select = torch.tensor(reshape_index_select, dtype=torch.int32,
+ device=torch.npu.current_device())
diff --git a/model/train/yoco_moe/mindspeed/moe/experts.py b/model/train/yoco_moe/mindspeed/moe/experts.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad61789b0a42bbbc8b189bcd4dc41cbd33207357
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/experts.py
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+#
+# This source code is licensed under the Apache license found in the
+# LICENSE file in the root directory of this source tree.
+
+# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
+# reworked/refactored some parts to make it run.
+import copy
+import torch
+
+
+class Experts(torch.nn.Module):
+
+ def __init__(self, expert, num_local_experts=1):
+ super(Experts, self).__init__()
+
+ self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
+ self.num_local_experts = num_local_experts
+
+ # All-reduce moe-expert parameters across data modulo expert parallel nodes
+ for expert in self.experts:
+ for name, param in expert.named_parameters():
+ param.allreduce = False
+
+ def forward(self, inputs):
+ chunks = inputs.chunk(self.num_local_experts, dim=1)
+ expert_outputs = []
+ for chunk, expert in zip(chunks, self.experts):
+ chunk = torch.squeeze(chunk, dim=1).contiguous()
+ out = expert(chunk)
+ if type(out) is tuple:
+ out, bias = out
+ if bias is not None:
+ out = out + bias
+ out = torch.unsqueeze(out, dim=1)
+ expert_outputs += [out]
+
+ expert_output = torch.cat(expert_outputs, dim=1)
+ return expert_output
diff --git a/model/train/yoco_moe/mindspeed/moe/gate.py b/model/train/yoco_moe/mindspeed/moe/gate.py
new file mode 100644
index 0000000000000000000000000000000000000000..f676917eea15f697a367bfd712b06d3706a260d1
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/gate.py
@@ -0,0 +1,306 @@
+# coding=utf-8
+# Copyright (c) Microsoft Corporation.
+#
+# This source code is licensed under the Apache license found in the
+# LICENSE file in the root directory of this source tree.
+
+# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py
+# reworked/refactored some parts to make it run.
+from typing import Callable, Dict, Tuple
+from collections import namedtuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Module
+import torch.distributed as dist
+from megatron.training import get_args
+from megatron.core.transformer.moe.moe_utils import sinkhorn
+
+from .config import Config
+from .utils import gumbel_rsample, _capacity, einsum, _one_hot_to_float, MoEAuxLossAutoScaler
+
+exp_selection_uniform_map: Dict[torch.device, Callable] = {}
+
+
+GatingTokenRearrangeInfo = namedtuple('GatingTokenRearrangeInfo', ['token_rearranged_ec_idx', 'token_exp_weights', 'expert_select_token_idx'])
+
+
+class TopKGate(Module):
+ """Gate module which implements Top2Gating as described in Gshard_.
+ ::
+
+ gate = TopKGate(model_dim, num_experts)
+ l_aux, combine_weights, dispatch_mask = gate(input)
+
+ .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
+
+ Args:
+ model_dim (int):
+ size of model embedding dimension
+ num_experts (ints):
+ number of experts in model
+ """
+
+ weight: torch.nn.Linear
+
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+
+ # Only top-1 and top-2 are supported at the moment.
+ if config.topk != 1 and config.topk != 2:
+ raise ValueError('Only top-1 and top-2 gatings are supported.')
+ self.weight = torch.nn.Linear(config.hidden_size, config.num_experts, bias=False).float()
+ setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
+ self.config = config
+
+ def forward(self, gate_input: torch.Tensor) -> Tuple[Tensor, ...]: # type: ignore
+ input_fp32 = gate_input.float()
+ logits = torch.nn.functional.linear(input_fp32, weight=self.weight.weight.float(), bias=None)
+
+ if self.config.use_sinkhorn:
+ logits = sinkhorn(logits)
+ if self.config.topk == 1:
+ gate_output = top1gating(logits, self.config)
+ else:
+ gate_output = top2gating(logits, self.config)
+
+ return gate_output
+
+
+def top1gating(logits: Tensor, config: Config) -> Tuple[Tensor, ...]:
+ """Implements Top1Gating on logits."""
+ args = get_args()
+ if config.noisy_gate_policy == 'RSample':
+ logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
+ # everything is in fp32 in this function
+ # token_sel_expert_weights: [S, E], 每个token选择每个专家的概率
+ token_sel_expert_weights = F.softmax(logits, dim=1)
+
+ if config.reshape_index_select is not None and args.ampipe_degree <= 1:
+ token_sel_expert_weights = token_sel_expert_weights[:, config.reshape_index_select]
+
+ capacity = _capacity(token_sel_expert_weights,
+ torch.tensor(config.capacity_factor),
+ torch.tensor(config.min_capacity))
+
+ # Create a mask for 1st's expert per token
+ # noisy gating
+ final_logits = logits_w_noise if config.noisy_gate_policy == "RSample" else \
+ token_sel_expert_weights
+ # [S] 每个token对应的专家(取概率最大的)
+ token_sel_expert_idx = torch.argmax(final_logits, dim=1)
+ num_experts = int(token_sel_expert_weights.shape[1])
+ token_sel_expert_mask = F.one_hot(token_sel_expert_idx, num_classes=num_experts)
+
+ # if we don't want to drop any tokens
+ if config.no_drop:
+ # gating decisions
+ exp_counts = torch.sum(token_sel_expert_mask, dim=0).detach()
+ if config.dynamic_padding:
+ new_capacity = torch.max(exp_counts)
+ cur_capacity = new_capacity.item()
+ capacity = config.dynamic_capacity.to(logits.device)
+
+ flag = cur_capacity > capacity
+ dist.reduce(flag, dst=0, op=torch.distributed.ReduceOp.SUM, group=dist.group.WORLD)
+ dist.broadcast(flag, src=0, group=dist.group.WORLD)
+ if flag:
+ dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
+ capacity = new_capacity
+
+ if cur_capacity > logits.shape[0]:
+ capacity = torch.ceil(torch.tensor(logits.shape[0])).to(torch.int64)
+ else:
+ new_capacity = torch.max(exp_counts).to(logits.device)
+ dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
+ capacity = new_capacity
+
+ # Compute l_aux负载均衡aux_loss
+ me = torch.mean(token_sel_expert_weights, dim=0)
+ ce = torch.mean(token_sel_expert_mask.float(), dim=0)
+ l_aux = torch.sum(me * ce) * num_experts
+ all_args = get_args()
+ # Random Token Selection(将token选择专家的掩码0/1矩阵中的1转成0~1之间的权重值)
+ if all_args.use_rts: # default True.
+ uniform = exp_selection_uniform_map.get(logits.device)
+ if uniform is None:
+ uniform = torch.distributions.uniform.Uniform(
+ low=torch.tensor(0.0, device=logits.device),
+ high=torch.tensor(1.0, device=logits.device)).rsample
+ exp_selection_uniform_map[logits.device] = uniform
+ # [S, E]
+ token_sel_expert_score = token_sel_expert_mask * uniform(token_sel_expert_mask.shape)
+ else:
+ token_sel_expert_score = token_sel_expert_mask
+
+ # 通过topC每个专家选择至多C个token,然后和原始的mask1(每个专家可能选择超过C个token)矩阵相乘,
+ # 丢掉超过专家容量的权重低的token,更新得到 token_sel_expert_mask
+ expert_sel_top_c_token_idx = torch.topk(token_sel_expert_score, k=capacity, dim=0)[1]
+ token_sel_expert_mask *= torch.zeros_like(token_sel_expert_mask).scatter_(0, expert_sel_top_c_token_idx, 1)
+
+ # Normalize gate probabilities
+ token_sel_expert_mask_float = token_sel_expert_mask.float()
+ token_sel_expert_weights = token_sel_expert_weights * token_sel_expert_mask_float
+
+ token_idx_in_expert_with_noise = torch.cumsum(token_sel_expert_mask, dim=0) - 1
+ masked_token_idx_in_expert = token_idx_in_expert_with_noise * token_sel_expert_mask
+ token_offset_for_expert = torch.sum(masked_token_idx_in_expert, dim=1)
+ if all_args.enable_token_rearrange_opt:
+ # 重排过程:计算出每个专家选择的token的索引:expert_select_token_idx,shape为: [E*C]
+ # MoE前向过程中根据此索引通过index_select API实现token的重排
+ # shape变化过程:[S, E]->[C, E]->[E, C]->[E*C]
+ expert_sel_top_c_token_idx = torch.topk(token_sel_expert_mask,
+ k=capacity,
+ dim=0,
+ sorted=True)[1]
+ expert_select_token_idx = expert_sel_top_c_token_idx.t().reshape(config.num_experts * capacity)
+ token_exp_weights, token_exp_idx = torch.max(token_sel_expert_weights, dim=1)
+ token_rearranged_ec_idx = (capacity.to(torch.int32) * token_exp_idx.to(torch.int32) +
+ token_offset_for_expert.to(torch.int32))
+ top1_gating_token_infos = GatingTokenRearrangeInfo(token_rearranged_ec_idx=token_rearranged_ec_idx,
+ token_exp_weights=token_exp_weights,
+ expert_select_token_idx=expert_select_token_idx)
+ return l_aux, top1_gating_token_infos
+ else:
+ token_locations_sc = _one_hot_to_float(token_offset_for_expert, capacity)
+ combine_weights = einsum("se,sc->sec", token_sel_expert_weights, token_locations_sc)
+ dispatch_mask = combine_weights.bool()
+ if config.dynamic_padding:
+ return l_aux, combine_weights, dispatch_mask, cur_capacity
+ else:
+ return l_aux, combine_weights, dispatch_mask
+
+
+def apply_aux_loss(config, gates, mask1):
+ num_experts = int(gates.shape[1])
+ me = torch.mean(gates, dim=0)
+ ce = torch.mean(mask1.float(), dim=0)
+ l_aux = torch.mean(me * ce) * num_experts * num_experts
+ if config.aux_loss_coef > 0:
+ l_aux = l_aux * config.aux_loss_coef
+ gates = MoEAuxLossAutoScaler.apply(gates, l_aux)
+ return gates, l_aux
+
+
+def apply_z_loss(config, logits):
+ """Encourages the router's logits to remain small to enhance stability.
+ Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
+
+ Args:
+ logits (torch.Tensor): The logits of the router.
+
+ Returns:
+ torch.Tensor: The logits after applying the z-loss.
+ """
+ if config.z_loss_coef > 0:
+ z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * config.z_loss_coef
+ logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
+ return logits
+
+
+def top2gating(logits: Tensor, config: Config) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Implements Top2Gating on logits."""
+ # apply z loss
+ args = get_args()
+ logits = apply_z_loss(config, logits)
+
+ # everything is in fp32 in this function
+ token_sel_expert_weights = F.softmax(logits, dim=1)
+
+ if config.reshape_index_select is not None and args.ampipe_degree <= 1:
+ token_sel_expert_weights = token_sel_expert_weights[:, config.reshape_index_select]
+
+ num_experts = int(token_sel_expert_weights.shape[1])
+
+ capacity = _capacity(token_sel_expert_weights,
+ torch.tensor(config.capacity_factor * 2),
+ torch.tensor(config.min_capacity))
+
+ _, selected_experts = torch.topk(token_sel_expert_weights, config.topk, dim=-1)
+ mask = F.one_hot(selected_experts, num_classes=num_experts)
+ first_expert_mask = mask[:, 0, :]
+ second_expert_mask = mask[:, 1, :]
+
+ # Compute locations in capacity buffer
+ locations_in_first_expert = torch.cumsum(first_expert_mask, dim=0) - 1
+ locations_in_second_expert = torch.cumsum(second_expert_mask, dim=0) - 1
+ # Update 2nd's location by accounting for locations of 1st
+ locations_in_second_expert += torch.sum(first_expert_mask, dim=0, keepdim=True)
+
+ # gating decisions
+ token_sel_expert_weights, l_aux = apply_aux_loss(config, token_sel_expert_weights, first_expert_mask)
+ if config.no_drop:
+ if config.dynamic_padding:
+ new_capacity = torch.max(locations_in_second_expert) + 2
+ cur_capacity = new_capacity.item()
+ capacity = config.dynamic_capacity.to(logits.device)
+
+ flag = cur_capacity > capacity
+ dist.reduce(flag, dst=0, op=torch.distributed.ReduceOp.SUM, group=dist.group.WORLD)
+ dist.broadcast(flag, src=0, group=dist.group.WORLD)
+ if flag:
+ dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
+ capacity = new_capacity
+ if cur_capacity > logits.shape[0]:
+ capacity = torch.ceil(torch.tensor(logits.shape[0])).to(torch.int64)
+ else:
+ new_capacity = torch.max(locations_in_second_expert) + 2
+ dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
+ capacity = new_capacity
+
+ # Remove locations outside capacity from mask
+ first_expert_mask *= torch.lt(locations_in_first_expert, capacity)
+ second_expert_mask *= torch.lt(locations_in_second_expert, capacity)
+
+ # Store the capacity location for each token
+ token_idx_in_first_expert = torch.sum(locations_in_first_expert * first_expert_mask, dim=1)
+ token_idx_in_second_expert = torch.sum(locations_in_second_expert * second_expert_mask, dim=1)
+
+ # Normalize gate probabilities
+ first_expert_mask_float = first_expert_mask.float()
+ second_expert_mask_float = second_expert_mask.float()
+ token_first_exp_weights, token_first_exp_idx = torch.max(token_sel_expert_weights * first_expert_mask_float, dim=1)
+ token_second_exp_weights, token_second_exp_idx = torch.max(token_sel_expert_weights * second_expert_mask_float,
+ dim=1)
+ denom_s = token_first_exp_weights + token_second_exp_weights
+ # Avoid divide-by-zero
+ denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
+ token_first_exp_weights /= denom_s
+ token_second_exp_weights /= denom_s
+ all_args = get_args()
+ if all_args.enable_token_rearrange_opt:
+ token_rearranged_first_ec_idx = token_first_exp_idx.int() * capacity + token_idx_in_first_expert.int()
+ token_rearranged_second_ec_idx = token_second_exp_idx.int() * capacity + token_idx_in_second_expert.int()
+ # 重排过程:计算出每个专家选择的token的索引:expert_select_token_idx,shape为: [E*C]
+ # MoE前向过程中根据此索引通过index_select API实现token的重排
+ # shape变化过程:[S, E]->[C, E]->[E, C]->[E*C]
+ token_sel_first_exp_int_mask = first_expert_mask * 2
+ token_sel_second_exp_int_mask = second_expert_mask
+ expert_sel_top_c_token_idx = torch.topk(token_sel_first_exp_int_mask + token_sel_second_exp_int_mask,
+ k=capacity,
+ dim=0,
+ sorted=True)[1]
+ expert_select_token_idx = expert_sel_top_c_token_idx.t().reshape(num_experts * capacity)
+ token_rearranged_ec_idx = torch.cat([token_rearranged_first_ec_idx, token_rearranged_second_ec_idx], dim=0)
+ token_exp_weights = torch.cat([token_first_exp_weights, token_second_exp_weights], dim=0)
+
+ top2_gating_token_infos = GatingTokenRearrangeInfo(token_rearranged_ec_idx=token_rearranged_ec_idx,
+ token_exp_weights=token_exp_weights,
+ expert_select_token_idx=expert_select_token_idx)
+ return l_aux, top2_gating_token_infos
+ else:
+ # Calculate combine_weights and dispatch_mask
+ gates1 = einsum("s,se->se", token_first_exp_weights, first_expert_mask_float)
+ gates2 = einsum("s,se->se", token_second_exp_weights, second_expert_mask_float)
+ locations1_sc = _one_hot_to_float(token_idx_in_first_expert, capacity)
+ locations2_sc = _one_hot_to_float(token_idx_in_second_expert, capacity)
+ combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
+ combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
+ combine_weights = combine1_sec + combine2_sec
+ dispatch_mask = combine_weights.bool()
+
+ if config.dynamic_padding:
+ return l_aux, combine_weights, dispatch_mask, cur_capacity
+ else:
+ return l_aux, combine_weights, dispatch_mask
diff --git a/model/train/yoco_moe/mindspeed/moe/mixtral_parallel_mlpbm.py b/model/train/yoco_moe/mindspeed/moe/mixtral_parallel_mlpbm.py
new file mode 100644
index 0000000000000000000000000000000000000000..baf96df9ad3a74a766763a1ac744d68e4ae702a9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/mixtral_parallel_mlpbm.py
@@ -0,0 +1,93 @@
+# coding=utf-8
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn.functional as F
+
+from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
+from megatron.training import get_args
+from megatron.core import parallel_state
+from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
+from mindspeed.model.transformer import should_recompute_activation
+
+
+class MixtralParallelMLPBM(torch.nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.activation_checkpoint_manager = None
+ self.ffn_dim = config.ffn_hidden_size
+ self.hidden_dim = config.hidden_size
+ self.layer_number = None
+
+ self.w1 = ColumnParallelLinear(
+ config.hidden_size,
+ config.ffn_hidden_size,
+ config=config,
+ init_method=config.init_method,
+ bias=False,
+ gather_output=False,
+ skip_bias_add=True,
+ is_expert=False,
+ pipe_experts=get_args().use_pipe_experts
+ )
+
+ self.w2 = RowParallelLinear(
+ config.ffn_hidden_size,
+ config.hidden_size,
+ config=config,
+ init_method=config.output_layer_init_method,
+ bias=False,
+ skip_bias_add=True,
+ input_is_parallel=True,
+ is_expert=False,
+ pipe_experts=get_args().use_pipe_experts
+ )
+
+ self.w3 = ColumnParallelLinear(
+ config.hidden_size,
+ config.ffn_hidden_size,
+ config=config,
+ init_method=config.init_method,
+ bias=False,
+ gather_output=False,
+ skip_bias_add=True,
+ is_expert=False,
+ pipe_experts=get_args().use_pipe_experts
+ )
+
+ self.act_fn = F.silu
+ if get_args().use_nanopipe and parallel_state.get_pipeline_model_parallel_world_size() > 1 \
+ and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
+ setattr(self.w1, "in_nano", True)
+ setattr(self.w2, "in_nano", True)
+ setattr(self.w3, "in_nano", True)
+
+ def forward(self, hidden_states):
+ is_recompute_activation = should_recompute_activation(self.layer_number)
+
+ if is_recompute_activation:
+ self.activation_checkpoint_manager = CheckpointWithoutOutput()
+ act_intermediate_parallel = self.activation_checkpoint_manager.checkpoint(self.act_fn, False, self.w1(hidden_states)[0])
+ current_hidden_states = act_intermediate_parallel * self.w3(hidden_states)[0]
+ self.activation_checkpoint_manager.discard_output()
+ current_hidden_states = self.w2(current_hidden_states)[0]
+ if current_hidden_states.requires_grad:
+ current_hidden_states.register_hook(self.activation_checkpoint_manager.recompute)
+ else:
+ current_hidden_states = self.act_fn(self.w1(hidden_states)[0]) * self.w3(hidden_states)[0]
+ current_hidden_states = self.w2(current_hidden_states)[0]
+
+ return current_hidden_states
diff --git a/model/train/yoco_moe/mindspeed/moe/moe.py b/model/train/yoco_moe/mindspeed/moe/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..f75638d9f2b6170120dfa0eeb7657ce59b943752
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/moe.py
@@ -0,0 +1,105 @@
+# Copyright (c) Microsoft Corporation.
+#
+# This source code is licensed under the Apache license found in the
+# LICENSE file in the root directory of this source tree.
+
+# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/layer.py
+# reworked/refactored some parts to make it run.
+import typing
+
+import torch
+from megatron.training import get_args
+
+from .experts import Experts
+from .gate import TopKGate
+from .moe_layer import MOELayer
+from .config import Config
+from .utils import get_reshape_index_select
+
+
+class MoE(torch.nn.Module):
+ """Initialize an MoE layer.
+
+ Arguments:
+ hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
+ expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
+ num_experts (int, optional): default=1, the total number of experts per layer.
+ ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
+ k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
+ capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
+ eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
+ min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
+ aux_loss_coef (int, optional): default=0.0, scaling coefficient for the aux loss.
+ z_loss_coef (int, optional): default=0.0, scaling coefficient for the z loss.
+ noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
+ """
+
+ def __init__(self,
+ hidden_size,
+ expert,
+ num_experts=1,
+ ep_size=1,
+ k=1,
+ capacity_factor=1.,
+ eval_capacity_factor=1.,
+ min_capacity=4,
+ aux_loss_coef=0.0,
+ z_loss_coef=0.0,
+ ep_group=None,
+ noisy_gate_policy: typing.Optional[str] = None,
+ no_drop=False,
+ dynamic_padding=False,
+ use_sinkhorn=False,
+ sequence_parallel=False):
+ super(MoE, self).__init__()
+ args = get_args()
+ pipe_experts = args.use_pipe_experts
+ sequence_parallel = sequence_parallel
+ pipe_experts_multi_data = args.pipe_experts_multi_data
+ pipe_experts_multi_stream = args.pipe_experts_multi_stream
+
+ if num_experts % ep_size != 0:
+ raise AssertionError(f"Number of experts should be divisible by expert parallel size")
+ num_local_experts = num_experts // ep_size
+
+ if pipe_experts:
+ reshape_index_select = get_reshape_index_select(num_local_experts, ep_size)
+ else:
+ reshape_index_select = None
+
+ self.num_experts = num_experts
+ self.ep_size = ep_size
+ self.num_local_experts = num_local_experts
+ self.ep_group = ep_group
+ self.hidden_size = hidden_size
+
+ config = Config(hidden_size=hidden_size,
+ num_experts=num_experts,
+ ep_size=ep_size,
+ topk=k,
+ capacity_factor=capacity_factor,
+ eval_capacity_factor=eval_capacity_factor,
+ min_capacity=min_capacity,
+ aux_loss_coef=aux_loss_coef,
+ z_loss_coef=z_loss_coef,
+ noisy_gate_policy=noisy_gate_policy,
+ no_drop=no_drop,
+ dynamic_padding=dynamic_padding,
+ use_sinkhorn=use_sinkhorn,
+ sequence_parallel=sequence_parallel,
+ reshape_index_select=reshape_index_select
+ )
+ self.moe_layer = MOELayer(TopKGate(config),
+ Experts(expert, num_local_experts),
+ ep_size,
+ num_local_experts,
+ pipe_experts=pipe_experts,
+ sequence_parallel=sequence_parallel,
+ pipe_experts_multi_data=pipe_experts_multi_data,
+ pipe_experts_multi_stream=pipe_experts_multi_stream,
+ )
+ self.moe_layer.set_ep_group(ep_group)
+
+ def forward(self, hidden_states, used_token=None):
+ output = self.moe_layer(hidden_states, used_token)
+ return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
diff --git a/model/train/yoco_moe/mindspeed/moe/moe_layer.py b/model/train/yoco_moe/mindspeed/moe/moe_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5450c626b72e762fd56eda054aa2a68618f2552
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/moe_layer.py
@@ -0,0 +1,148 @@
+# Copyright (c) Microsoft Corporation.
+#
+# This source code is licensed under the Apache license found in the
+# LICENSE file in the root directory of this source tree.
+
+# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py
+# reworked/refactored some parts to make it run.
+from typing import TYPE_CHECKING, Any
+
+import torch
+from torch import Tensor
+from torch.nn import Module
+import torch.distributed as dist
+
+from .utils import _AllToAll, einsum
+from .pipe_experts import PipeExpert
+
+if TYPE_CHECKING:
+ Base = Module[Tensor]
+else:
+ Base = Module
+
+
+class MOELayer(Base):
+ """MOELayer module which implements MixtureOfExperts as described in Gshard_.
+ ::
+
+ gate = TopKGate(model_dim, num_experts)
+ moe = MOELayer(gate, expert)
+ output = moe(input)
+ l_aux = moe.l_aux
+
+ .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
+
+ Args:
+ gate (torch.nn.Module):
+ gate network
+ expert (torch.nn.Module):
+ expert network
+ """
+
+ def __init__(self,
+ gate: Module,
+ experts: Module,
+ ep_size,
+ num_local_experts: int,
+ pipe_experts: bool = False,
+ sequence_parallel: bool = False,
+ pipe_experts_multi_data: int = 1,
+ pipe_experts_multi_stream: bool = False) -> None:
+ super().__init__()
+ self.gate = gate
+ self.experts = experts
+ self.ep_group = None
+ self.ep_size = ep_size
+ self.num_local_experts = num_local_experts
+ self.num_experts = ep_size * num_local_experts
+ self.exp_counts = None
+ self.l_aux = None
+
+ self.cur_index_window = 0
+ self.capacity_window_size = 20
+ self.capacity_history_window = []
+ self.gate.config.dynamic_capacity = torch.ceil(torch.tensor(256)).to(torch.int64)
+
+ self.pipe_experts = pipe_experts
+ self.sequence_parallel = sequence_parallel
+ self.pipe_experts_multi_data = pipe_experts_multi_data
+ self.pipe_experts_multi_stream = pipe_experts_multi_stream
+
+ def set_ep_group(self, ep_group):
+ self.ep_group = ep_group
+
+ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
+ d_model = input[0].shape[-1]
+ reshaped_input = input[0].reshape(-1, d_model)
+ from megatron.training import get_args
+ all_args = get_args()
+ # gate
+ if not all_args.enable_token_rearrange_opt:
+ if self.gate.config.dynamic_padding:
+ self.l_aux, combine_weights, dispatch_mask, cur_capacity_cur_rank = self.gate(reshaped_input)
+ self.capacity_history_window.append(cur_capacity_cur_rank)
+ self.cur_index_window += 1
+ if len(self.capacity_history_window) > self.capacity_window_size:
+ self.capacity_history_window.pop(0)
+ if self.cur_index_window == self.capacity_window_size - 1:
+ self.cur_index_window = 0
+ capacity_history_window_tensor = torch.Tensor(self.capacity_history_window[-5:]).to(combine_weights.device)
+ dist.all_reduce(capacity_history_window_tensor, op=torch.distributed.ReduceOp.MAX,
+ group=dist.group.WORLD)
+ self.capacity_history_window = capacity_history_window_tensor.cpu().numpy().tolist()
+
+ if len(self.capacity_history_window) > 0:
+ capacity_next_window = sum(self.capacity_history_window) / len(self.capacity_history_window) + 20
+ else:
+ capacity_next_window = 256
+ self.gate.config.dynamic_capacity = torch.ceil(torch.tensor(capacity_next_window)).to(torch.int64)
+ else:
+ self.l_aux, combine_weights, dispatch_mask = self.gate(reshaped_input)
+ dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
+ else:
+ self.l_aux, token_rearrange_infos = self.gate(reshaped_input)
+ org_dtype = reshaped_input.dtype
+ if org_dtype == torch.bfloat16: # 规避算子性能劣化问题, 解决后可删除
+ rearranged_input = torch.index_select(
+ reshaped_input.to(torch.float32), dim=0, index=token_rearrange_infos.expert_select_token_idx
+ ).to(org_dtype)
+ else:
+ rearranged_input = torch.index_select(
+ reshaped_input, dim=0, index=token_rearrange_infos.expert_select_token_idx
+ )
+ capacity = token_rearrange_infos.expert_select_token_idx.size(0) // self.num_experts
+ dispatched_input = rearranged_input.reshape(self.num_experts, capacity, d_model).contiguous()
+
+ if self.pipe_experts:
+ expert_output = PipeExpert.apply(self.experts, dispatched_input, self.ep_size, self.num_local_experts,
+ self.sequence_parallel, self.pipe_experts_multi_data,
+ self.pipe_experts_multi_stream)
+ else:
+ # dispatch all2all
+ dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
+
+ # Re-shape after all-to-all: ecm -> gecm
+ dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
+ expert_output = self.experts(dispatched_input)
+
+ # combine all2all
+ expert_output = _AllToAll.apply(self.ep_group, expert_output)
+
+ # Re-shape back: gecm -> ecm
+ expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
+
+ if not all_args.enable_token_rearrange_opt:
+ combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
+ else:
+ E, C, M = expert_output.shape
+ org_dtype = expert_output.dtype
+ if org_dtype == torch.bfloat16:
+ valid_expert_out = torch.index_select(
+ expert_output.view(E * C, M).to(torch.float32), dim=0, index=token_rearrange_infos.token_rearranged_ec_idx
+ ).to(org_dtype)
+ else:
+ valid_expert_out = torch.index_select(expert_output.view(E * C, M), dim=0, index=token_rearrange_infos.token_rearranged_ec_idx)
+ combined_output = valid_expert_out * token_rearrange_infos.token_exp_weights.unsqueeze(1).type_as(input[0])
+ if all_args.moe_router_topk == 2:
+ combined_output = torch.add(*torch.chunk(combined_output, all_args.moe_router_topk, dim=0))
+ return combined_output.reshape(input[0].shape)
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/moe/pipe_experts.py b/model/train/yoco_moe/mindspeed/moe/pipe_experts.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a48205a065875fdaeca9daccad9697698586c63
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/pipe_experts.py
@@ -0,0 +1,422 @@
+# coding=utf-8
+# Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved.
+
+import torch
+
+from megatron.training import get_args
+from mindspeed.core.weight_grad_store import WeightGradStore
+
+from .async_comm_utils import (async_all_to_all, async_fw_ar_rs, get_fw_ag_output, async_fw_all_gather,
+ async_all_gather)
+
+ASYNC_BW_ALL_GATHER_COUNT = 0
+FLAG_GRAD_REDUCE = True
+
+
+def get_async_bw_all_gather_count():
+ return ASYNC_BW_ALL_GATHER_COUNT
+
+
+class PipeExpertUtil:
+ multi_data = None
+ num_local_experts = None
+ slice_seq_size = None
+ ep_size = None
+
+ first_a2a_event = []
+ second_a2a_event = []
+ fw_ag_event = []
+ bw_ag_event = []
+ ar_rs_event = []
+
+ @classmethod
+ def set_parameters(cls, args, slice_seq_size):
+ cls.multi_data = args[4]
+ cls.num_local_experts = args[2]
+ cls.slice_seq_size = slice_seq_size
+ cls.ep_size = args[1]
+
+ @classmethod
+ def get_first_a2a_event(cls):
+ return cls.first_a2a_event
+
+ @classmethod
+ def get_second_a2a_event(cls):
+ return cls.second_a2a_event
+
+ @classmethod
+ def get_fw_ag_event(cls):
+ return cls.fw_ag_event
+
+ @classmethod
+ def get_bw_ag_event(cls):
+ return cls.bw_ag_event
+
+ @classmethod
+ def get_ar_rs_event(cls):
+ return cls.ar_rs_event
+
+ @classmethod
+ def deal_data(cls, origin_data, output_data):
+ for i in range(cls.num_local_experts):
+ for j in range(cls.multi_data):
+ output_data.append(origin_data[i * cls.ep_size: (i + 1) * cls.ep_size,
+ j * cls.slice_seq_size: (j + 1) * cls.slice_seq_size].clone().contiguous())
+
+ @classmethod
+ def first_a2a_when_not_multi_stream(cls, input_data_list):
+ for i in range(cls.num_local_experts):
+ for j in range(cls.multi_data):
+ input_data_list[j + i * cls.multi_data], handle = async_all_to_all(
+ input_data_list[j + i * cls.multi_data])
+ cls.first_a2a_event.append(handle)
+
+ @classmethod
+ def fw_bw_ag_after_first_a2a_when_not_multi_stream(cls, input_data_list, num_local_experts_index, multi_data_index,
+ is_fw_ag):
+ index = num_local_experts_index * cls.multi_data + multi_data_index
+ if index == 0 and get_args().ampipe_degree <= 1:
+ cls.first_a2a_event[index].wait()
+ if is_fw_ag:
+ input_data_list[index], handle = async_fw_all_gather(input_data_list[index])
+ cls.fw_ag_event.append(handle)
+ else:
+ if get_args().use_nanopipe and WeightGradStore.is_decoupleBlock:
+ WeightGradStore.save_grad_output(input_data_list[num_local_experts_index * cls.multi_data + multi_data_index].clone().detach())
+ input_data_list[index], handle = async_all_gather(input_data_list[index], is_bwd=True)
+ cls.bw_ag_event.append(handle)
+ if index < (cls.num_local_experts * cls.multi_data - 1):
+ cls.first_a2a_event[index + 1].wait()
+ if is_fw_ag:
+ if index == 0 and not get_args().use_nanopipe:
+ input_data_list[index + 1], handle = async_fw_all_gather(input_data_list[index + 1], None, True)
+ else:
+ input_data_list[index + 1], handle = async_fw_all_gather(input_data_list[index + 1])
+ cls.fw_ag_event.append(handle)
+ else:
+ if get_args().use_nanopipe and WeightGradStore.is_decoupleBlock:
+ WeightGradStore.save_grad_output(input_data_list[num_local_experts_index * cls.multi_data + multi_data_index + 1].clone().detach())
+ if index == 0 and not get_args().use_nanopipe:
+ input_data_list[index + 1], handle = async_all_gather(input_data_list[index + 1], None, True, True)
+ else:
+ input_data_list[index + 1], handle = async_all_gather(input_data_list[index + 1], is_bwd=True)
+ cls.bw_ag_event.append(handle)
+
+ @classmethod
+ def fw_bw_ag_after_first_a2a_when_multi_stream(cls, input_data_list, num_local_experts_index, multi_data_index,
+ is_fw_ag):
+ index = num_local_experts_index * cls.multi_data + multi_data_index
+ if index == 0:
+ input_data_list[index], handle = async_all_to_all(input_data_list[index])
+ cls.first_a2a_event.append(handle)
+ if is_fw_ag:
+ input_data_list[index], handle = async_fw_all_gather(
+ input_data_list[index], cls.first_a2a_event[index])
+ cls.fw_ag_event.append(handle)
+ else:
+ input_data_list[index], handle = async_all_gather(
+ input_data_list[index], cls.first_a2a_event[index], is_bwd=True)
+ cls.bw_ag_event.append(handle)
+ if index < (cls.num_local_experts * cls.multi_data - 1):
+ if is_fw_ag:
+ input_data_list[index + 1], handle = async_all_to_all(
+ input_data_list[index + 1], cls.fw_ag_event[index])
+ cls.first_a2a_event.append(handle)
+ if index == 0 and not get_args().use_nanopipe:
+ input_data_list[index + 1], handle = async_fw_all_gather(
+ input_data_list[index + 1], cls.first_a2a_event[index + 1], True)
+ else:
+ input_data_list[index + 1], handle = async_fw_all_gather(
+ input_data_list[index + 1], cls.first_a2a_event[index + 1])
+ cls.fw_ag_event.append(handle)
+ else:
+ input_data_list[index + 1], handle = async_all_to_all(
+ input_data_list[index + 1], cls.bw_ag_event[index])
+ cls.first_a2a_event.append(handle)
+ if index == 0 and not get_args().use_nanopipe:
+ input_data_list[index + 1], handle = async_all_gather(
+ input_data_list[index + 1], cls.first_a2a_event[index + 1], True, True)
+ else:
+ input_data_list[index + 1], handle = async_all_gather(
+ input_data_list[index + 1], cls.first_a2a_event[index + 1], is_bwd=True)
+ cls.bw_ag_event.append(handle)
+
+ @classmethod
+ def fw_a2a_after_ar_rs_when_not_multi_stream(cls, num_local_experts_index, multi_data_index,
+ output_list_for_each_multi_data, outputs_list_for_each_local_expert):
+ if cls.multi_data == 1:
+ if num_local_experts_index > 0:
+ cls.ar_rs_event[num_local_experts_index - 1].wait()
+ outputs_list_for_each_local_expert[num_local_experts_index - 1][0], handle = async_all_to_all(
+ outputs_list_for_each_local_expert[num_local_experts_index - 1][0])
+ cls.second_a2a_event.append(handle)
+ else:
+ if multi_data_index > 0:
+ cls.ar_rs_event[num_local_experts_index * cls.multi_data + multi_data_index - 1].wait()
+ output_list_for_each_multi_data[multi_data_index - 1], handle = async_all_to_all(
+ output_list_for_each_multi_data[multi_data_index - 1])
+ cls.second_a2a_event.append(handle)
+ else:
+ if num_local_experts_index > 0:
+ cls.ar_rs_event[num_local_experts_index * cls.multi_data + multi_data_index - 1].wait()
+ outputs_list_for_each_local_expert[num_local_experts_index - 1][
+ cls.multi_data - 1], handle = async_all_to_all(
+ outputs_list_for_each_local_expert[num_local_experts_index - 1][cls.multi_data - 1])
+ cls.second_a2a_event.append(handle)
+
+ @classmethod
+ def fw_a2a_for_final_data_when_not_multi_stream(cls, outputs_list_for_each_local_expert):
+ cls.ar_rs_event[cls.num_local_experts * cls.multi_data - 1].wait()
+ outputs_list_for_each_local_expert[cls.num_local_experts - 1][
+ cls.multi_data - 1], handle = async_all_to_all(
+ outputs_list_for_each_local_expert[cls.num_local_experts - 1][cls.multi_data - 1])
+ cls.second_a2a_event.append(handle)
+
+
+class PipeExpert(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, Experts, *args):
+ inputs = args[0]
+ ep_size = args[1]
+ num_local_experts = args[2]
+ sequence_parallel = args[3]
+ multi_data = args[4]
+ multi_stream = args[5]
+
+ ctx.num_local_experts = num_local_experts
+ ctx.sequence_parallel = sequence_parallel
+ ctx.multi_data = multi_data
+ ctx.multi_stream = multi_stream
+
+ inputs_list = []
+ ampipe_degree = get_args().ampipe_degree
+ ctx.ampipe_degree = ampipe_degree
+ if ampipe_degree > 1:
+ PipeExpertUtil.first_a2a_event = args[6]
+ PipeExpertUtil.second_a2a_event = args[7]
+ PipeExpertUtil.fw_ag_event = args[8]
+ ctx.hidden_size = hidden_size = args[9]
+ save_tensors_list = args[10]
+ inputs_list = inputs
+ slice_seq_size = 0
+ else:
+ input_shape = list(inputs.size())
+ if multi_data > input_shape[1]:
+ raise ValueError('--pipe-experts-multi-data cannot be greater than experts capacity')
+ slice_seq_size = input_shape[1] // multi_data
+ if input_shape[1] % multi_data != 0:
+ slice_seq_size += 1
+
+ outputs_list_for_each_local_expert = []
+ input_list_before_expert = []
+ output_list_after_expert = []
+ PipeExpertUtil.set_parameters(args, slice_seq_size)
+
+ if ampipe_degree <= 1:
+ PipeExpertUtil.deal_data(inputs, inputs_list)
+ inputs.untyped_storage().resize_(0)
+
+ if not multi_stream and ampipe_degree <= 1:
+ PipeExpertUtil.first_a2a_when_not_multi_stream(inputs_list)
+
+ for i in range(num_local_experts):
+ output_list_for_each_multi_data = []
+ for j in range(multi_data):
+ if sequence_parallel:
+ if not multi_stream:
+ PipeExpertUtil.fw_bw_ag_after_first_a2a_when_not_multi_stream(inputs_list, i, j, True)
+ elif ampipe_degree <= 1:
+ PipeExpertUtil.fw_bw_ag_after_first_a2a_when_multi_stream(inputs_list, i, j, True)
+
+ PipeExpertUtil.get_fw_ag_event()[i * multi_data + j].wait()
+ else:
+ PipeExpertUtil.get_first_a2a_event()[i * multi_data + j].wait()
+
+ input_detach_before_expert = inputs_list[i * multi_data + j].detach()
+ input_detach_before_expert.requires_grad = True
+ input_list_before_expert.append(input_detach_before_expert)
+
+ with torch.enable_grad():
+ output_expert = Experts.experts[i](input_list_before_expert[i * multi_data + j])
+ if sequence_parallel:
+ get_fw_ag_output().pop(0)
+
+ if isinstance(output_expert, tuple):
+ output_expert, bias = output_expert
+ if bias is not None:
+ with torch.enable_grad():
+ output_expert = output_expert + bias
+
+ output_list_after_expert.append(output_expert)
+ output_detach_after_expert = output_expert.detach()
+
+ if not multi_stream:
+ PipeExpertUtil.fw_a2a_after_ar_rs_when_not_multi_stream(i, j, output_list_for_each_multi_data,
+ outputs_list_for_each_local_expert)
+
+ output_detach_after_expert, handle = async_fw_ar_rs(output_detach_after_expert, sequence_parallel)
+ output_list_for_each_multi_data.append(output_detach_after_expert)
+ PipeExpertUtil.get_ar_rs_event().append(handle)
+ else:
+ # all2all allgather wait release memory
+ PipeExpertUtil.get_first_a2a_event()[i * multi_data + j].wait()
+ PipeExpertUtil.get_fw_ag_event()[i * multi_data + j].wait()
+
+ output_detach_after_expert, handle = async_fw_ar_rs(output_detach_after_expert, sequence_parallel)
+ PipeExpertUtil.get_ar_rs_event().append(handle)
+ output_detach_after_expert, handle = async_all_to_all(output_detach_after_expert,
+ PipeExpertUtil.get_ar_rs_event()[
+ i * multi_data + j])
+ output_list_for_each_multi_data.append(output_detach_after_expert)
+ PipeExpertUtil.get_second_a2a_event().append(handle)
+
+ outputs_list_for_each_local_expert.append(output_list_for_each_multi_data)
+
+ if not multi_stream:
+ PipeExpertUtil.fw_a2a_for_final_data_when_not_multi_stream(outputs_list_for_each_local_expert)
+
+ for i in range(num_local_experts):
+ for j in range(multi_data):
+ PipeExpertUtil.get_second_a2a_event()[i * multi_data + j].wait()
+ # reduce scatter
+ PipeExpertUtil.get_ar_rs_event()[i * multi_data + j].wait()
+
+ PipeExpertUtil.get_first_a2a_event().clear()
+ PipeExpertUtil.get_second_a2a_event().clear()
+ PipeExpertUtil.get_fw_ag_event().clear()
+ PipeExpertUtil.get_ar_rs_event().clear()
+
+ for tensor in output_list_after_expert:
+ tensor.untyped_storage().resize_(0)
+
+ ctx.input_list_before_expert = input_list_before_expert
+
+ if 1 < ampipe_degree <= multi_data:
+ save_tensors_list.extend(output_list_after_expert)
+ output_list = []
+ for i in range(num_local_experts):
+ exp_out_list = []
+ for j in range(ampipe_degree):
+ ampipe_tokens = outputs_list_for_each_local_expert[i][
+ j * multi_data // ampipe_degree:(j + 1) * multi_data // ampipe_degree]
+ ampipe_tokens = torch.cat(ampipe_tokens, dim=1)
+ exp_out_list.append(ampipe_tokens)
+ output_list.append(exp_out_list)
+ output_forward = [
+ torch.cat([i[j] for i in output_list], dim=1).reshape(num_local_experts * ep_size, -1, hidden_size) for
+ j in range(ampipe_degree)]
+
+ else:
+ ctx.save_for_backward(*tuple(output_list_after_expert))
+ output_forward = torch.cat([torch.cat((outputs_list_for_each_local_expert[i]), dim=1) for i in range(num_local_experts)], dim=0)
+
+ return output_forward
+
+ @staticmethod
+ def backward(ctx, *args):
+ num_local_experts = ctx.num_local_experts
+ sequence_parallel = ctx.sequence_parallel
+ multi_stream = ctx.multi_stream
+ multi_data = ctx.multi_data
+ ampipe_degree = ctx.ampipe_degree
+
+ grad_outputs = args[0]
+ global ASYNC_BW_ALL_GATHER_COUNT
+ ASYNC_BW_ALL_GATHER_COUNT = 0
+
+ grad_outputs_list = []
+ grad_outputs_list_for_each_local_expert = []
+ if ampipe_degree > 1:
+ PipeExpertUtil.first_a2a_event = args[1]
+ PipeExpertUtil.bw_ag_event = args[2]
+ PipeExpertUtil.second_a2a_event = args[3]
+ output_list_after_expert = args[4]
+ grad_outputs_list = grad_outputs
+ else:
+ output_list_after_expert = list(ctx.saved_tensors)
+
+ if ampipe_degree <= 1:
+ PipeExpertUtil.deal_data(grad_outputs, grad_outputs_list)
+ grad_outputs.storage().resize_(0)
+
+ if not multi_stream and ampipe_degree <= 1:
+ PipeExpertUtil.first_a2a_when_not_multi_stream(grad_outputs_list)
+
+ for i in range(num_local_experts):
+ grad_output_list_for_each_multi_data = []
+ global FLAG_GRAD_REDUCE
+ FLAG_GRAD_REDUCE = False
+ for j in range(multi_data):
+ if sequence_parallel:
+ if not multi_stream:
+ PipeExpertUtil.fw_bw_ag_after_first_a2a_when_not_multi_stream(grad_outputs_list, i, j, False)
+
+ elif ampipe_degree <= 1:
+ PipeExpertUtil.fw_bw_ag_after_first_a2a_when_multi_stream(grad_outputs_list, i, j, False)
+
+ PipeExpertUtil.get_bw_ag_event()[i * multi_data + j].wait()
+ else:
+ PipeExpertUtil.get_first_a2a_event()[i * multi_data + j].wait()
+ ASYNC_BW_ALL_GATHER_COUNT += 1
+ if j == multi_data - 1:
+ FLAG_GRAD_REDUCE = True
+ output_list_after_expert[i * multi_data + (multi_data // ampipe_degree + j) % multi_data].backward(
+ grad_outputs_list[i * multi_data + j])
+ grads_expert_output = ctx.input_list_before_expert[
+ i * multi_data + (multi_data // ampipe_degree + j) % multi_data].grad
+
+ grads_expert_output, handle = async_all_to_all(grads_expert_output)
+ grad_output_list_for_each_multi_data.append(grads_expert_output)
+ PipeExpertUtil.get_second_a2a_event().append(handle)
+ grad_outputs_list_for_each_local_expert.append(grad_output_list_for_each_multi_data)
+
+ if 1 < ampipe_degree <= multi_data:
+ for i in range(num_local_experts):
+ for j in range(multi_data):
+ index = i * multi_data + j
+ if index < len(PipeExpertUtil.get_second_a2a_event()) - 1:
+ PipeExpertUtil.get_second_a2a_event()[index].wait()
+
+ for event in PipeExpertUtil.get_first_a2a_event():
+ event.wait()
+
+ for event in PipeExpertUtil.get_bw_ag_event():
+ event.wait()
+
+ PipeExpertUtil.get_first_a2a_event().clear()
+ PipeExpertUtil.get_bw_ag_event().clear()
+
+ output_list = []
+ for i in range(num_local_experts):
+ exp_out_list = []
+ for j in range(ampipe_degree):
+ ampipe_tokens = grad_outputs_list_for_each_local_expert[i][
+ j * multi_data // ampipe_degree:(j + 1) * multi_data // ampipe_degree]
+ exp_out_list.append(ampipe_tokens)
+ output_list.append(exp_out_list)
+ second_ampipe_grad_input = torch.cat([torch.cat(i[0], dim=1) for i in output_list], dim=1)
+ second_ampipe_grad_input = second_ampipe_grad_input.reshape((num_local_experts * PipeExpertUtil.ep_size), -1,
+ ctx.hidden_size)
+ first_ampipe_grad_inputs = [i[1] for i in output_list]
+
+ return [first_ampipe_grad_inputs, second_ampipe_grad_input]
+ else:
+ for i in range(num_local_experts):
+ for j in range(multi_data):
+ PipeExpertUtil.get_second_a2a_event()[i * multi_data + j].wait()
+
+ for event in PipeExpertUtil.get_first_a2a_event():
+ event.wait()
+
+ for event in PipeExpertUtil.get_bw_ag_event():
+ event.wait()
+
+ PipeExpertUtil.get_second_a2a_event().clear()
+ PipeExpertUtil.get_first_a2a_event().clear()
+ PipeExpertUtil.get_bw_ag_event().clear()
+ grad_output = torch.cat(
+ [torch.cat((grad_outputs_list_for_each_local_expert[i]), dim=1) for i in range(num_local_experts)], dim=0)
+
+ ctx.input_list_before_expert = None
+ return None, grad_output, None, None, None, None, None
diff --git a/model/train/yoco_moe/mindspeed/moe/utils.py b/model/train/yoco_moe/mindspeed/moe/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b343b3d52a32cc3c4d67626b0ba530a05c1fa2
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/moe/utils.py
@@ -0,0 +1,210 @@
+# Copyright (c) Microsoft Corporation.
+#
+# This source code is licensed under the Apache license found in the
+# LICENSE file in the root directory of this source tree.
+
+# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py
+# copied from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py
+# reworked/refactored some parts to make it run.
+from typing import Any
+from typing import Callable, Dict, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch import Tensor
+from megatron.core import mpu
+
+gumbel_map: Dict[torch.device, Callable] = {}
+USE_EINSUM = False
+ampipe_slices_map = {}
+
+
+def print_rank_0(message):
+ """If distributed is initialized, print only on rank 0."""
+ if torch.distributed.is_initialized():
+ if torch.distributed.get_rank() == 0:
+ print(message, flush=True)
+ else:
+ print(message, flush=True)
+
+
+# Based on https://github.com/pytorch/pytorch/pull/40762
+class _AllToAll(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
+ ctx.group = group
+ input = input.contiguous()
+ output = torch.empty_like(input)
+ dist.all_to_all_single(output, input, group=group)
+ return output
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
+ return (None, _AllToAll.apply(ctx.group, *grad_output))
+
+
+def all_gather_along_first_dim(input_, is_use_global_memory_buffer=False):
+ world_size = mpu.get_tensor_model_parallel_world_size()
+ if world_size == 1:
+ return input_
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] * world_size
+ if is_use_global_memory_buffer:
+ ag_out = mpu.get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
+ else:
+ ag_out = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
+ torch.distributed._all_gather_base(
+ ag_out, input_.contiguous(), group=mpu.get_tensor_model_parallel_group()
+ )
+ return ag_out
+
+
+def get_reshape_index_select(num_local_experts, ep_size):
+ reshape_index_select = []
+ for i in range(num_local_experts):
+ index = i
+ for j in range(ep_size):
+ reshape_index_select.append(index)
+ index += num_local_experts
+ return reshape_index_select
+
+
+def get_slice_indices_from_order_to_disorder(seq_length, pipe_degree, device):
+ if ampipe_slices_map.get('order_to_disorder') is not None:
+ return ampipe_slices_map.get('order_to_disorder')
+ tp_size = mpu.get_tensor_model_parallel_world_size()
+ slice_size = seq_length // tp_size // pipe_degree
+
+ output = []
+ for out_idx in range(0, seq_length // tp_size, slice_size):
+ for i in range(out_idx, seq_length, pipe_degree * slice_size):
+ for j in range(slice_size):
+ output.append(i + j)
+ output = torch.tensor(output, dtype=torch.int32, device=device)
+ ampipe_slices_map['order_to_disorder'] = output
+ return output
+
+
+def get_slice_indices_from_disorder_to_order(seq_length, pipe_degree, device):
+ if ampipe_slices_map.get('disorder_to_order') is not None:
+ return ampipe_slices_map.get('disorder_to_order')
+ tp_size = mpu.get_tensor_model_parallel_world_size()
+ slice_size = seq_length // tp_size // pipe_degree
+
+ output = []
+ for out_idx in range(0, seq_length // pipe_degree, slice_size):
+ for i in range(out_idx, seq_length, tp_size * slice_size):
+ for j in range(slice_size):
+ output.append(i + j)
+ output = torch.tensor(output, dtype=torch.int32, device=device)
+ ampipe_slices_map['disorder_to_order'] = output
+ return output
+
+
+def _one_hot_to_float(x, num_classes):
+ return F.one_hot(x, num_classes=num_classes).float()
+
+
+def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
+ # gates has shape of S,E
+ num_tokens = gates.shape[0]
+ num_experts = gates.shape[1]
+ max_capacity = num_tokens
+ # to(torch.int64) works around a bug in torch.onnx.export:
+ # it should cast k to int64 when converting torch.topk but it doesn't.
+ capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
+ if capacity < min_capacity:
+ capacity = min_capacity.to(torch.int64)
+ elif capacity > max_capacity:
+ capacity = torch.tensor(max_capacity, dtype=torch.int64)
+ return capacity
+
+
+def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
+ gumbel = gumbel_map.get(device)
+ if gumbel is None:
+ one = torch.tensor(1.0, device=device)
+ zero = torch.tensor(0.0, device=device)
+ gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
+ gumbel_map[device] = gumbel
+ return gumbel(shape)
+
+
+# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
+# See https://arxiv.org/pdf/2006.16668.pdf for details.
+def einsum(rule, a, b):
+ if USE_EINSUM:
+ return torch.einsum(rule, a, b)
+ elif rule == 's,se->se':
+ return a.reshape(a.shape[0], -1) * b
+ elif rule == 'se,sc->sec':
+ return a.unsqueeze(2) * b.unsqueeze(1)
+ elif rule == 'se,se->s':
+ return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
+ elif rule == 'sec,sm->ecm':
+ s = a.shape[0]
+ e = a.shape[1]
+ c = a.shape[2]
+ m = b.shape[1]
+ return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
+ elif rule == 'sec,ecm->sm':
+ return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
+ elif rule == 'ks,ksm->sm':
+ k = b.shape[0]
+ s = b.shape[1]
+ m = b.shape[2]
+ # [k, s] -> [s, k] -> [s, 1, k]
+ a = a.t().unsqueeze(1)
+ # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
+ b = b.reshape(k, -1).t().reshape(s, m, k)
+ # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
+ return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
+ else:
+ return torch.einsum(rule, a, b)
+
+
+class MoEAuxLossAutoScaler(torch.autograd.Function):
+ """An AutoScaler that compute and scales the grad for auxiliary loss.
+
+ """
+
+ main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
+
+ @staticmethod
+ def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
+ """Preserve the aux_loss by storing it in the context to avoid garbage collection.
+
+ Args:
+ output (torch.Tensor): The output tensor.
+ aux_loss (torch.Tensor): The auxiliary loss tensor.
+
+ Returns:
+ torch.Tensor: The output tensor.
+ """
+ ctx.save_for_backward(aux_loss)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor):
+ """Compute and scale the gradient for auxiliary loss..
+
+ Args:
+ grad_output (torch.Tensor): The gradient of the output.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
+ """
+ (aux_loss,) = ctx.saved_tensors
+ aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
+ scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
+ return grad_output, scaled_aux_loss_grad
+
+ @staticmethod
+ def set_loss_scale(scale: torch.Tensor):
+ """set the scale of the aux loss.
+
+ Args:
+ scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
+ """
+ MoEAuxLossAutoScaler.main_loss_backward_scale = scale
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/conv3d/__init__.py b/model/train/yoco_moe/mindspeed/multi_modal/conv3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/conv3d/conv3d_depth_parallel.py b/model/train/yoco_moe/mindspeed/multi_modal/conv3d/conv3d_depth_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..88776419e6b909cc1242ccabd760a0d52d6082e9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/conv3d/conv3d_depth_parallel.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+
+
+def get_range_list_of_3dshape(dim_size, world_size, kernel_size, stride):
+ def find_last_le_k(arr, k):
+ return max((element for element in arr if element < k), default=arr[-1])
+
+ def find_first_ge_k(arr, k):
+ return next((element for element in arr if element >= k), arr[-1])
+
+ range_list = []
+ stride_index = [i for i in range(0, dim_size, stride)]
+ for rank in range(world_size):
+ depth_per_sp = dim_size // world_size
+ start_idx = find_first_ge_k(stride_index, rank * depth_per_sp)
+ last_idx = find_last_le_k(stride_index, (rank + 1) * depth_per_sp) + 1
+ end_idx = last_idx + kernel_size - 1 if rank < world_size - 1 else dim_size
+
+ range_list.append([start_idx, end_idx])
+ return range_list
+
+
+def _split(input_, pg: dist.ProcessGroup, dim=-1, kernel_size=1, stride=1, depth_range=None):
+ # skip if only one rank involved
+ world_size = dist.get_world_size(pg)
+ rank = dist.get_rank(pg)
+ if world_size == 1:
+ return input_
+
+ if depth_range:
+ start_idx, end_idx = depth_range[rank]
+ output = input_[:, :, start_idx:end_idx, :, :].contiguous()
+ return output, None
+
+ # Split along last dimension.
+ dim_size = input_.size(dim)
+
+ start_end_idx_list = get_range_list_of_3dshape(dim_size, world_size, kernel_size, stride)
+ start_idx, end_idx = start_end_idx_list[rank]
+ output = input_[:, :, start_idx:end_idx, :, :].contiguous()
+
+ return output, start_end_idx_list
+
+
+def _gather(input_, pg: dist.ProcessGroup, total_depth, dim=2, kernel_size=1, stride=1, is_forward=True):
+ input_ = input_.contiguous()
+ world_size = dist.get_world_size(pg)
+ padding = 0 # not support padding currently
+
+ # skip if only one rank involved
+ if world_size == 1:
+ return input_
+
+ tensor_list = []
+ start_end_idx_list = get_range_list_of_3dshape(total_depth, world_size, kernel_size, stride)
+ original_start_end_idx_list = []
+ conv_start_end_idx_list = []
+
+ if is_forward:
+ # forward: build the shapes after conv
+ last_end_idx = 0
+ for start_idx, end_idx in start_end_idx_list:
+ length = end_idx - start_idx
+ # O = (W-K+2P)/S + 1
+ length = (length - kernel_size + 2 * padding) // stride + 1
+ conv_start_end_idx_list.append([last_end_idx, last_end_idx + length])
+ last_end_idx = last_end_idx + length
+ tensor_list.append(torch.empty_like(input_[:, :, 0:1, :, :].expand(-1, -1, length, -1, -1)))
+ output_start_end_idx_list = conv_start_end_idx_list
+ else:
+ # backward: build the original shapes before conv
+ for start_idx, end_idx in start_end_idx_list:
+ # O = (W-K+2P)/S + 1
+ original_start_end_idx_list.append([start_idx, end_idx])
+ tensor_list.append(torch.empty_like(input_[:, :, 0:1, :, :].expand(-1, -1, end_idx - start_idx, -1, -1)))
+ output_start_end_idx_list = original_start_end_idx_list
+
+ dist.all_gather(tensor_list, input_, group=pg)
+ output = torch.cat(tensor_list, dim=dim).contiguous()
+ if not is_forward:
+ real_output = torch.zeros_like(input_[:, :, 0:1, :, :].expand(-1, -1, total_depth, -1, -1))
+ for tensor, idx in zip(tensor_list, output_start_end_idx_list):
+ start_idx, end_idx = idx
+ for i in range(start_idx, end_idx):
+ j = i - start_idx
+ real_output[:, :, i, :, :] = real_output[:, :, i, :, :] + tensor[:, :, j, :, :]
+
+ output = real_output
+ return output, output_start_end_idx_list
+
+
+class _ConvGatherForwardSplitBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, process_group, total_depth, dim, kernel_size, stride):
+ ctx.mode = process_group
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ ctx.stride = stride
+ output, depth_range = _gather(input_, process_group, total_depth, dim, kernel_size, stride, True)
+ ctx.depth_range = depth_range
+ return output
+
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ output, _ = _split(grad_output, ctx.mode, ctx.dim, ctx.kernel_size, ctx.stride, ctx.depth_range)
+ return output, None, None, None, None, None, None
+
+
+class _ConvSplitForwardGatherBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, process_group, dim, kernel_size, stride):
+ ctx.mode = process_group
+ ctx.dim = dim
+ ctx.kernel_size = kernel_size
+ ctx.stride = stride
+ ctx.total_depth = input_.shape[dim]
+ output, _ = _split(input_, process_group, dim, kernel_size, stride)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ output, _ = _gather(grad_output, ctx.mode, ctx.total_depth, ctx.dim, ctx.kernel_size, ctx.stride, False)
+ return output, None, None, None, None, None, None
+
+
+class AllReduceFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, conv3d_module, param_async, grad_reduce_handles):
+ ctx.grad_reduce_handles = grad_reduce_handles
+ ctx.param_async = param_async
+ ctx.conv3d = conv3d_module
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ for param in ctx.conv3d.parameters():
+ if param.grad is not None:
+ if ctx.param_async:
+ handle = torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.SUM, async_op=True)
+ ctx.grad_reduce_handles.append(handle)
+ else:
+ torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.SUM)
+ return grad_output, None, None, None
+
+
+class Conv3DSequenceParallel(nn.Module):
+ def __init__(self,
+ pg: dist.ProcessGroup,
+ in_channels,
+ out_channels,
+ kernel_size=(1, 1, 1),
+ stride=(1, 1, 1),
+ dilation=1,
+ bias=True,
+ param_async=False,
+ dtype=torch.bfloat16,
+ sp_size=1):
+ super(Conv3DSequenceParallel, self).__init__()
+ self.sp_size = sp_size
+ self.depth_kernel_size = kernel_size[0]
+ self.depth_stride = stride[0]
+ self.param_async = param_async
+ self.padding = 0 # not support padding currently
+ self.pg = pg
+ self.world_size = dist.get_world_size(pg)
+ self.grad_reduce_handles = []
+
+ self.conv3d = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=self.padding,
+ dilation=dilation,
+ bias=bias
+ ).npu().to(dtype)
+
+ def forward(self, x):
+ depth = x.shape[2] # [batch_size, in_channels, depth, height, width]
+
+ # O = (W-K+2P)/S + 1
+ depth_after_conv = (depth - self.depth_kernel_size + 2 * self.padding) // self.depth_stride + 1
+ if self.sp_size > 1 and (depth_after_conv // self.world_size) > 0:
+ x = AllReduceFunction.apply(x, self.conv3d, self.param_async, self.grad_reduce_handles)
+ x = _ConvSplitForwardGatherBackward.apply(x, self.pg, 2, self.depth_kernel_size, self.depth_stride)
+
+ x = self.conv3d(x)
+
+ if self.sp_size > 1 and (depth_after_conv // self.world_size) > 0:
+ x = _ConvGatherForwardSplitBackward.apply(x, self.pg, depth, 2, self.depth_kernel_size, self.depth_stride)
+
+ return x
+
+ def get_param_grad_reduce_handles(self):
+ return self.grad_reduce_handles
+
+ def wait_param_grad_reduce_handles(self):
+ for handle in self.grad_reduce_handles:
+ handle.wait()
+ self.grad_reduce_handles = []
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/__init__.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..985e3ed26654d77c48a0d904cd5e86efce1d86b3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/__init__.py
@@ -0,0 +1 @@
+from . import communication, config, inner_data_parallel, pipeline_parallel, parallel_state, training, checkpointing
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/checkpointing.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/checkpointing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3faf3f1d04a8b025ae6dbf1d209aa6f25ddb00b0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/checkpointing.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import os
+from functools import wraps
+from megatron.core import mpu
+from .config.dist_train_config import get_dist_model_name
+
+
+def get_checkpoint_name_wrapper(get_checkpoint_name):
+ @wraps(get_checkpoint_name)
+ def wrapper(*args, **kwargs):
+ return _get_checkpoint_name(*args, **kwargs)
+ return wrapper
+
+
+def _get_checkpoint_name(checkpoints_path, iteration, release=False, **kwargs):
+ if release:
+ directory = 'release'
+ else:
+ directory = 'iter_{:07d}'.format(iteration)
+ if kwargs.get('return_base_dir', False):
+ common_path = os.path.join(checkpoints_path, directory)
+ return common_path
+
+ pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
+ tensor_rank = mpu.get_tensor_model_parallel_rank()
+ pipeline_rank = mpu.get_pipeline_model_parallel_rank()
+ model_name = get_dist_model_name()
+ if not pipeline_parallel:
+ common_path = os.path.join(checkpoints_path, directory,
+ f'mp_{model_name}_rank_{tensor_rank:02d}')
+ else:
+ common_path = os.path.join(checkpoints_path, directory,
+ f'mp_{model_name}_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
+
+ return os.path.join(common_path, "model_optim_rng.pt")
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/__init__.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae82f5bfb5a1719a56b44384443ac964faa78b9a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/__init__.py
@@ -0,0 +1,2 @@
+from .dist_communication import send_recv_tensor_list, generate_send_recv_mask
+from .dist_ranks_match import get_dst_ranks
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/dist_communication.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/dist_communication.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c4833195448e4394dc727fd86747813792f5aea
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/dist_communication.py
@@ -0,0 +1,230 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+from typing import Optional, Sequence, Tuple, List, Dict
+import torch
+import torch_npu
+from ..parallel_state import _is_pipeline_first_stage, _is_pipeline_last_stage, get_global_pipeline_parallel_rank
+from ..config.dist_train_config import get_dist_model_index, get_rank_number_to_model_index, get_dist_model_config
+
+TENSOR_SYNC_TOOL: "TensorSyncTool" = None
+
+
+def init_tensor_sync_tool():
+ global TENSOR_SYNC_TOOL
+ if TENSOR_SYNC_TOOL is None:
+ TENSOR_SYNC_TOOL = TensorSyncTool()
+ return TENSOR_SYNC_TOOL
+
+
+class TensorSyncTool:
+ def __init__(self):
+ dtypes = []
+ for name in dir(torch):
+ attr = getattr(torch, name)
+ if isinstance(attr, torch.dtype) and attr not in torch_npu.unsupported_dtype:
+ dtypes.append(attr)
+ # Sorting enables different machines to obtain dtypes in the same sequence.
+ dtypes = sorted(set(dtypes), key=lambda x: str(x))
+
+ self.type_to_int = {None: -1}
+ self.type_to_int.update({dtype: i for i, dtype in enumerate(dtypes)})
+ self.int_to_type = {v: k for k, v in self.type_to_int.items()}
+ # fixed_header_len (10) = dtype (1) + req_grads (1) + len_shape (1) + shape (x) + pad(10 - 3 - x)
+ # Thus, the maximum dimension of tensors that can be supported here is 7.
+ self.fixed_header_len = 10
+
+ def encode_tensor_header(self, tensor: torch.Tensor):
+ """
+ | int32 | int32 | int32 | int32 | int32 |
+ | type | req_grads | len(shape) | shape | pad |
+ """
+ header = [0] * self.fixed_header_len
+
+ header[0] = self.type_to_int.get(tensor.dtype, -1) if tensor is not None else -1
+ if header[0] not in self.type_to_int.values():
+ if header[0] == -1: # `-1` matches `None`
+ return header
+ raise RuntimeError(f"The tensor dtype is not supported or recorded on this device: {tensor.dtype}")
+ header[1] = int(tensor.requires_grad)
+ header[2] = len(tensor.shape)
+ if self.fixed_header_len - 3 < len(tensor.shape): # `3` equals the len of [dtype, req_grads, len_shape]
+ raise ValueError('`len(tensor.shape)` is too long to be stored in the remaining space of the header.')
+ header[3:] = tensor.shape
+
+ device = torch.npu.current_device()
+ index = list(range(len(header)))
+ index = torch.tensor(index, dtype=torch.int32, device=device)
+ header = torch.tensor(header, dtype=torch.int32, device=device)
+ header_tensor = torch.zeros(TENSOR_SYNC_TOOL.fixed_header_len, dtype=torch.int32, device=device)
+ header_tensor.scatter_(0, index, header)
+ return header_tensor
+
+ def decode_tensor_header(self, header_tensor: torch.Tensor):
+ dtype = self.int_to_type.get(int(header_tensor[0]), None)
+ if dtype is None:
+ return dtype, None, None
+ requires_grad = bool(header_tensor[1])
+ shape_len = header_tensor[2]
+ shape = header_tensor.tolist()[3:3 + shape_len]
+ return dtype, shape, requires_grad
+
+
+def send_recv(tensor: Optional[torch.Tensor], is_recv: bool, ranks: Sequence) -> Optional[Sequence[torch.Tensor]]:
+ """
+ force_send is used for text_only backward situations.pre_subworld skips backward if recv None tensor.
+ """
+ if isinstance(tensor, Sequence):
+ tensor = tensor[0]
+
+ recv_tensor = None
+ # To prevent deadlocks caused by different pipeline stages receiving tensor simultaneously.
+ if not get_global_pipeline_parallel_rank() % 2:
+ if tensor is not None:
+ _send_tensor(tensor, ranks)
+ if is_recv:
+ recv_tensor = _recv_tensor(ranks)
+ else:
+ if is_recv:
+ recv_tensor = _recv_tensor(ranks)
+ if tensor is not None:
+ _send_tensor(tensor, ranks)
+
+ if is_recv and not isinstance(recv_tensor, list):
+ recv_tensor = [recv_tensor]
+
+ return recv_tensor
+
+
+def send_recv_tensor_list(
+ tensor_list: Optional[Sequence[torch.Tensor]],
+ is_recv: bool,
+ dst_ranks: Sequence[int],
+) -> Optional[Sequence[Sequence[torch.Tensor]]]:
+ if tensor_list is None:
+ if not is_recv:
+ raise ValueError('`tensor_list` can be set to `None` only on the receive side.')
+ elif isinstance(tensor_list, Sequence) and len(tensor_list) > 0 and isinstance(tensor_list[0], Sequence):
+ tensor_list = tensor_list[0]
+ else:
+ if not isinstance(tensor_list, Sequence):
+ raise TypeError(f'`tensor_list` is an unsupported type: {type(tensor_list)}')
+ if not isinstance(tensor_list[0], torch.Tensor):
+ raise TypeError(f'item of `tensor_list` is an unsupported type: {type(tensor_list[0])}')
+
+ tensor_list_ret = None
+ # To prevent deadlocks caused by different pipeline stages receiving tensor simultaneously.
+ if not get_global_pipeline_parallel_rank() % 2:
+ if tensor_list is not None:
+ send_tensor_list(tensor_list, dst_ranks)
+ if is_recv:
+ tensor_list_ret = recv_tensor_list(dst_ranks)
+ else:
+ if is_recv:
+ tensor_list_ret = recv_tensor_list(dst_ranks)
+ if tensor_list is not None:
+ send_tensor_list(tensor_list, dst_ranks)
+
+ return tensor_list_ret
+
+
+def recv_tensor_list(src_ranks: Sequence[int]) -> Optional[Sequence[Sequence[torch.Tensor]]]:
+ tensor_list_len = []
+ recv_tensor = torch.tensor([0], device=torch.npu.current_device())
+ for rank in src_ranks:
+ torch.distributed.recv(recv_tensor, rank)
+ tensor_list_len.append(recv_tensor.item())
+
+ if not all(tensor_list_len[0] == len_ for len_ in tensor_list_len[1:]):
+ raise ValueError(f'Tensor sequences of different lengths cannot be received from different cards.')
+ tensor_list_ret = [_recv_tensor(src_ranks) for _ in range(tensor_list_len[0])]
+
+ return tensor_list_ret
+
+
+def send_tensor_list(tensor_list: Optional[Sequence[torch.Tensor]], dst_ranks: Sequence[int]) -> None:
+ tensor_list_len = len(tensor_list)
+ if tensor_list_len == 0:
+ return
+ send_tensor = torch.tensor([tensor_list_len], device=torch.npu.current_device())
+ for rank in dst_ranks:
+ torch.distributed.send(send_tensor, rank)
+ for i in range(tensor_list_len):
+ _send_tensor(tensor_list[i], dst_ranks)
+
+
+def _send_header(tensor: torch.Tensor, dst: int) -> None:
+ header_tensor = TENSOR_SYNC_TOOL.encode_tensor_header(tensor)
+ torch.distributed.send(header_tensor, dst)
+
+
+def _send_tensor(tensor: torch.tensor, dst_ranks: Sequence) -> None:
+ if tensor is None:
+ return
+ for dst in dst_ranks:
+ _send_header(tensor, dst)
+ torch.distributed.send(tensor=tensor, dst=dst)
+
+
+def _recv_header(src: int) -> Tuple[Optional[torch.dtype], Optional[List[int]], Optional[bool]]:
+ device = torch.npu.current_device()
+ header_tensor = torch.zeros(TENSOR_SYNC_TOOL.fixed_header_len, dtype=torch.int32, device=device)
+ torch.distributed.recv(header_tensor, src)
+ header = TENSOR_SYNC_TOOL.decode_tensor_header(header_tensor)
+ return header
+
+
+def _recv_tensor(dst_ranks: Sequence) -> Optional[Sequence[torch.Tensor]]:
+ """Asynchronously receiving tensors
+
+ first receive the shape and dtype, use these to initialize an empty tensor,
+ then receive the tensor data, and finally return the tensor.
+ """
+ recv_tensors = []
+ for rank in dst_ranks:
+ # recv header
+ dtype, shape, requires_grad = _recv_header(rank)
+ device = torch.npu.current_device()
+ if dtype is None:
+ print('[WARNING] Get dtype=None from received header.')
+ return None
+ # recv tensor
+ tensor_recv_prev = torch.empty(tuple(shape), dtype=dtype, device=device, requires_grad=requires_grad)
+ torch.distributed.recv(tensor=tensor_recv_prev, src=rank)
+
+ recv_tensors.append(tensor_recv_prev)
+ return recv_tensors
+
+
+def generate_send_recv_mask(rank: int = None) -> Dict[str, bool]:
+ model_index = get_dist_model_index(rank)
+ rank_number_to_model_index = get_rank_number_to_model_index()
+ if model_index not in rank_number_to_model_index:
+ raise RuntimeError(f"model_index ({model_index}) not in _RANK_NUMBER_TO_MODEL_INDEX")
+
+ result = {
+ 'send_forward': False,
+ 'send_backward': False,
+ 'recv_forward': False,
+ 'recv_backward': False
+ }
+ if _is_pipeline_first_stage(is_global=False):
+ for i, index in enumerate(rank_number_to_model_index):
+ if index < model_index:
+ result['recv_forward'] = True
+ if (not get_dist_model_config(rank=i).forward_only) \
+ and (not get_dist_model_config(rank=rank).forward_only):
+ result['send_backward'] = True
+ break
+
+ if _is_pipeline_last_stage(is_global=False):
+ for i, index in enumerate(rank_number_to_model_index):
+ if index > model_index:
+ result['send_forward'] = True
+ if (not get_dist_model_config(rank=i).forward_only) \
+ and (not get_dist_model_config(rank=rank).forward_only):
+ result['recv_backward'] = True
+ break
+
+ return result
+
+
+init_tensor_sync_tool()
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/dist_ranks_match.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/dist_ranks_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..a048705b2c4408da607c4a48288abfc74737fc42
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/communication/dist_ranks_match.py
@@ -0,0 +1,116 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+from itertools import accumulate
+import torch.distributed as dist
+from ..config import dist_train_config as config
+
+
+"""
+key:int: cur rank
+value:list: dst ranks
+"""
+_MODEL_COMM_RANKS = {}
+
+
+def generate_model_comm_ranks(pp_ranks_prev: [[]], tp_ranks_prev: [[]], pp_ranks_last: [[]], tp_ranks_last: [[]]):
+ global _MODEL_COMM_RANKS
+ if _MODEL_COMM_RANKS and config.get_all_config_size() != 2:
+ # If the size is 2, this method is expected to be invoked only once.
+ raise RuntimeError(f'Get config size ({config.get_all_config_size()}) is not equal to 2, '
+ f'and _MODEL_COMM_RANKS is initialized.')
+ tp_ranks_prev_ = []
+ tp_ranks_last_ = []
+
+ # Take the ranks of the last stage of 'prev' and the first stage of 'last'.
+ for pp_ranks in pp_ranks_prev:
+ for tp_ranks in tp_ranks_prev:
+ if pp_ranks[-1] in tp_ranks and tp_ranks not in tp_ranks_prev_:
+ tp_ranks_prev_.append(tp_ranks)
+
+ for pp_ranks in pp_ranks_last:
+ for tp_ranks in tp_ranks_last:
+ if pp_ranks[0] in tp_ranks and tp_ranks not in tp_ranks_last_:
+ tp_ranks_last_.append(tp_ranks)
+
+ if not (len(tp_ranks_prev_) and len(tp_ranks_last_)):
+ raise ValueError("tp ranks must not empty")
+
+ # Place the TP units with fewer counts at the front and those with more at the back,
+ # so that when generating the forward correspondence, it traverses through fewer iterations.
+ if len(tp_ranks_prev_) > len(tp_ranks_last_):
+ tp_ranks_prev_, tp_ranks_last_ = tp_ranks_last_, tp_ranks_prev_
+
+ # Generate correspondence.
+ lens_last = get_size_list(len(tp_ranks_last_), len(tp_ranks_prev_), 1)
+ index_for_last = [0] + list(accumulate(lens_last))
+ ranks_dict_prev = {}
+ for i, prev_ranks in enumerate(tp_ranks_prev_):
+ last_ranks = [rank for lst in tp_ranks_last_[index_for_last[i]: index_for_last[i + 1]] for rank in lst]
+ num_take_last = lens_last[i] # The actual number of data sets taken from tp_ranks_last_ in this round.
+ num_unit_last = len(tp_ranks_last_[0])
+
+ # Place the elements with fewer counts at the front and those with more at the back,
+ # to facilitate the execution of the general logic.
+ if len(last_ranks) < len(prev_ranks):
+ prev_ranks, last_ranks = last_ranks, prev_ranks
+ num_take_last = 1 # Only one sublist will be extracted from tp_ranks_Prev_ in each round.
+ num_unit_last = len(tp_ranks_prev_[0])
+
+ # Establish the corresponding relationships.
+ per_ranks = get_size_list(len(last_ranks), len(prev_ranks), num_unit_last)
+ index_for_prev = [0] + list(accumulate(per_ranks))
+ for j, rank_ in enumerate(prev_ranks):
+ ranks_dict_prev[rank_] = last_ranks[index_for_prev[j]: index_for_prev[j + 1]]
+
+ print(f"rank={dist.get_rank()}, num_take_last: {num_take_last}, num_unit_last: {num_unit_last}, "
+ f"prev: {prev_ranks}, last: {last_ranks}")
+
+ # Conversely, establish the corresponding relationships again;
+ # currently, this is only compatible with scenarios where the model is divided into two parts.
+ ranks_dict_last = {last: [prev] for prev in ranks_dict_prev for last in ranks_dict_prev.get(prev, None)}
+ if None in ranks_dict_last.keys():
+ raise KeyError('Found unexpected keys in `ranks_dict_last`')
+
+ # Update data
+ keys = ranks_dict_prev.keys() | ranks_dict_last.keys()
+ for k in keys:
+ _MODEL_COMM_RANKS[k] = _MODEL_COMM_RANKS.get(k, []) + ranks_dict_prev.get(k, []) + ranks_dict_last.get(k, [])
+
+
+def get_dst_ranks(rank=None):
+ global _MODEL_COMM_RANKS
+ if rank is None:
+ rank = dist.get_rank()
+
+ return _MODEL_COMM_RANKS.get(rank, None)
+
+
+def clear_model_comm_ranks():
+ global _MODEL_COMM_RANKS
+ _MODEL_COMM_RANKS = {}
+
+
+def get_size_list(sum_, len_, base_):
+ """
+ sum, len, base:
+ 12, 2, 7 => 12, 2, 6 => [6, 6] base is too large, let the base cycle subtract 1 first
+ 15, 2, 5 => [5, 5] => [10, 5] base is appropriate, try to allocate with multiple of base num
+ 12, 2, 5 => [5, 5] => [6, 6] base is too small, try to allocate as much as possible
+ """
+ if not all(isinstance(num, int) for num in (sum_, len_, base_)):
+ raise ValueError("sum_, base_ and len_ must be integers.")
+ if base_ <= 0 or len_ <= 0:
+ raise ValueError("base_ and len_ cannot be zero.")
+ while sum_ // base_ < len_:
+ base_ -= 1
+ list_base_ = sum_ // len_ // base_ * base_
+ list_ = [list_base_ for _ in range(len_)]
+ rem_ = sum_ - len_ * list_base_
+ base_ = base_ if rem_ % base_ == 0 else 1
+ index_ = 0
+
+ while rem_ > 0:
+ list_[index_ % len_] += base_
+ rem_ -= base_
+ index_ += 1
+
+ return list_
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/config/__init__.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c120956432a729cdaaa1637536c787f962d6c0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/config/__init__.py
@@ -0,0 +1,5 @@
+from .dist_train_config import (
+ get_all_config, get_dist_model_config, get_dist_model_index, get_rank_number_to_model_index, get_all_config_size,
+ get_dist_model_name, get_rank_number_to_model_name, get_dist_global_model_index,
+ merge_dist_train_args, is_forward_only_model
+)
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/config/dist_train_config.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/config/dist_train_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab3b8c075b8d68be188c3dd147ca9123a46b9282
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/config/dist_train_config.py
@@ -0,0 +1,322 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import os
+import json
+import torch.distributed
+
+_ALL_CONFIG = {} # {name: DetachedConfig()}
+# model_idx: 0 1
+# vae rank0 ↘
+# vit rank2,3
+# t5 rank1 ↗
+_RANK_NUMBER_TO_MODEL_INDEX = [] # rank index (list index) -- model index -- [0, 0, 1, 1]
+_RANK_NUMBER_TO_MODEL_NAME = [] # rank index (list index) -- model name -- ['vae', 't5', 'vit', 'vit']
+_NUMBER_OF_MODELS = 0
+_USE_MULTIPARAM_SEND_RECV = False
+_ALL_DIST_MODEL_INDEX = []
+_ALL_DIST_MODEL_NAME = []
+_ALL_DIST_MODEL_CONFIG = []
+_SUPPORT_MODEL_NAME = {"internvl2": ["vit", "gpt"], "opensoraplan1.3": ["vae", "dit"]}
+
+
+class ContextKey:
+ DIST_CONFIG = 'dist_config'
+ # global config keys
+ MODEL_CONFIG = 'model_config'
+ USE_MULTIPARAM_SEND_RECV = 'use_multiparam_send_recv'
+ MODEL_NAME = 'model_name'
+ # model config keys
+ NAME = 'name'
+ MODEL_INDEX = 'model_index'
+ WORLD_SIZE = 'world_size'
+ TENSOR_MODEL_PARALLEL_SIZE = 'tensor_model_parallel_size'
+ PIPELINE_MODEL_PARALLEL_SIZE = 'pipeline_model_parallel_size'
+ CONTEXT_PARALLEL_SIZE = 'context_parallel_size'
+ MAIN_DP = 'main_dp'
+ FORWARD_ONLY = 'forward_only'
+
+
+CK = ContextKey()
+
+
+class ModelConfig:
+ def __init__(self, config_dict: dict, start_rank):
+ self._keys = {CK.NAME, CK.MODEL_INDEX, CK.WORLD_SIZE, CK.TENSOR_MODEL_PARALLEL_SIZE,
+ CK.PIPELINE_MODEL_PARALLEL_SIZE, CK.CONTEXT_PARALLEL_SIZE, CK.FORWARD_ONLY, CK.MAIN_DP}
+ self._base_validate(config_dict)
+
+ setattr(self, CK.NAME, None)
+ setattr(self, CK.MODEL_INDEX, None)
+ setattr(self, CK.WORLD_SIZE, None)
+ setattr(self, CK.TENSOR_MODEL_PARALLEL_SIZE, 1)
+ setattr(self, CK.PIPELINE_MODEL_PARALLEL_SIZE, 1)
+ setattr(self, CK.CONTEXT_PARALLEL_SIZE, 1)
+ setattr(self, CK.FORWARD_ONLY, False)
+ setattr(self, CK.MAIN_DP, False)
+ self._set_single_model_config(config_dict)
+
+ # Additional generated attributes.
+ self.start_rank = start_rank
+ self.ranks = list(range(self.start_rank, self.start_rank + getattr(self, CK.WORLD_SIZE)))
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value):
+ setattr(self, key, value)
+
+ def __delitem__(self, key):
+ delattr(self, key)
+
+ def __repr__(self):
+ repr_str = '('
+ for k in self._keys:
+ repr_str += f'{k}: {getattr(self, k)}, '
+ repr_str = repr_str.rstrip(', ') + ')'
+ return repr_str
+
+ def _set_single_model_config(self, config_dict):
+ for k, v in config_dict.items():
+ setattr(self, k, v)
+ self._keys.add(k)
+
+ def _base_validate(self, ori_cfg):
+ # startswith
+ if any(key.startswith('_') for key in ori_cfg.keys()):
+ raise ValueError('The configuration item field cannot start with an underscore (_) '
+ 'to prevent unexpected overwriting.')
+ # check valid key
+ valid_keys = list(self._keys)
+ invalid_keys = [key for key in ori_cfg if key not in valid_keys]
+ if invalid_keys:
+ raise KeyError(f"The following keys in DistTrain config are not valid: {invalid_keys}")
+ # world_size
+ world_size = ori_cfg.get(CK.WORLD_SIZE)
+ if not (isinstance(world_size, int) and world_size > 0):
+ raise ValueError(f'`{CK.WORLD_SIZE}` ({world_size}) should be greater than or equal to 0')
+ # parallel
+ tp_size = ori_cfg.get(CK.TENSOR_MODEL_PARALLEL_SIZE, 1)
+ pp_size = ori_cfg.get(CK.PIPELINE_MODEL_PARALLEL_SIZE, 1)
+ cp_size = ori_cfg.get(CK.CONTEXT_PARALLEL_SIZE, 1)
+ if not (isinstance(tp_size, int) and tp_size > 0):
+ raise ValueError(f'`{CK.TENSOR_MODEL_PARALLEL_SIZE}` ({tp_size}) should be greater than 0')
+ if not (isinstance(pp_size, int) and pp_size > 0):
+ raise ValueError(f'`{CK.PIPELINE_MODEL_PARALLEL_SIZE}` ({pp_size}) should be greater than 0')
+ if not (isinstance(cp_size, int) and cp_size > 0):
+ raise ValueError(f'`{CK.CONTEXT_PARALLEL_SIZE}` ({cp_size}) should be greater than 0')
+ if world_size % (tp_size * pp_size * cp_size):
+ raise ValueError((f'`{CK.WORLD_SIZE}` ({world_size}) should be divisible by the product of '
+ f'`{CK.TENSOR_MODEL_PARALLEL_SIZE}` ({tp_size}), `{CK.PIPELINE_MODEL_PARALLEL_SIZE}` '
+ f'({pp_size}), and `{CK.CONTEXT_PARALLEL_SIZE}` ({cp_size})'))
+ if CK.FORWARD_ONLY in ori_cfg and not isinstance(ori_cfg.get(CK.FORWARD_ONLY), bool):
+ raise TypeError(f"The `{CK.FORWARD_ONLY}` value type must be bool.")
+
+
+def validate_configs_world_size(args):
+ world_size = 0
+ for cfg in _ALL_CONFIG.values():
+ world_size += cfg[CK.WORLD_SIZE]
+ if world_size != args.world_size:
+ raise ValueError('The sum of `world_size` in config must be equal to the actual `world_size`.')
+
+
+def get_all_config():
+ return _ALL_CONFIG
+
+
+def get_all_config_size():
+ return len(_ALL_CONFIG)
+
+
+def get_rank_number_to_model_index():
+ return _RANK_NUMBER_TO_MODEL_INDEX
+
+
+def get_rank_number_to_model_name():
+ return _RANK_NUMBER_TO_MODEL_NAME
+
+
+def get_dist_model_name(rank: int = None, global_index: int = None) -> str:
+ if global_index is not None:
+ if not (0 - _NUMBER_OF_MODELS <= global_index < _NUMBER_OF_MODELS):
+ raise ValueError(f'`global_index` must between `0 - _NUMBER_OF_MODELS` ({0 - _NUMBER_OF_MODELS}) '
+ f'and `_NUMBER_OF_MODELS` ({_NUMBER_OF_MODELS})')
+ key = list(_ALL_CONFIG.keys())[global_index]
+ index_name = _ALL_CONFIG[key][CK.NAME]
+ if rank is None:
+ return index_name
+ else:
+ if not (0 <= rank < len(_RANK_NUMBER_TO_MODEL_NAME)):
+ raise IndexError(f'{rank=} should between 0 and {len(_RANK_NUMBER_TO_MODEL_NAME)=}, '
+ f'check the config file and launch params')
+ name = _RANK_NUMBER_TO_MODEL_NAME[rank]
+ if index_name != name:
+ raise RuntimeError(f'{rank=}, `{index_name}` should equals `{name}`')
+ return name
+
+ if rank is None:
+ rank = torch.distributed.get_rank()
+ if not (0 <= rank < len(_RANK_NUMBER_TO_MODEL_NAME)):
+ raise IndexError(f'{rank=} should between 0 and {len(_RANK_NUMBER_TO_MODEL_NAME)=}, '
+ f'check the config file and launch params')
+
+ name = _RANK_NUMBER_TO_MODEL_NAME[rank]
+ return name
+
+
+def get_dist_model_config(name: str = None, rank: int = None, global_index: int = None):
+ if global_index is not None:
+ if not (0 - _NUMBER_OF_MODELS <= global_index < _NUMBER_OF_MODELS):
+ raise ValueError(f'`global_index` must between `0 - _NUMBER_OF_MODELS` ({0 - _NUMBER_OF_MODELS}) '
+ f'and `_NUMBER_OF_MODELS` ({_NUMBER_OF_MODELS})')
+ if name is not None:
+ if rank is not None or global_index is not None:
+ if name != get_dist_model_name(rank, global_index):
+ raise RuntimeError(f'{rank=}, `{name}` should equals `{get_dist_model_name(rank, global_index)}`')
+ else:
+ name = get_dist_model_name(rank, global_index)
+ if name not in _ALL_CONFIG.keys():
+ raise KeyError(f'{name=} not in {_ALL_CONFIG.keys()=}')
+ return _ALL_CONFIG[name]
+
+
+def get_dist_model_index(rank: int = None) -> int:
+ if rank is None:
+ rank = torch.distributed.get_rank()
+ if not (0 - len(_RANK_NUMBER_TO_MODEL_INDEX) <= rank < len(_RANK_NUMBER_TO_MODEL_INDEX)):
+ raise IndexError(f'{0 - len(_RANK_NUMBER_TO_MODEL_INDEX)=} <= {rank=} < {len(_RANK_NUMBER_TO_MODEL_INDEX)=}, '
+ f'check the config file and launch params')
+ return _RANK_NUMBER_TO_MODEL_INDEX[rank]
+
+
+def get_dist_global_model_index(rank: int = None) -> int:
+ name = get_dist_model_name(rank)
+ keys = _ALL_CONFIG.keys()
+ return list(keys).index(name)
+
+
+def is_use_multiparam_send_recv():
+ return _USE_MULTIPARAM_SEND_RECV
+
+
+def _read_json(json_path):
+ try:
+ with open(json_path, mode="r") as f:
+ json_file = f.read()
+ configs_list = json.loads(json_file)
+ return configs_list
+ except FileNotFoundError as e:
+ raise FileNotFoundError(f"The file {json_path} does not exist.") from e
+ except json.JSONDecodeError as e:
+ raise ValueError(f"The file {json_path} is not a valid JSON file.") from e
+ except Exception as e:
+ raise RuntimeError(f"An unexpected error occurred: {e}") from e
+
+
+def _check_config(config_dict):
+ if CK.MODEL_CONFIG not in config_dict.keys():
+ raise KeyError(f"The `{CK.MODEL_CONFIG}` key does not exist in DistTrain config.")
+ if CK.USE_MULTIPARAM_SEND_RECV in config_dict.keys() and not isinstance(config_dict[CK.USE_MULTIPARAM_SEND_RECV], bool):
+ raise TypeError(f"The `{CK.USE_MULTIPARAM_SEND_RECV}` value type must be bool.")
+ if CK.MODEL_NAME not in config_dict.keys():
+ raise KeyError(f"The `{CK.MODEL_NAME}` key does not exist in DistTrain config.")
+ if not isinstance(config_dict[CK.MODEL_NAME], str):
+ raise TypeError(f"The `{CK.MODEL_NAME}` value type must be string.")
+ global _SUPPORT_MODEL_NAME
+ if config_dict[CK.MODEL_NAME] not in _SUPPORT_MODEL_NAME:
+ raise ValueError(f"The `{CK.MODEL_NAME}` current not support.")
+ valid_keys = [CK.MODEL_CONFIG, CK.USE_MULTIPARAM_SEND_RECV, CK.MODEL_NAME]
+ invalid_keys = [key for key in config_dict.keys() if key not in valid_keys]
+ if invalid_keys:
+ raise KeyError(f"Get unexpected keywords: {invalid_keys}")
+ if not isinstance(config_dict[CK.MODEL_CONFIG], list):
+ raise TypeError(f"The `{CK.MODEL_CONFIG}` type must be list.")
+ if not config_dict[CK.MODEL_CONFIG]:
+ raise ValueError(f"The `{CK.MODEL_CONFIG}` must not be empty.")
+ global _ALL_DIST_MODEL_INDEX, _ALL_DIST_MODEL_NAME, _ALL_DIST_MODEL_CONFIG
+ _ALL_DIST_MODEL_INDEX = [config.get(CK.MODEL_INDEX) for config in config_dict[CK.MODEL_CONFIG]]
+ _ALL_DIST_MODEL_NAME = [config.get(CK.NAME) for config in config_dict[CK.MODEL_CONFIG]]
+ _ALL_DIST_MODEL_CONFIG = config_dict[CK.MODEL_CONFIG]
+ if not all(key in config.keys() for config in _ALL_DIST_MODEL_CONFIG for key in [CK.NAME, CK.WORLD_SIZE, CK.MODEL_INDEX]):
+ raise ValueError(f"At least three items must be configured: `{CK.NAME}`, `{CK.WORLD_SIZE}`, and `{CK.MODEL_INDEX}`.")
+ if not all(isinstance(name, str) for name in _ALL_DIST_MODEL_NAME):
+ raise TypeError(f"The `{CK.NAME}` value type must be str.")
+ if len(_ALL_DIST_MODEL_NAME) != len(set(_ALL_DIST_MODEL_NAME)):
+ raise ValueError(f"`{CK.NAME}` is duplicate in DistTrain config.")
+ if not all(name.isidentifier() for name in _ALL_DIST_MODEL_NAME):
+ raise ValueError(f"`{CK.NAME}` is not a valid string.")
+ valid_names = _SUPPORT_MODEL_NAME.get(config_dict[CK.MODEL_NAME])
+ if len(_ALL_DIST_MODEL_NAME) != len(valid_names):
+ raise ValueError(f"`{config_dict[CK.MODEL_NAME]}` model current only support {valid_names}.")
+ if not all(isinstance(index, int) for index in _ALL_DIST_MODEL_INDEX):
+ raise TypeError(f"The `{CK.MODEL_INDEX}` value type must be int.")
+ _ALL_DIST_MODEL_INDEX.sort()
+ if not all(_ALL_DIST_MODEL_INDEX[i] - _ALL_DIST_MODEL_INDEX[i - 1] == 1 for i in range(1, len(_ALL_DIST_MODEL_INDEX))):
+ raise ValueError(f"`{CK.MODEL_INDEX}` must be continuous.")
+
+ # 把model_index升序的name保存
+ combined = list(zip(_ALL_DIST_MODEL_INDEX, _ALL_DIST_MODEL_CONFIG))
+ combined.sort(key=lambda x: x[0])
+ _, _ALL_DIST_MODEL_CONFIG = list(zip(*combined))
+ if _ALL_DIST_MODEL_CONFIG[0][CK.MODEL_INDEX] < 0:
+ raise ValueError(f"`{CK.MODEL_INDEX}` must start from 0.")
+ if not all(name == valid for name, valid in zip(_ALL_DIST_MODEL_NAME, valid_names)):
+ raise ValueError(f"`{CK.NAME}` sequence is incorrect, {config_dict[CK.MODEL_NAME]} "
+ f"model name list strictly follow the sequence [{valid_names}].")
+ if not all(
+ isinstance(config.get(CK.MAIN_DP), bool)
+ for config in _ALL_DIST_MODEL_CONFIG
+ if CK.MAIN_DP in config
+ ):
+ raise TypeError(f"The `{CK.MAIN_DP}` value type must be bool.")
+ if sum(1 for config in _ALL_DIST_MODEL_CONFIG if config.get(CK.MAIN_DP, False)) > 1:
+ raise ValueError(f"Only one `{CK.MAIN_DP}` can be true.")
+
+
+def _set_config(config_dict):
+ _check_config(config_dict)
+ global _NUMBER_OF_MODELS, _ALL_DIST_MODEL_CONFIG
+ _NUMBER_OF_MODELS = len(_ALL_DIST_MODEL_CONFIG)
+ config_dict[CK.MODEL_CONFIG] = _ALL_DIST_MODEL_CONFIG
+ # Save the config in ascending order by name.
+ for k, v in config_dict.items():
+ if k == CK.USE_MULTIPARAM_SEND_RECV:
+ global _USE_MULTIPARAM_SEND_RECV
+ _USE_MULTIPARAM_SEND_RECV = v
+ elif k == CK.MODEL_CONFIG:
+ global _ALL_CONFIG, _RANK_NUMBER_TO_MODEL_NAME, _RANK_NUMBER_TO_MODEL_INDEX
+ for model_config in v: # v == [{}, {}, {}, ...]
+ _ALL_CONFIG[model_config.get(CK.NAME)] = ModelConfig(model_config, len(_RANK_NUMBER_TO_MODEL_INDEX))
+ _RANK_NUMBER_TO_MODEL_INDEX.extend([model_config.get(CK.MODEL_INDEX)] * model_config.get(CK.WORLD_SIZE))
+ _RANK_NUMBER_TO_MODEL_NAME.extend([model_config.get(CK.NAME)] * model_config.get(CK.WORLD_SIZE))
+ print(f"{_ALL_CONFIG=}\n{_RANK_NUMBER_TO_MODEL_NAME=}\n{_RANK_NUMBER_TO_MODEL_INDEX=}")
+
+
+def _clear_dist_config():
+ global _ALL_CONFIG, _RANK_NUMBER_TO_MODEL_NAME, _RANK_NUMBER_TO_MODEL_INDEX, _NUMBER_OF_MODELS, \
+ _USE_MULTIPARAM_SEND_RECV, _ALL_DIST_MODEL_INDEX, _ALL_DIST_MODEL_NAME, _ALL_DIST_MODEL_CONFIG
+ _ALL_CONFIG = {}
+ _RANK_NUMBER_TO_MODEL_NAME = []
+ _RANK_NUMBER_TO_MODEL_INDEX = []
+ _NUMBER_OF_MODELS = 0
+ _USE_MULTIPARAM_SEND_RECV = False
+ _ALL_DIST_MODEL_INDEX = []
+ _ALL_DIST_MODEL_NAME = []
+ _ALL_DIST_MODEL_CONFIG = []
+
+
+def merge_dist_train_args(path):
+ real_path = os.path.realpath(path)
+ if real_path.endswith(".json"): # MindSpeed-MM use json config
+ config = _read_json(real_path)
+ if isinstance(config, dict):
+ config = config.get(CK.DIST_CONFIG, {})
+ else:
+ raise ValueError('Unexpected json file, not contain dist_config dict data.')
+ else:
+ raise TypeError("Unexpected file type.")
+ _clear_dist_config()
+ _set_config(config)
+
+
+def is_forward_only_model(name: str = None, rank: int = None, global_index: int = None):
+ return get_dist_model_config(name, rank, global_index)[CK.FORWARD_ONLY]
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/__init__.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/inner_data_parallel.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/inner_data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..972f42e92edf59a0919f3396b80e2b7a32852fc3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/inner_data_parallel.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+from .. import parallel_state as dist_ps
+
+
+@dist_ps.subwrold_decorator
+def get_inner_data_parallel_group():
+ """Get the inner data parallel group the caller rank belongs to."""
+ if dist_ps._INNER_DATA_PARALLEL_GROUP is None:
+ raise RuntimeError('inner data parallel group is not initialized')
+ return dist_ps._INNER_DATA_PARALLEL_GROUP
+
+
+@dist_ps.subwrold_decorator
+def get_inner_data_parallel_world_size():
+ """Return world size for the inner data parallel group."""
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ return torch.distributed.get_world_size(
+ group=get_inner_data_parallel_group()
+ )
+ else:
+ return 0
+
+
+@dist_ps.subwrold_decorator
+def get_inner_data_parallel_rank():
+ """Return my rank for the inner data parallel group."""
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ return torch.distributed.get_rank(
+ group=get_inner_data_parallel_group()
+ )
+ else:
+ return 0
+
+
+def get_inner_data_parallel_src_rank():
+ """Calculate the global rank corresponding to the first local rank in the inner data parallel group."""
+ if dist_ps._CUR_SUB_WORLD is None:
+ return 0
+ global_rank = (torch.distributed.get_rank() - dist_ps._CUR_SUB_WORLD.start_rank)
+ local_world_size = get_inner_data_parallel_world_size()
+ return (global_rank // local_world_size) * local_world_size + dist_ps._CUR_SUB_WORLD.start_rank
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/mappings.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/mappings.py
new file mode 100644
index 0000000000000000000000000000000000000000..717bd77a037ee0b0c383cd2f72c1a65019901dba
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/mappings.py
@@ -0,0 +1,83 @@
+# Copied from Megatron-LM: https://github.com/NVIDIA/Megatron-LM
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import torch
+from mindspeed.core.tensor_parallel.comm_utils import (
+ _split_along_first_dim,
+ sync_gather_along_first_dim,
+ sync_reduce_scatter_along_first_dim
+)
+from mindspeed.core.tensor_parallel.comm_group_api import CollectiveCommIntf
+from .inner_data_parallel import (
+ get_inner_data_parallel_group,
+ get_inner_data_parallel_world_size,
+ get_inner_data_parallel_rank,
+)
+
+
+def gather_from_inner_dp_region(input_, inner_dp_parallel_output_grad=True):
+ return _GatherFromInnerDataParallelRegion.apply(input_, inner_dp_parallel_output_grad)
+
+
+class _GatherFromInnerDataParallelRegion(torch.autograd.Function):
+ """Gather the input from sequence parallel region and concatinate."""
+
+ @staticmethod
+ def symbolic(graph, input_, inner_dp_parallel_output_grad=True):
+ return sync_gather_along_first_dim(input_, InnerDPCollectiveComm)
+
+ @staticmethod
+ def forward(ctx, input_, inner_dp_parallel_output_grad=True):
+ ctx.inner_dp_parallel_output_grad = inner_dp_parallel_output_grad
+ return sync_gather_along_first_dim(input_, InnerDPCollectiveComm)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ inner_dp_parallel_output_grad = ctx.inner_dp_parallel_output_grad
+
+ # If the computation graph after the gather operation is
+ # in the tensor parallel mode, output gradients need to reduce
+ # scattered and whereas if the computation is duplicated,
+ # output gradients need to be scattered.
+ if inner_dp_parallel_output_grad:
+ return sync_reduce_scatter_along_first_dim(grad_output, InnerDPCollectiveComm), None
+ else:
+ return _split_along_first_dim(grad_output, InnerDPCollectiveComm), None
+
+
+class InnerDPCollectiveComm(CollectiveCommIntf):
+ def __init__(self, name='inner-dp'):
+ super().__init__(name)
+
+ @classmethod
+ def get_comm_rank(cls):
+ return get_inner_data_parallel_rank()
+
+ @classmethod
+ def get_comm_group_world_size(cls):
+ return get_inner_data_parallel_world_size()
+
+ @classmethod
+ def get_comm_group(cls):
+ return get_inner_data_parallel_group()
+
+
+def split_data(data: torch.Tensor, padding_val: int = 0):
+ data_num = data.shape[0]
+ dp_size = get_inner_data_parallel_world_size()
+ data_pad_num = (data_num + dp_size - 1) // dp_size * dp_size
+ data_shape = tuple(data.shape[1:])
+
+ # Flatten the data associated with the keys
+ flatten_data = data.contiguous().view(-1, *data_shape).cuda()
+
+ if padding_val != 0:
+ flatten_data_pad = torch.full((data_pad_num, *data_shape), padding_val, device=torch.cuda.current_device())
+ else:
+ flatten_data_pad = torch.zeros((data_pad_num, *data_shape), device=torch.cuda.current_device())
+
+ flatten_data_pad[:data_num, ...] = flatten_data
+ split_data_list = list(torch.chunk(flatten_data_pad, dp_size, dim=0))
+ data = split_data_list[get_inner_data_parallel_rank()]
+
+ return data
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/utils.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b924b6a402a1cf01594e40a5d5e7d8a83d7d7afd
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/inner_data_parallel/utils.py
@@ -0,0 +1,26 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+from megatron.core.parallel_state import get_data_parallel_world_size
+from ..config.dist_train_config import get_all_config
+from ..parallel_state import is_in_subworld
+
+
+def need_inner_data_parallel():
+ # only support for vit model
+ if not is_in_subworld("vit"):
+ return False
+ if not any(cfg.main_dp for cfg in get_all_config().values()):
+ return False
+ return get_data_parallel_world_size() // get_global_data_parallel_size() > 1
+
+
+def get_global_data_parallel_size():
+ """Return main data parallel size"""
+ all_config = get_all_config()
+ global_dp_size = 0
+ for config in all_config.values():
+ if config.main_dp:
+ global_dp_size = config.world_size // \
+ (config.tensor_model_parallel_size * config.pipeline_model_parallel_size * config.context_parallel_size)
+ if global_dp_size == 0:
+ raise AssertionError("No Main DP")
+ return global_dp_size
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/parallel_state.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/parallel_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..7457a7a77fcd42e8019ea43753b7152cc4169b9e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/parallel_state.py
@@ -0,0 +1,1260 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+
+"""Model and data parallel groups."""
+import os
+from datetime import timedelta
+from typing import Optional, Sequence
+from functools import wraps
+import torch
+from torch._C._distributed_c10d import ProcessGroup
+from megatron.core.parallel_state import RankGenerator
+from megatron.core.utils import GlobalMemoryBuffer
+from megatron.training import get_args
+from .config.dist_train_config import get_dist_model_config, get_all_config_size, get_all_config
+
+# Current subworld, adapts to the situation when different model shares one rank
+_CUR_SUB_WORLD = None
+ALL_SUB_WORLD = {}
+
+# Intra-layer model parallel group that the current rank belongs to.
+_TENSOR_MODEL_PARALLEL_GROUP = None
+# Inter-layer model parallel group that the current rank belongs to.
+_PIPELINE_MODEL_PARALLEL_GROUP = None
+# Model parallel group (both intra- and pipeline) that the current rank belongs to.
+_MODEL_PARALLEL_GROUP = None
+# Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to.
+_MODEL_AND_EXPERT_PARALLEL_GROUP = None
+# Embedding group.
+_EMBEDDING_GROUP = None
+# Position embedding group.
+_POSITION_EMBEDDING_GROUP = None
+# Data parallel group that the current rank belongs to.
+_DATA_PARALLEL_GROUP = None
+_DATA_PARALLEL_GROUP_GLOO = None
+# tensor model parallel group and data parallel group combined
+# used for fp8 and moe training
+_TENSOR_AND_DATA_PARALLEL_GROUP = None
+# Expert parallel group that the current rank belongs to.
+_EXPERT_MODEL_PARALLEL_GROUP = None
+_TENSOR_AND_EXPERT_PARALLEL_GROUP = None
+_DATA_MODULO_EXPERT_PARALLEL_GROUP = None
+_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None
+_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = None
+_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None
+
+_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
+_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
+_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
+
+# These values enable us to change the mpu sizes on the fly.
+_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
+_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
+_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None
+_MPU_TENSOR_MODEL_PARALLEL_RANK = None
+_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
+_MPU_EXPERT_MODEL_PARALLEL_RANK = None
+
+# A list of ranks that have a copy of the embedding.
+_EMBEDDING_GLOBAL_RANKS = None
+
+# A list of ranks that have a copy of the position embedding.
+_POSITION_EMBEDDING_GLOBAL_RANKS = None
+
+# A list of global ranks for each pipeline group to ease calculation of the source
+# rank when broadcasting from the first or last pipeline stage.
+_PIPELINE_GLOBAL_RANKS = None
+
+# A list of global ranks for each data parallel group to ease calculation of the source
+# rank when broadcasting weights from src to all other data parallel ranks
+_DATA_PARALLEL_GLOBAL_RANKS = None
+
+# A list of global ranks for each tensor model parallel group to ease calculation of
+# the first local rank in the tensor model parallel group
+_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
+
+# Context parallel group that the current rank belongs to
+_CONTEXT_PARALLEL_GROUP = None
+# A list of global ranks for each context parallel group to ease calculation of the
+# destination rank when exchanging KV/dKV between context parallel_ranks
+_CONTEXT_PARALLEL_GLOBAL_RANKS = None
+
+# Data parallel group information with context parallel combined.
+_DATA_PARALLEL_GROUP_WITH_CP = None
+_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
+_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
+
+# combined parallel group of TP and CP
+_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None
+
+# combined parallel group of TP, DP, and CP used for fp8
+_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
+
+# inner data parallel group
+_INNER_DATA_PARALLEL_GROUP = None
+# Memory buffers to avoid dynamic memory allocation
+_GLOBAL_MEMORY_BUFFER = None
+
+# MOE logging
+_MOE_LAYER_WISE_LOGGING_TRACKER = {}
+
+
+class DetachedSubWorld:
+ def __init__(self, name: str, start_rank, ranks: list):
+ self.name = name
+ self.ranks = ranks
+ self.start_rank = start_rank
+
+ # Intra-layer model parallel group that the current rank belongs to.
+ self.tensor_model_parallel_group = None
+ # Inter-layer model parallel group that the current rank belongs to.
+ self.pipeline_model_parallel_group = None
+ # Model parallel group (both intra- and pipeline) that the current rank belongs to.
+ self.model_parallel_group = None
+ # Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to.
+ self.model_and_expert_parallel_group = None
+ # Embedding group.
+ self.embedding_group = None
+ # Position embedding group.
+ self.position_embedding_group = None
+ # Data parallel group that the current rank belongs to.
+ self.data_parallel_group = None
+ self.data_parallel_group_gloo = None
+ # tensor model parallel group and data parallel group combined
+ # used for fp8 and moe training
+ self.tensor_and_data_parallel_group = None
+ # Expert parallel group that the current rank belongs to.
+ self.expert_model_parallel_group = None
+ self.tensor_and_expert_parallel_group = None
+ self.data_modulo_expert_parallel_group = None
+ self.data_modulo_expert_parallel_group_gloo = None
+ self.data_modulo_expert_parallel_group_with_cp = None
+ self.data_modulo_expert_parallel_group_with_cp_gloo = None
+
+ self.virtual_pipeline_model_parallel_rank = None
+ self.virtual_pipeline_model_parallel_world_size = None
+ self.pipeline_model_parallel_split_rank = None
+
+ # These values enable us to change the mpu sizes on the fly.
+ self.mpu_tensor_model_parallel_world_size = None
+ self.mpu_pipeline_model_parallel_world_size = None
+ self.mpu_expert_model_parallel_world_size = None
+ self.mpu_tensor_model_parallel_rank = None
+ self.mpu_pipeline_model_parallel_rank = None
+ self.mpu_expert_model_parallel_rank = None
+
+ # A list of ranks that have a copy of the embedding.
+ self.embedding_global_ranks = None
+
+ # A list of ranks that have a copy of the position embedding.
+ self.position_embedding_global_ranks = None
+
+ # A list of global ranks for each pipeline group to ease calculation of the source
+ # rank when broadcasting from the first or last pipeline stage.
+ self.pipeline_global_ranks = None
+
+ # A list of global ranks for each data parallel group to ease calculation of the source
+ # rank when broadcasting weights from src to all other data parallel ranks
+ self.data_parallel_global_ranks = None
+
+ # A list of global ranks for each tensor model parallel group to ease calculation of
+ # the first local rank in the tensor model parallel group
+ self.tensor_model_parallel_global_ranks = None
+
+ # Context parallel group that the current rank belongs to
+ self.context_parallel_group = None
+ # A list of global ranks for each context parallel group to ease calculation of the
+ # destination rank when exchanging KV/dKV between context parallel_ranks
+ self.context_parallel_global_ranks = None
+
+ # Data parallel group information with context parallel combined.
+ self.data_parallel_group_with_cp = None
+ self.data_parallel_group_with_cp_gloo = None
+ self.data_parallel_global_ranks_with_cp = None
+
+ # combined parallel group of TP and CP
+ self.tensor_and_context_parallel_group = None
+
+ # combined parallel group of TP, DP, and CP used for fp8
+ self.tensor_and_data_parallel_group_with_cp = None
+
+ # inner data parallel group
+ self.inner_data_parallel_group = None
+ # Memory buffers to avoid dynamic memory allocation
+ self.global_memory_buffer = None
+
+ # MOE logging
+ self.moe_layer_wise_logging_tracker = {}
+
+ def __repr__(self):
+ repr_str = ""
+
+ print_keys = {"name": "model",
+ "pipeline_model_parallel_group": "PP_RANKS",
+ "tensor_model_parallel_group": "TP_RANKS",
+ "data_parallel_group": "DP_RANKS",
+ "context_parallel_group": "CP_RANKS",
+ "tensor_and_data_parallel_group": "TP_DP_RANKS",
+ "tensor_and_expert_parallel_group": "TP_EP_RANKS"}
+
+ for name, value in vars(self).items():
+ if name not in print_keys:
+ continue
+ else:
+ name = print_keys[name]
+
+ repr_str += f"{name}="
+ if isinstance(value, range):
+ repr_str += f"{list(value)},"
+ elif isinstance(value, ProcessGroup):
+ if value is not None:
+ repr_str += f"{torch.distributed.get_process_group_ranks(value)},"
+ else:
+ repr_str += f"{value},"
+ else:
+ repr_str += f"{value},"
+
+ return repr_str
+
+
+def reset_global_group_and_ranks():
+ # create an empty subworld, then use its members' default value to reset global group and ranks
+ empty_subworld = DetachedSubWorld("empty_subworld", 0, [0])
+ set_global_group_and_ranks_by_subworld(empty_subworld)
+
+
+def set_global_group_and_ranks_by_subworld(subworld: DetachedSubWorld):
+ global _TENSOR_MODEL_PARALLEL_GROUP
+ global _PIPELINE_MODEL_PARALLEL_GROUP
+ global _MODEL_PARALLEL_GROUP
+ global _MODEL_AND_EXPERT_PARALLEL_GROUP
+ global _EMBEDDING_GROUP
+ global _POSITION_EMBEDDING_GROUP
+ global _DATA_PARALLEL_GROUP
+ global _DATA_PARALLEL_GROUP_GLOO
+ global _TENSOR_AND_DATA_PARALLEL_GROUP
+ global _EXPERT_MODEL_PARALLEL_GROUP
+ global _TENSOR_AND_EXPERT_PARALLEL_GROUP
+ global _DATA_MODULO_EXPERT_PARALLEL_GROUP
+ global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
+ global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
+ global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
+ global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
+ global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
+ global _MPU_TENSOR_MODEL_PARALLEL_RANK
+ global _MPU_PIPELINE_MODEL_PARALLEL_RANK
+ global _MPU_EXPERT_MODEL_PARALLEL_RANK
+ global _EMBEDDING_GLOBAL_RANKS
+ global _POSITION_EMBEDDING_GLOBAL_RANKS
+ global _PIPELINE_GLOBAL_RANKS
+ global _DATA_PARALLEL_GLOBAL_RANKS
+ global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
+ global _CONTEXT_PARALLEL_GROUP
+ global _CONTEXT_PARALLEL_GLOBAL_RANKS
+ global _DATA_PARALLEL_GROUP_WITH_CP
+ global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
+ global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
+ global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
+ global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
+ global _INNER_DATA_PARALLEL_GROUP
+ global _GLOBAL_MEMORY_BUFFER
+ global _MOE_LAYER_WISE_LOGGING_TRACKER
+
+ # Intra-layer model parallel group that the current rank belongs to.
+ _TENSOR_MODEL_PARALLEL_GROUP = subworld.tensor_model_parallel_group
+ # Inter-layer model parallel group that the current rank belongs to.
+ _PIPELINE_MODEL_PARALLEL_GROUP = subworld.pipeline_model_parallel_group
+ # Model parallel group (both intra- and pipeline) that the current rank belongs to.
+ _MODEL_PARALLEL_GROUP = subworld.model_parallel_group
+ # Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to.
+ _MODEL_AND_EXPERT_PARALLEL_GROUP = subworld.model_and_expert_parallel_group
+ # Embedding group.
+ _EMBEDDING_GROUP = subworld.embedding_group
+ # Position embedding group.
+ _POSITION_EMBEDDING_GROUP = subworld.position_embedding_group
+ # Data parallel group that the current rank belongs to.
+ _DATA_PARALLEL_GROUP = subworld.data_parallel_group
+ _DATA_PARALLEL_GROUP_GLOO = subworld.data_parallel_group_gloo
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = subworld.data_modulo_expert_parallel_group_with_cp
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = subworld.data_modulo_expert_parallel_group_with_cp_gloo
+ # tensor model parallel group and data parallel group combined
+ # used for fp8 and moe training
+ _TENSOR_AND_DATA_PARALLEL_GROUP = subworld.tensor_and_data_parallel_group
+ # Expert parallel group that the current rank belongs to.
+ _EXPERT_MODEL_PARALLEL_GROUP = subworld.expert_model_parallel_group
+ _TENSOR_AND_EXPERT_PARALLEL_GROUP = subworld.tensor_and_expert_parallel_group
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP = subworld.data_modulo_expert_parallel_group
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = subworld.data_modulo_expert_parallel_group_gloo
+
+ _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = subworld.virtual_pipeline_model_parallel_rank
+ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = subworld.virtual_pipeline_model_parallel_world_size
+ _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = subworld.pipeline_model_parallel_split_rank
+
+ # These values enable us to change the mpu sizes on the fly.
+ _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = subworld.mpu_tensor_model_parallel_world_size
+ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = subworld.mpu_pipeline_model_parallel_world_size
+ _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = subworld.mpu_expert_model_parallel_world_size
+ _MPU_TENSOR_MODEL_PARALLEL_RANK = subworld.mpu_tensor_model_parallel_rank
+ _MPU_PIPELINE_MODEL_PARALLEL_RANK = subworld.mpu_pipeline_model_parallel_rank
+ _MPU_EXPERT_MODEL_PARALLEL_RANK = subworld.mpu_expert_model_parallel_rank
+
+ # A list of ranks that have a copy of the embedding.
+ _EMBEDDING_GLOBAL_RANKS = subworld.embedding_global_ranks
+
+ # A list of ranks that have a copy of the position embedding.
+ _POSITION_EMBEDDING_GLOBAL_RANKS = subworld.position_embedding_global_ranks
+
+ # A list of global ranks for each pipeline group to ease calculation of the source
+ # rank when broadcasting from the first or last pipeline stage.
+ _PIPELINE_GLOBAL_RANKS = subworld.pipeline_global_ranks
+
+ # A list of global ranks for each data parallel group to ease calculation of the source
+ # rank when broadcasting weights from src to all other data parallel ranks
+ _DATA_PARALLEL_GLOBAL_RANKS = subworld.data_parallel_global_ranks
+
+ # A list of global ranks for each tensor model parallel group to ease calculation of
+ # the first local rank in the tensor model parallel group
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = subworld.tensor_model_parallel_global_ranks
+
+ # Context parallel group that the current rank belongs to
+ _CONTEXT_PARALLEL_GROUP = subworld.context_parallel_group
+ # A list of global ranks for each context parallel group to ease calculation of the
+ # destination rank when exchanging KV/dKV between context parallel_ranks
+ _CONTEXT_PARALLEL_GLOBAL_RANKS = subworld.context_parallel_global_ranks
+
+ # Data parallel group information with context parallel combined.
+ _DATA_PARALLEL_GROUP_WITH_CP = subworld.data_parallel_group_with_cp
+ _DATA_PARALLEL_GROUP_WITH_CP_GLOO = subworld.data_parallel_group_with_cp_gloo
+ _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = subworld.data_parallel_global_ranks_with_cp
+
+ # combined parallel group of TP and CP
+ _TENSOR_AND_CONTEXT_PARALLEL_GROUP = subworld.tensor_and_context_parallel_group
+
+ # combined parallel group of TP, DP, and CP used for fp8
+ _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = subworld.tensor_and_data_parallel_group_with_cp
+
+ # inner data parallel group
+ _INNER_DATA_PARALLEL_GROUP = subworld.inner_data_parallel_group
+
+ # Memory buffers to avoid dynamic memory allocation
+ _GLOBAL_MEMORY_BUFFER = subworld.global_memory_buffer
+
+ # MOE logging
+ _MOE_LAYER_WISE_LOGGING_TRACKER = subworld.moe_layer_wise_logging_tracker
+
+
+def get_nccl_options(pg_name, nccl_comm_cfgs):
+ """Set the NCCL process group options.
+
+ Args:
+ pg_name (str): process group name
+ nccl_comm_cfgs (dict): nccl communicator configurations
+
+ When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
+ """
+ if pg_name in nccl_comm_cfgs:
+ nccl_options = torch.distributed.ProcessGroupNCCL.Options()
+ nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
+ nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
+ nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
+ return nccl_options
+ else:
+ return None
+
+
+def is_last_rank():
+ global _CUR_SUB_WORLD
+ rank = torch.distributed.get_rank()
+ if _CUR_SUB_WORLD is None:
+ raise RuntimeError('_CUR_SUB_WORLD should not be None')
+ if rank == _CUR_SUB_WORLD.ranks[-1]:
+ return True
+ return False
+
+
+def _initialize_model_parallel(
+ tensor_model_parallel_size: int = 1,
+ pipeline_model_parallel_size: int = 1,
+ virtual_pipeline_model_parallel_size: Optional[int] = None,
+ pipeline_model_parallel_split_rank: Optional[int] = None,
+ use_sharp: bool = False,
+ context_parallel_size: int = 1,
+ expert_model_parallel_size: int = 1,
+ nccl_communicator_config_path: Optional[str] = None,
+ distributed_timeout_minutes: int = 30,
+ order: str = "tp-cp-ep-dp-pp",
+ subworld: DetachedSubWorld = None
+):
+ # Get world size and rank. Ensure some consistencies.
+ tp_ranks = []
+ pp_ranks = []
+ if subworld is None:
+ return pp_ranks, tp_ranks
+
+ if not torch.distributed.is_initialized():
+ raise RuntimeError('Distributed is not initialized.')
+ world_size: int = torch.distributed.get_world_size()
+ sub_world_size = len(subworld.ranks)
+ if sub_world_size > world_size:
+ raise RuntimeError(f"world_size ({world_size}) is less than sub_world_size ({sub_world_size})")
+ world_size = sub_world_size
+ reset_global_group_and_ranks()
+
+ def adjust_rank(ranks_: Sequence):
+ for i_, _ in enumerate(ranks_):
+ ranks_[i_] += subworld.start_rank
+ return ranks_
+
+ if (
+ world_size
+ % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size)
+ != 0
+ ):
+ raise RuntimeError(
+ f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
+ f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) "
+ f"x context_parallel_size ({context_parallel_size})"
+ )
+
+ data_parallel_size: int = world_size // (
+ tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
+ )
+
+ if data_parallel_size % expert_model_parallel_size != 0:
+ raise RuntimeError(
+ f"data_parallel_size ({data_parallel_size}) is not divisible by expert_model_parallel_size "
+ )
+
+ if virtual_pipeline_model_parallel_size is not None:
+ if not pipeline_model_parallel_size > 1:
+ raise RuntimeError(
+ "pipeline-model-parallel size should be greater than 1 with interleaved schedule"
+ )
+ subworld.virtual_pipeline_model_parallel_rank = 0
+ subworld.virtual_pipeline_model_parallel_world_size = virtual_pipeline_model_parallel_size
+
+ if pipeline_model_parallel_split_rank is not None:
+ subworld.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank
+
+ rank = torch.distributed.get_rank()
+
+ nccl_comm_cfgs = {}
+ if nccl_communicator_config_path is not None:
+ try:
+ import yaml
+ except ImportError:
+ raise RuntimeError(
+ "Cannot import `yaml`. Setting custom nccl communicator configs "
+ "requires the yaml package."
+ )
+
+ with open(nccl_communicator_config_path, "r") as stream:
+ nccl_comm_cfgs = yaml.safe_load(stream)
+
+ rank_generator = RankGenerator(
+ tp=tensor_model_parallel_size,
+ ep=expert_model_parallel_size,
+ dp=data_parallel_size,
+ pp=pipeline_model_parallel_size,
+ cp=context_parallel_size,
+ order=order,
+ )
+ timeout = timedelta(minutes=distributed_timeout_minutes)
+
+ # Build the data-parallel groups.
+ assert subworld.data_parallel_group is None, 'data parallel group is already initialized'
+
+ for ranks in rank_generator.get_ranks('dp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
+ )
+ group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
+ if rank in ranks:
+ subworld.data_parallel_group = group
+ subworld.data_parallel_group_gloo = group_gloo
+ subworld.data_parallel_global_ranks = ranks
+ for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
+ ranks_with_cp = adjust_rank(ranks_with_cp)
+ group_with_cp = torch.distributed.new_group(
+ ranks_with_cp, timeout=timeout, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs)
+ )
+ group_with_cp_gloo = torch.distributed.new_group(
+ ranks_with_cp, timeout=timeout, backend="gloo"
+ )
+ if rank in ranks_with_cp:
+ subworld.data_parallel_group_with_cp = group_with_cp
+ subworld.data_parallel_group_with_cp_gloo = group_with_cp_gloo
+ subworld.data_parallel_global_ranks_with_cp = ranks_with_cp
+
+ # Apply SHARP to DP process groups
+ if use_sharp:
+ if rank == 0:
+ print(
+ "The number of process groups to use SHARP with depends on the type "
+ "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
+ "process groups and QM2 supports up to 256 process groups. We apply "
+ "SHARP to the communications of the data-parallel domain. If the "
+ "number of data-parallel process groups is larger than the max "
+ "process groups that the network switch supports, the communication "
+ "will fall back to non-SHARP operators. To enable SHARP, "
+ "`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
+ )
+ torch.distributed.barrier(
+ group=get_data_parallel_group(with_context_parallel=True),
+ device_ids=[torch.cuda.current_device()],
+ )
+ # Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to DP process groups
+ os.environ["NCCL_COLLNET_ENABLE"] = "0"
+
+ # Build the context-parallel groups.
+ assert subworld.context_parallel_group is None, 'context parallel group is already initialized'
+ for ranks in rank_generator.get_ranks('cp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('cp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.context_parallel_group = group
+ subworld.context_parallel_global_ranks = ranks
+
+ # Build the model-parallel groups.
+ assert subworld.model_parallel_group is None, 'model parallel group is already initialized'
+ for ranks in rank_generator.get_ranks('tp-pp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.model_parallel_group = group
+
+ # Build the model-parallel groups with expert parallel
+ assert subworld.model_and_expert_parallel_group is None, 'model and expert parallel group is already initialized'
+ for ranks in rank_generator.get_ranks('tp-ep-pp', independent_ep=True):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('mp_exp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.model_and_expert_parallel_group = group
+
+ # Build the tensor model-parallel groups.
+ assert subworld.tensor_model_parallel_group is None, 'tensor model parallel group is already initialized'
+ for ranks in rank_generator.get_ranks('tp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.tensor_model_parallel_group = group
+ subworld.tensor_model_parallel_global_ranks = ranks
+
+ # Build the pipeline model-parallel groups and embedding groups
+ # (first and last rank in each pipeline model-parallel group).
+ assert subworld.pipeline_model_parallel_group is None, 'pipeline model parallel group is already initialized'
+ assert subworld.embedding_group is None, 'embedding group is already initialized'
+ assert subworld.position_embedding_group is None, 'position embedding group is already initialized'
+ for ranks in rank_generator.get_ranks('pp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('pp', nccl_comm_cfgs)
+ )
+ pp_ranks.append(list(ranks))
+ if rank in ranks:
+ subworld.pipeline_model_parallel_group = group
+ subworld.pipeline_global_ranks = ranks
+ # Setup embedding group (to exchange gradients between
+ # first and last stages).
+ if len(ranks) > 1:
+ embedding_ranks = [ranks[0], ranks[-1]]
+ position_embedding_ranks = [ranks[0]]
+ if pipeline_model_parallel_split_rank is not None:
+ if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
+ embedding_ranks = [
+ ranks[0],
+ ranks[pipeline_model_parallel_split_rank],
+ ranks[-1],
+ ]
+ if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
+ position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
+ else:
+ embedding_ranks = ranks
+ position_embedding_ranks = ranks
+
+ group = torch.distributed.new_group(
+ embedding_ranks, timeout=timeout, pg_options=get_nccl_options('embd', nccl_comm_cfgs)
+ )
+ if rank in embedding_ranks:
+ subworld.embedding_group = group
+ if rank in ranks:
+ subworld.embedding_global_ranks = embedding_ranks
+
+ group = torch.distributed.new_group(
+ position_embedding_ranks,
+ timeout=timeout,
+ pg_options=get_nccl_options('embd', nccl_comm_cfgs),
+ )
+ if rank in position_embedding_ranks:
+ subworld.position_embedding_group = group
+ if rank in ranks:
+ subworld.position_embedding_global_ranks = position_embedding_ranks
+
+ # Build the tensor + data parallel groups.
+ assert subworld.tensor_and_data_parallel_group is None, 'Tensor + data parallel group is already initialized'
+ for ranks in rank_generator.get_ranks('tp-dp-cp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.tensor_and_data_parallel_group_with_cp = group
+ for ranks in rank_generator.get_ranks('tp-dp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs)
+ )
+ tp_ranks.append(list(ranks))
+ if rank in ranks:
+ subworld.tensor_and_data_parallel_group = group
+
+ assert subworld.tensor_and_context_parallel_group is None, 'Tensor + context parallel group is already initialized'
+ for ranks in rank_generator.get_ranks('tp-cp'):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('tp_cp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.tensor_and_context_parallel_group = group
+
+ # Build the tensor + expert parallel groups
+ assert subworld.expert_model_parallel_group is None, 'Expert parallel group is already initialized'
+ assert subworld.tensor_and_expert_parallel_group is None, 'Tensor + expert parallel group is already initialized'
+ assert subworld.data_modulo_expert_parallel_group is None, 'Data modulo expert group is already initialized'
+ assert (
+ subworld.data_modulo_expert_parallel_group_with_cp is None
+ ), 'Data modulo expert group with context parallel is already initialized'
+
+ for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.tensor_and_expert_parallel_group = group
+
+ for ranks in rank_generator.get_ranks('ep', independent_ep=True):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.expert_model_parallel_group = group
+
+ for ranks in rank_generator.get_ranks('dp', independent_ep=True):
+ ranks = adjust_rank(ranks)
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('dp_modulo_exp', nccl_comm_cfgs)
+ )
+ group_gloo = torch.distributed.new_group(ranks, backend="gloo")
+ if rank in ranks:
+ subworld.data_modulo_expert_parallel_group = group
+ subworld.data_modulo_expert_parallel_group_gloo = group_gloo
+
+ for ranks in rank_generator.get_ranks('dp-cp', independent_ep=True):
+ # Lazy initialization of the group
+ ranks = adjust_rank(ranks)
+ cp_world_size = torch.distributed.get_world_size(subworld.context_parallel_group)
+ if cp_world_size > 1:
+ group = torch.distributed.new_group(
+ ranks,
+ timeout=timeout,
+ pg_options=get_nccl_options('dp_modulo_exp_cp', nccl_comm_cfgs),
+ )
+ group_gloo = torch.distributed.new_group(ranks, backend="gloo")
+ else:
+ group = subworld.data_modulo_expert_parallel_group
+ group_gloo = subworld.data_modulo_expert_parallel_group_gloo
+ if rank in ranks:
+ subworld.data_modulo_expert_parallel_group_with_cp = group
+ subworld.data_modulo_expert_parallel_group_with_cp_gloo = group_gloo
+
+ if any(cfg.main_dp for cfg in get_all_config().values()):
+ from .inner_data_parallel.utils import get_global_data_parallel_size
+ if subworld.inner_data_parallel_group is not None:
+ raise RuntimeError('inner dp model parallel group is already initialized')
+ if get_global_data_parallel_size() > data_parallel_size:
+ raise RuntimeError(f'global dp size ({get_global_data_parallel_size()}) should smaller than or equals to '
+ f'subworld dp size ({data_parallel_size})')
+ inner_dp_size = data_parallel_size // get_global_data_parallel_size()
+ for i in range(world_size // inner_dp_size):
+ start_rank = i * inner_dp_size
+ end_rank = (i + 1) * inner_dp_size
+ ranks = adjust_rank(list(range(start_rank, end_rank)))
+ group = torch.distributed.new_group(
+ ranks, timeout=timeout, pg_options=get_nccl_options('inner_dp', nccl_comm_cfgs)
+ )
+ if rank in ranks:
+ subworld.inner_data_parallel_group = group
+ # Initialize global memory buffer
+ # This isn't really "parallel state" but there isn't another good place to
+ # put this. If we end up with a more generic initialization of megatron-core
+ # we could stick it there
+ _set_global_memory_buffer(subworld=subworld)
+
+ # append to all sub world list
+ global ALL_SUB_WORLD
+ if rank in subworld.ranks:
+ reset_global_group_and_ranks()
+ set_global_group_and_ranks_by_subworld(subworld=subworld)
+ ALL_SUB_WORLD[subworld.name] = subworld
+ print(f"rank={rank},{subworld}")
+ return pp_ranks, tp_ranks
+
+
+def initialize_model_parallel(*args, **kwargs) -> None:
+ global _CUR_SUB_WORLD, ALL_SUB_WORLD
+ _CUR_SUB_WORLD = None
+ ALL_SUB_WORLD = {}
+ world_size: int = torch.distributed.get_world_size()
+ all_cfg = []
+ all_pp_and_tp_ranks = {}
+
+ # 初始化并行组
+ dist_all_world_size = 0
+ for i in range(get_all_config_size()):
+ cfg = get_dist_model_config(global_index=i)
+ dist_all_world_size += cfg.world_size
+ subworld = DetachedSubWorld(cfg.name, cfg.start_rank,
+ list(range(cfg.start_rank, cfg.start_rank + cfg.world_size)))
+ pp_ranks, tp_ranks = _initialize_model_parallel(cfg.tensor_model_parallel_size, cfg.pipeline_model_parallel_size,
+ context_parallel_size=cfg.context_parallel_size,
+ subworld=subworld)
+ all_cfg.append(cfg)
+ all_pp_and_tp_ranks[cfg.model_index] = all_pp_and_tp_ranks.get(cfg.model_index, []) + [[pp_ranks, tp_ranks]]
+ if world_size != dist_all_world_size:
+ raise RuntimeError(f"{world_size=} should equals to {dist_all_world_size=}")
+
+ # 生成映射关系
+ from .communication.dist_ranks_match import generate_model_comm_ranks, get_dst_ranks
+ for i in range(len(all_pp_and_tp_ranks) - 1):
+ for ranks_prev in all_pp_and_tp_ranks.get(i, []):
+ for ranks_post in all_pp_and_tp_ranks.get(i + 1, []):
+ comm_args = ranks_prev + ranks_post
+ generate_model_comm_ranks(*comm_args)
+ dst_ranks = get_dst_ranks()
+ if dst_ranks is not None:
+ print(f"rank={torch.distributed.get_rank()} "
+ f"--> {dst_ranks}, prev: {list(comm_args[1])}, last: {list(comm_args[3])}")
+
+
+def _set_global_memory_buffer(subworld: DetachedSubWorld):
+ # Initialize subworld buffer
+ if subworld.global_memory_buffer is not None:
+ raise RuntimeError('subworld memory buffer is already initialized')
+ subworld.global_memory_buffer = GlobalMemoryBuffer()
+
+
+def _get_subworld_by_name(name=""):
+ if ALL_SUB_WORLD is None:
+ raise RuntimeError('all subworld is not initialized')
+ return ALL_SUB_WORLD.get(name, None)
+
+
+def set_subworld_by_name(name=""):
+ global _CUR_SUB_WORLD
+ if is_in_subworld(name):
+ _CUR_SUB_WORLD = _get_subworld_by_name(name)
+
+
+def is_in_subworld(name=""):
+ subworld = _get_subworld_by_name(name)
+ if subworld is None:
+ return False
+ rank = torch.distributed.get_rank()
+ return rank in subworld.ranks
+
+
+def is_not_use_dist_train_or_in_subworld(name=""):
+ args = get_args()
+ if getattr(args, "dist_train", False):
+ return is_in_subworld(name)
+ return True
+
+
+def is_use_dist_train_and_in_subworld(name=""):
+ args = get_args()
+ if getattr(args, "dist_train", False):
+ return is_in_subworld(name)
+ return False
+
+
+def get_is_pipeline_first_stage_wrapper(is_pipeline_first_stage):
+ @wraps(is_pipeline_first_stage)
+ def wrapper(*args, **kwargs):
+ return _is_pipeline_first_stage(*args, **kwargs)
+ return wrapper
+
+
+def _is_pipeline_first_stage(ignore_virtual=False, is_global=True):
+ """Return True if in the first pipeline model-parallel stage, False otherwise."""
+ if is_global:
+ from .config.dist_train_config import get_dist_model_name
+ if _get_subworld_by_name(get_dist_model_name()) is None:
+ return False
+
+ if not ignore_virtual:
+ if (
+ get_virtual_pipeline_model_parallel_world_size() is not None
+ and get_virtual_pipeline_model_parallel_rank() != 0
+ ):
+ return False
+ return get_pipeline_model_parallel_rank() == 0
+
+
+def get_is_pipeline_last_stage_wrapper(is_pipeline_last_stage):
+ @wraps(is_pipeline_last_stage)
+ def wrapper(*args, **kwargs):
+ return _is_pipeline_last_stage(*args, **kwargs)
+ return wrapper
+
+
+def _is_pipeline_last_stage(ignore_virtual=False, is_global=True):
+ """Return True if in the last pipeline model-parallel stage, False otherwise."""
+ if is_global:
+ from .config import dist_train_config
+ name = dist_train_config._RANK_NUMBER_TO_MODEL_NAME[-1]
+ if _get_subworld_by_name(name) is None:
+ return False
+
+ if not ignore_virtual:
+ virtual_pipeline_model_parallel_world_size = (
+ get_virtual_pipeline_model_parallel_world_size()
+ )
+ if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
+ virtual_pipeline_model_parallel_world_size - 1
+ ):
+ return False
+ return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
+
+
+def subwrold_decorator(wrap_func):
+ @wraps(wrap_func)
+ def wrap_the_function(*args, **kwargs):
+ global _CUR_SUB_WORLD
+ reset_global_group_and_ranks()
+ if _CUR_SUB_WORLD is None:
+ from .config.dist_train_config import get_dist_model_name
+ name = get_dist_model_name()
+ set_subworld_by_name(name)
+ if _CUR_SUB_WORLD is not None:
+ set_global_group_and_ranks_by_subworld(subworld=_CUR_SUB_WORLD)
+ ret = wrap_func(*args, **kwargs)
+ return ret
+ return wrap_the_function
+
+
+def get_tensor_model_parallel_src_rank_wrapper(get_tensor_model_parallel_src_rank):
+ @wraps(get_tensor_model_parallel_src_rank)
+ def wrapper():
+ return _get_tensor_model_parallel_src_rank()
+ return wrapper
+
+
+@subwrold_decorator
+def _get_tensor_model_parallel_src_rank():
+ """Calculate the global rank corresponding to the first local rank in the tensor model parallel group."""
+ if _CUR_SUB_WORLD is None:
+ return 0
+ global_rank = (torch.distributed.get_rank() - _CUR_SUB_WORLD.start_rank)
+ local_world_size = get_tensor_model_parallel_world_size()
+ return (global_rank // local_world_size) * local_world_size + _CUR_SUB_WORLD.start_rank
+
+
+@subwrold_decorator
+def is_initialized():
+ """Useful for code segments that may be accessed with or without mpu initialization"""
+ return _DATA_PARALLEL_GROUP is not None
+
+
+@subwrold_decorator
+def model_parallel_is_initialized():
+ """Check if model and data parallel groups are initialized."""
+ if (
+ _TENSOR_MODEL_PARALLEL_GROUP is None
+ or _PIPELINE_MODEL_PARALLEL_GROUP is None
+ or _DATA_PARALLEL_GROUP is None
+ ):
+ return False
+ return True
+
+
+@subwrold_decorator
+def get_model_parallel_group(with_expert_parallel=False):
+ """Get the model parallel group the caller rank belongs to."""
+ if with_expert_parallel:
+ assert (
+ _MODEL_AND_EXPERT_PARALLEL_GROUP is not None
+ ), 'model parallel group is not initialized'
+ return _MODEL_AND_EXPERT_PARALLEL_GROUP
+ assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized'
+ return _MODEL_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_tensor_model_parallel_group(check_initialized=True):
+ """Get the tensor model parallel group the caller rank belongs to."""
+ if check_initialized:
+ assert (
+ _TENSOR_MODEL_PARALLEL_GROUP is not None
+ ), 'tensor model parallel group is not initialized'
+ return _TENSOR_MODEL_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_group():
+ """Get the pipeline model parallel group the caller rank belongs to."""
+ assert (
+ _PIPELINE_MODEL_PARALLEL_GROUP is not None
+ ), 'pipeline_model parallel group is not initialized'
+ return _PIPELINE_MODEL_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_data_parallel_group(with_context_parallel=False):
+ """Get the data parallel group the caller rank belongs to."""
+ if with_context_parallel:
+ assert (
+ _DATA_PARALLEL_GROUP_WITH_CP is not None
+ ), 'data parallel group with context parallel combined is not initialized'
+ return _DATA_PARALLEL_GROUP_WITH_CP
+ else:
+ assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
+ return _DATA_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_data_parallel_group_gloo(with_context_parallel=False):
+ """Get the data parallel group-gloo the caller rank belongs to."""
+ if with_context_parallel:
+ assert (
+ _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
+ ), 'data parallel group-gloo with context parallel combined is not initialized'
+ return _DATA_PARALLEL_GROUP_WITH_CP_GLOO
+ else:
+ assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
+ return _DATA_PARALLEL_GROUP_GLOO
+
+
+@subwrold_decorator
+def get_context_parallel_group(check_initialized=True):
+ """Get the context parallel group the caller rank belongs to."""
+ if check_initialized:
+ assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
+ return _CONTEXT_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_context_parallel_global_ranks(check_initialized=True):
+ """Get all global ranks of the context parallel group that the caller rank belongs to."""
+ if check_initialized:
+ assert _CONTEXT_PARALLEL_GLOBAL_RANKS is not None, 'context parallel group is not initialized'
+ return _CONTEXT_PARALLEL_GLOBAL_RANKS
+
+
+@subwrold_decorator
+def get_embedding_group():
+ """Get the embedding group the caller rank belongs to."""
+ assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
+ return _EMBEDDING_GROUP
+
+
+@subwrold_decorator
+def get_position_embedding_group():
+ """Get the position embedding group the caller rank belongs to."""
+ assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized'
+ return _POSITION_EMBEDDING_GROUP
+
+
+@subwrold_decorator
+def get_position_embedding_group():
+ """Get the position embedding group the caller rank belongs to."""
+ if _POSITION_EMBEDDING_GROUP is None:
+ raise RuntimeError('position embedding group is not initialized')
+ return _POSITION_EMBEDDING_GROUP
+
+
+@subwrold_decorator
+def get_amax_reduction_group(with_context_parallel=False):
+ """Get the FP8 amax reduction group the caller rank belongs to."""
+ if with_context_parallel:
+ assert (
+ _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None
+ ), 'FP8 amax reduction group is not initialized'
+ return _TENSOR_AND_CONTEXT_PARALLEL_GROUP
+ else:
+ assert (
+ _TENSOR_MODEL_PARALLEL_GROUP is not None
+ ), 'FP8 amax reduction group is not initialized'
+ return _TENSOR_MODEL_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_tensor_and_data_parallel_group(with_context_parallel=False):
+ """Get the tensor and data parallel group the caller rank belongs to."""
+ if with_context_parallel:
+ assert (
+ _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
+ ), 'tensor and data parallel group is not initialized'
+ return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
+ else:
+ assert (
+ _TENSOR_AND_DATA_PARALLEL_GROUP is not None
+ ), 'tensor and data parallel group is not initialized'
+ return _TENSOR_AND_DATA_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_tensor_and_context_parallel_group():
+ """Get the tensor and context parallel group the caller rank belongs to."""
+ assert (
+ _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None
+ ), 'tensor and context parallel group is not initialized'
+ return _TENSOR_AND_CONTEXT_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_expert_model_parallel_group():
+ assert (
+ _EXPERT_MODEL_PARALLEL_GROUP is not None
+ ), 'expert model parallel group is not initialized'
+ return _EXPERT_MODEL_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_tensor_and_expert_parallel_group():
+ assert (
+ _TENSOR_AND_EXPERT_PARALLEL_GROUP is not None
+ ), 'tensor and expert parallel group is not initialized'
+ return _TENSOR_AND_EXPERT_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_data_modulo_expert_parallel_group(with_context_parallel=False):
+ if with_context_parallel:
+ assert (
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP is not None
+ ), 'data modulo expert parallel group with context parallel is not initialized'
+ return _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
+ else:
+ assert (
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP is not None
+ ), 'data modulo expert parallel group is not initialized'
+ return _DATA_MODULO_EXPERT_PARALLEL_GROUP
+
+
+@subwrold_decorator
+def get_data_modulo_expert_parallel_group_gloo(with_context_parallel=False):
+ if with_context_parallel:
+ assert (
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO is not None
+ ), 'data modulo expert parallel group-gloo with context parallel is not initialized'
+ return _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
+ else:
+ assert (
+ _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is not None
+ ), 'data modulo expert parallel group-gloo is not initialized'
+ return _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
+
+
+@subwrold_decorator
+def get_tensor_model_parallel_world_size():
+ """Return world size for the tensor model parallel group."""
+ global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
+ if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
+ return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
+ return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_world_size():
+ """Return world size for the pipeline model parallel group."""
+ global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
+ return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
+
+
+@subwrold_decorator
+def get_tensor_model_parallel_rank():
+ """Return my rank for the tensor model parallel group."""
+ global _MPU_TENSOR_MODEL_PARALLEL_RANK
+ if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
+ return _MPU_TENSOR_MODEL_PARALLEL_RANK
+ return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_rank(is_global=False):
+ """Return my rank for the pipeline model parallel group."""
+ global _MPU_PIPELINE_MODEL_PARALLEL_RANK
+ if is_global:
+ return get_global_pipeline_parallel_rank()
+ else:
+ if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
+ return _MPU_PIPELINE_MODEL_PARALLEL_RANK
+ return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_split_rank():
+ """Return pipeline model parallel split rank."""
+ global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
+ return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
+
+
+def is_rank_in_embedding_group(ignore_virtual=False):
+ """Return true if current rank is in embedding group, False otherwise."""
+ rank = torch.distributed.get_rank()
+ if ignore_virtual:
+ return rank in _EMBEDDING_GLOBAL_RANKS
+ if rank in _EMBEDDING_GLOBAL_RANKS:
+ if get_args().multimodal:
+ if rank == _EMBEDDING_GLOBAL_RANKS[-1]:
+ return _is_pipeline_last_stage()
+ else:
+ return True
+ else:
+ if rank == _EMBEDDING_GLOBAL_RANKS[0]:
+ return _is_pipeline_first_stage()
+ elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
+ return _is_pipeline_last_stage()
+ else:
+ return True
+ return False
+
+
+@subwrold_decorator
+def is_rank_in_position_embedding_group():
+ """Return true if current rank is in position embedding group, False otherwise."""
+ rank = torch.distributed.get_rank()
+ global _POSITION_EMBEDDING_GLOBAL_RANKS
+ return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
+
+
+@subwrold_decorator
+def get_virtual_pipeline_model_parallel_rank():
+ """Return the virtual pipeline-parallel rank."""
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
+ return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
+
+
+@subwrold_decorator
+def get_virtual_pipeline_model_parallel_world_size():
+ """Return the virtual pipeline-parallel world size."""
+ global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+ return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
+
+
+@subwrold_decorator
+def get_data_parallel_src_rank(with_context_parallel=False):
+ """Calculate the global rank corresponding to the first local rank in the data parallel group."""
+ if with_context_parallel:
+ assert (
+ _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
+ ), "Data parallel group with context parallel combined is not initialized"
+ return _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
+ else:
+ assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
+ return _DATA_PARALLEL_GLOBAL_RANKS[0]
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_first_rank():
+ """Return the global rank of the first process in the pipeline for the current tensor parallel group"""
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
+ return _PIPELINE_GLOBAL_RANKS[0]
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_last_rank():
+ """Return the global rank of the last process in the pipeline for the current tensor parallel group"""
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
+ last_rank_local = get_pipeline_model_parallel_world_size() - 1
+ return _PIPELINE_GLOBAL_RANKS[last_rank_local]
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_next_rank():
+ """Return the global rank that follows the caller in the pipeline"""
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
+ rank_in_pipeline = get_pipeline_model_parallel_rank()
+ world_size = get_pipeline_model_parallel_world_size()
+ return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
+
+
+@subwrold_decorator
+def get_pipeline_model_parallel_prev_rank():
+ """Return the global rank that preceeds the caller in the pipeline"""
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
+ rank_in_pipeline = get_pipeline_model_parallel_rank()
+ world_size = get_pipeline_model_parallel_world_size()
+ return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
+
+
+@subwrold_decorator
+def get_expert_model_parallel_world_size():
+ """Return world size for the expert model parallel group"""
+ if _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE:
+ return _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ tensor_and_expert_parallel_world_size = torch.distributed.get_world_size(
+ group=get_tensor_and_expert_parallel_group()
+ )
+ return tensor_and_expert_parallel_world_size // get_tensor_model_parallel_world_size()
+ else:
+ return 0
+
+
+@subwrold_decorator
+def get_expert_model_parallel_rank():
+ """Return my rank for the expert parallel group"""
+ if _MPU_EXPERT_MODEL_PARALLEL_RANK:
+ return _MPU_EXPERT_MODEL_PARALLEL_RANK
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ tensor_and_expert_parallel_rank = torch.distributed.get_rank(
+ group=get_tensor_and_expert_parallel_group()
+ )
+ return tensor_and_expert_parallel_rank // get_tensor_model_parallel_world_size()
+ else:
+ return 0
+
+
+@subwrold_decorator
+def get_global_memory_buffer():
+ """Return the global GlobalMemoryBuffer object"""
+ if _GLOBAL_MEMORY_BUFFER is None:
+ raise RuntimeError('global memory buffer is not initialized')
+ return _GLOBAL_MEMORY_BUFFER
+
+
+@subwrold_decorator
+def get_moe_layer_wise_logging_tracker():
+ """Return the moe layer wise tracker."""
+ global _MOE_LAYER_WISE_LOGGING_TRACKER
+ return _MOE_LAYER_WISE_LOGGING_TRACKER
+
+
+def get_global_pipeline_parallel_rank():
+ """Return main data parallel size"""
+ all_config = get_all_config()
+ global_pp_rank = 0
+ for config in all_config.values():
+ if config.model_index < get_dist_model_config().model_index:
+ global_pp_rank += config.pipeline_model_parallel_size
+ else:
+ global_pp_rank += get_pipeline_model_parallel_rank()
+ break
+ return global_pp_rank
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/pipeline_parallel/__init__.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/pipeline_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68da32c88bbb0d4dbfe4762b888fd8e43d0045cb
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/pipeline_parallel/__init__.py
@@ -0,0 +1 @@
+from . import dist_schedules
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/pipeline_parallel/dist_schedules.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/pipeline_parallel/dist_schedules.py
new file mode 100644
index 0000000000000000000000000000000000000000..82cf9fddd0e5091e62509c60213ea164b2aa653c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/pipeline_parallel/dist_schedules.py
@@ -0,0 +1,524 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+import contextlib
+from typing import Iterator, List, Union, Optional
+from functools import wraps
+import torch
+from megatron.training import get_args
+from megatron.core.utils import get_model_config, get_model_type
+from megatron.core.enums import ModelType
+import megatron.core.pipeline_parallel.schedules as schedules
+from megatron.core.parallel_state import (
+ get_tensor_model_parallel_world_size,
+ get_pipeline_model_parallel_rank,
+ get_context_parallel_world_size,
+ is_pipeline_stage_before_split,
+ is_pipeline_stage_after_split,
+ get_pipeline_model_parallel_world_size,
+ get_pipeline_model_parallel_next_rank,
+ get_pipeline_model_parallel_prev_rank
+)
+from ..communication.dist_ranks_match import get_dst_ranks
+from ..communication.dist_communication import generate_send_recv_mask, send_recv_tensor_list, send_recv
+from ..config.dist_train_config import (
+ get_dist_model_config,
+ get_all_config_size,
+ is_forward_only_model,
+ is_use_multiparam_send_recv
+)
+
+
+def get_forward_backward_func_wrapper(get_forward_backward_func):
+ @wraps(get_forward_backward_func)
+ def wrapper(*args, **kwargs):
+ if get_args().dist_train:
+ return forward_backward_pipelining_without_interleaving
+ return get_forward_backward_func(*args, **kwargs)
+
+ return wrapper
+
+
+def p2p_ops_wrapper(p2p_ops):
+ @wraps(p2p_ops)
+ def wrapper(*args, **kwargs):
+ arguments = get_args()
+ if arguments.dist_train:
+ return _p2p_ops(*args, **kwargs)
+ return p2p_ops(*args, **kwargs)
+ return wrapper
+
+
+def _p2p_ops(
+ *,
+ tensor_send_prev: Optional[torch.Tensor],
+ tensor_recv_prev: Optional[torch.Tensor],
+ tensor_send_next: Optional[torch.Tensor],
+ tensor_recv_next: Optional[torch.Tensor],
+ group: torch.distributed.ProcessGroup
+):
+ reqs = []
+ # To prevent deadlocks caused by different pipeline stages receiving tensor simultaneously.
+ if get_pipeline_model_parallel_rank(is_global=True) % 2 == 0:
+ if tensor_send_next is not None:
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(send_next_req)
+
+ if tensor_recv_prev is not None:
+ recv_prev_req = torch.distributed.irecv(
+ tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(recv_prev_req)
+
+ if tensor_send_prev is not None:
+ send_prev_req = torch.distributed.isend(
+ tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(send_prev_req)
+
+ if tensor_recv_next is not None:
+ recv_next_req = torch.distributed.irecv(
+ tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(recv_next_req)
+
+ else:
+ if tensor_recv_prev is not None:
+ recv_prev_req = torch.distributed.irecv(
+ tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(recv_prev_req)
+
+ if tensor_send_next is not None:
+ send_next_req = torch.distributed.isend(
+ tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(send_next_req)
+
+ if tensor_recv_next is not None:
+ recv_next_req = torch.distributed.irecv(
+ tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
+ )
+ reqs.append(recv_next_req)
+
+ if tensor_send_prev is not None:
+ send_prev_req = torch.distributed.isend(
+ tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
+ )
+ reqs.append(send_prev_req)
+ return reqs
+
+
+def get_tensor_shapes(
+ *,
+ rank: int,
+ model_type: ModelType,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int,
+ config,
+):
+ # Determine right tensor sizes (based on position of rank with respect to split
+ # rank) and model size.
+ # Send two tensors if model is T5 and rank is in decoder stage:
+ # first tensor is decoder (pre-transpose),
+ # second tensor is encoder (post-transpose).
+ # If model is T5 and rank is at the boundary:
+ # send one tensor (post-transpose from encoder).
+ # Otherwise, send one tensor (pre-transpose).
+ tensor_shapes = []
+
+ seq_length = seq_length // get_context_parallel_world_size()
+ if model_type == ModelType.encoder_and_decoder:
+ decoder_seq_length = decoder_seq_length // get_context_parallel_world_size()
+
+ if config.sequence_parallel:
+ seq_length = seq_length // get_tensor_model_parallel_world_size()
+ if model_type == ModelType.encoder_and_decoder:
+ decoder_seq_length = (
+ decoder_seq_length // get_tensor_model_parallel_world_size()
+ )
+
+ if model_type == ModelType.encoder_and_decoder:
+ if is_pipeline_stage_before_split(rank):
+ if is_use_multiparam_send_recv():
+ tensor_shapes = [
+ {'shape': (seq_length, micro_batch_size, config.hidden_size), 'dtype': config.params_dtype},
+ ]
+ else:
+ tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ else:
+ if is_use_multiparam_send_recv():
+ tensor_shapes = [
+ {'shape': ((decoder_seq_length, micro_batch_size, config.hidden_size)), 'dtype': config.params_dtype},
+ {'shape': ((seq_length, micro_batch_size, config.hidden_size)), 'dtype': config.params_dtype}
+ ]
+ else:
+ tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
+ tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ else:
+ if is_use_multiparam_send_recv():
+ tensor_shapes = [
+ {'shape': ((seq_length, micro_batch_size, config.hidden_size)), 'dtype': config.params_dtype},
+ ]
+ else:
+ tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+
+ return tensor_shapes
+
+
+def forward_backward_pipelining_without_interleaving(
+ *,
+ forward_step_func,
+ data_iterator: Union[Iterator, List[Iterator]],
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
+ num_microbatches: int,
+ seq_length: int,
+ micro_batch_size: int,
+ decoder_seq_length: int = None,
+ forward_only: bool = False,
+ collect_non_loss_data: bool = False,
+ first_val_step: bool = None,
+):
+ """
+ Run non-interleaved 1F1B schedule, with communication between pipeline stages.
+ Returns dictionary with losses if the last stage, empty dict otherwise.
+ """
+ model_config = get_dist_model_config()
+ if hasattr(model_config, 'forward_only'):
+ forward_only = model_config.forward_only
+ if isinstance(model, list):
+ if len(model) != 1:
+ raise ValueError(
+ "non-interleaved pipeline parallelism does not support model chunking"
+ )
+ model = model[0]
+ if isinstance(data_iterator, list):
+ if len(data_iterator) != 1:
+ raise ValueError(
+ "non-pipeline-parallel schedule does not support model chunking"
+ )
+ data_iterator = data_iterator[0]
+
+ config = get_model_config(model)
+ config.deallocate_pipeline_outputs = False
+ if config.overlap_p2p_comm:
+ raise ValueError(
+ "Non-interleaved pipeline parallelism does not support overlapping p2p communication"
+ )
+
+ # Needed only when gradients are finalized in M-Core
+ if config.finalize_model_grads_func is not None and not forward_only:
+ embedding_module = schedules.clear_embedding_activation_buffer(config, model)
+
+ if config.timers is not None:
+ config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
+
+ # Disable async grad reductions
+ no_sync_func = config.no_sync_func
+ if no_sync_func is None:
+ no_sync_func = contextlib.nullcontext
+ no_sync_context = None
+
+ def disable_grad_sync():
+ """Disable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is None:
+ no_sync_context = no_sync_func()
+ no_sync_context.__enter__()
+
+ def enable_grad_sync():
+ """Enable asynchronous grad reductions"""
+ nonlocal no_sync_context
+ if no_sync_context is not None:
+ no_sync_context.__exit__(None, None, None)
+ no_sync_context = None
+
+ disable_grad_sync()
+
+ # Compute number of warmup microbatches.
+ rank = get_pipeline_model_parallel_rank()
+ model_config = get_dist_model_config(rank=torch.distributed.get_rank())
+ num_warmup_microbatches = 0
+ for index in range(model_config.model_index, get_all_config_size()):
+ num_warmup_microbatches += get_dist_model_config(global_index=index).pipeline_model_parallel_size
+ num_warmup_microbatches = num_warmup_microbatches - rank - 1
+ num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
+ num_microbatches_remaining = num_microbatches - num_warmup_microbatches
+
+ max_outstanding_backprops = None
+ if config.num_microbatches_with_partial_activation_checkpoints is not None:
+ max_outstanding_backprops = num_warmup_microbatches + 1
+
+ model_type = get_model_type(model)
+
+ get_shape_func = schedules.get_tensor_shapes if not is_forward_only_model() else get_tensor_shapes
+
+ recv_tensor_shapes = get_shape_func(
+ rank=rank - 1,
+ model_type=model_type,
+ seq_length=seq_length,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+ send_tensor_shapes = get_shape_func(
+ rank=rank,
+ model_type=model_type,
+ seq_length=seq_length,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+
+ send_recv_ops = generate_send_recv_mask(torch.distributed.get_rank())
+
+ # Input, output tensors only need to be saved when doing backward passes
+ input_tensors = None
+ output_tensors = None
+ total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
+
+ if not forward_only:
+ input_tensors = []
+ output_tensors = []
+ forward_data_store = []
+
+ # Run warmup forward passes.
+ for i in range(num_warmup_microbatches):
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ i % max_outstanding_backprops
+ >= config.num_microbatches_with_partial_activation_checkpoints
+ )
+ else:
+ checkpoint_activations_microbatch = None
+
+ input_tensor = recv_forward(recv_tensor_shapes, config, send_recv_ops)
+ output_tensor, num_tokens = schedules.forward_step(
+ forward_step_func,
+ data_iterator,
+ model,
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ checkpoint_activations_microbatch,
+ schedules.check_first_val_step(first_val_step, forward_only, i == 0),
+ current_microbatch=i,
+ )
+ send_forward(output_tensor, send_tensor_shapes, config, send_recv_ops)
+ total_num_tokens += num_tokens.item()
+
+ if not forward_only:
+ input_tensors.append(input_tensor)
+ output_tensors.append(output_tensor)
+ schedules.deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
+
+ # Before running 1F1B, need to receive first forward tensor.
+ # If all microbatches are run in warmup / cooldown phase, then no need to
+ # receive this tensor here.
+ if num_microbatches_remaining > 0:
+ input_tensor = recv_forward(recv_tensor_shapes, config, send_recv_ops)
+
+ # Run 1F1B in steady state.
+ for i in range(num_microbatches_remaining):
+ last_iteration = i == (num_microbatches_remaining - 1)
+
+ # Decide to checkpoint all layers' activations of the current micro-batch
+ if max_outstanding_backprops is not None:
+ checkpoint_activations_microbatch = (
+ (i + num_warmup_microbatches) % max_outstanding_backprops
+ ) >= config.num_microbatches_with_partial_activation_checkpoints
+ else:
+ checkpoint_activations_microbatch = None
+
+ output_tensor, num_tokens = schedules.forward_step(
+ forward_step_func,
+ data_iterator,
+ model,
+ num_microbatches,
+ input_tensor,
+ forward_data_store,
+ config,
+ collect_non_loss_data,
+ checkpoint_activations_microbatch,
+ schedules.check_first_val_step(
+ first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0)
+ ),
+ current_microbatch=i + num_warmup_microbatches,
+ )
+ total_num_tokens += num_tokens.item()
+
+ if forward_only:
+ send_forward(output_tensor, send_tensor_shapes, config, send_recv_ops)
+
+ if not last_iteration:
+ input_tensor = recv_forward(recv_tensor_shapes, config, send_recv_ops)
+
+ else:
+ output_tensor_grad = send_forward_recv_backward(
+ output_tensor, send_tensor_shapes, config, send_recv_ops
+ )
+
+ # Add input_tensor and output_tensor to end of list.
+ input_tensors.append(input_tensor)
+ output_tensors.append(output_tensor)
+ schedules.deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
+
+ # Pop input_tensor and output_tensor from the start of the list for
+ # the backward pass.
+ input_tensor = input_tensors.pop(0)
+ output_tensor = output_tensors.pop(0)
+
+ # Enable grad sync for the last microbatch in the batch if the full
+ # backward pass completes in the 1F1B stage.
+ if num_warmup_microbatches == 0 and last_iteration:
+ if config.grad_sync_func is None or rank == 0:
+ enable_grad_sync()
+
+ input_tensor_grad = _backward_step(
+ input_tensor, output_tensor, output_tensor_grad, model_type, config
+ )
+
+ if last_iteration:
+ input_tensor = None
+ send_backward(input_tensor_grad, recv_tensor_shapes, config, send_recv_ops)
+ else:
+ input_tensor = send_backward_recv_forward(
+ input_tensor_grad, recv_tensor_shapes, config, send_recv_ops
+ )
+
+ # Run cooldown backward passes.
+ if not forward_only:
+ for i in range(num_warmup_microbatches):
+
+ # Enable async grad reduction in the last backward pass
+ # Note: If grad sync function is provided, only enable
+ # async grad reduction in first pipeline stage. Other
+ # pipeline stages do grad reduction during pipeline
+ # bubble.
+ if i == num_warmup_microbatches - 1:
+ if config.grad_sync_func is None or rank == 0:
+ enable_grad_sync()
+
+ input_tensor = input_tensors.pop(0)
+ output_tensor = output_tensors.pop(0)
+
+ output_tensor_grad = recv_backward(send_tensor_shapes, config, send_recv_ops)
+
+ input_tensor_grad = _backward_step(
+ input_tensor, output_tensor, output_tensor_grad, model_type, config
+ )
+
+ send_backward(input_tensor_grad, recv_tensor_shapes, config, send_recv_ops)
+
+ # Launch any remaining grad reductions.
+ if no_sync_context is not None:
+ enable_grad_sync()
+ if config.grad_sync_func is not None:
+ config.grad_sync_func(model.parameters())
+
+ if config.finalize_model_grads_func is not None and not forward_only:
+
+ # If defer_embedding_wgrad_compute is enabled we need to do the
+ # weight gradient GEMM's here.
+ schedules.finish_embedding_wgrad_compute(config, embedding_module)
+
+ # Finalize model grads (perform full grad all-reduce / reduce-scatter for
+ # data parallelism, layernorm all-reduce for sequence parallelism, and
+ # embedding all-reduce for pipeline parallelism).
+ config.finalize_model_grads_func(
+ [model], total_num_tokens if config.calculate_per_token_loss else None
+ )
+
+ if config.timers is not None:
+ config.timers('forward-backward').stop()
+
+ return forward_data_store
+
+
+def _backward_step(*args, **kwargs):
+ if is_use_multiparam_send_recv():
+ from mindspeed.core.pipeline_parallel.multiparameter_schedules import backward_step
+ return backward_step(*args, **kwargs)
+
+ return schedules.backward_step(*args, **kwargs)
+
+
+def get_send_recv_fun():
+ if is_use_multiparam_send_recv():
+ return send_recv_tensor_list
+ else:
+ return send_recv
+
+
+def post_process_for_recving(recv_tensors: List):
+ if is_use_multiparam_send_recv():
+ return [tensors[0] for tensors in recv_tensors]
+ else:
+ return [recv_tensors[0]]
+
+
+def send_forward(output_tensors, tensor_shapes, config, send_recv_ops):
+ if send_recv_ops['send_forward']:
+ send_recv_func = get_send_recv_fun()
+ send_recv_func(output_tensors, False, get_dst_ranks())
+ else:
+ schedules.send_forward(output_tensors, tensor_shapes, config)
+
+
+def recv_forward(tensor_shapes, config, send_recv_ops):
+ if send_recv_ops['recv_forward']:
+ send_recv_func = get_send_recv_fun()
+ recv_tensors = send_recv_func(None, True, get_dst_ranks())
+ input_tensor = post_process_for_recving(recv_tensors)
+ else:
+ input_tensor = schedules.recv_forward(tensor_shapes, config)
+ return input_tensor
+
+
+def send_backward(input_tensor_grad, tensor_shapes, config, send_recv_ops):
+ if send_recv_ops['send_backward']:
+ send_recv_func = get_send_recv_fun()
+ send_recv_func(input_tensor_grad, False, get_dst_ranks())
+ else:
+ schedules.send_backward(input_tensor_grad, tensor_shapes, config)
+
+
+def send_forward_recv_backward(output_tensor, tensor_shape, config, send_recv_ops):
+ if send_recv_ops['send_forward'] and send_recv_ops['recv_backward']:
+ send_recv_func = get_send_recv_fun()
+ recv_tensors = send_recv_func(output_tensor, True, get_dst_ranks())
+ output_tensor_grad = post_process_for_recving(recv_tensors)
+ else:
+ output_tensor_grad = schedules.send_forward_recv_backward(
+ output_tensor, tensor_shape, config
+ )
+ return output_tensor_grad
+
+
+def send_backward_recv_forward(input_tensor_grad, tensor_shapes, config, send_recv_ops):
+ if send_recv_ops['recv_forward']:
+ if not send_recv_ops['send_backward']: # forward_only model adaption
+ input_tensor_grad = [[]]
+ send_recv_func = get_send_recv_fun()
+ recv_tensors = send_recv_func(input_tensor_grad, True, get_dst_ranks())
+ input_tensor = post_process_for_recving(recv_tensors)
+ else:
+ input_tensor = schedules.send_backward_recv_forward(
+ input_tensor_grad, tensor_shapes, config
+ )
+ return input_tensor
+
+
+def recv_backward(tensor_shapes, config, send_recv_ops):
+ if send_recv_ops['recv_backward']:
+ send_recv_func = get_send_recv_fun()
+ recv_tensors = send_recv_func(None, True, get_dst_ranks())
+ output_tensor_grad = post_process_for_recving(recv_tensors)
+ else:
+ output_tensor_grad = schedules.recv_backward(tensor_shapes, config)
+ return output_tensor_grad
diff --git a/model/train/yoco_moe/mindspeed/multi_modal/dist_train/training.py b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeb4d50820ad49d1a4876915ec6ac833f4593282
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/multi_modal/dist_train/training.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
+from functools import wraps
+from megatron.training import get_args
+from .config.dist_train_config import get_all_config
+
+
+def initialize_distributed_wrapper(_initialize_distributed):
+ @wraps(_initialize_distributed)
+ def wrapper():
+ args = get_args()
+ temp_world_size = args.world_size
+ true_world_size = 0
+ for cfg in get_all_config().values():
+ true_world_size += cfg.world_size
+ args.world_size = true_world_size
+
+ _initialize_distributed()
+
+ args.world_size = temp_world_size
+
+ return wrapper
diff --git a/model/train/yoco_moe/mindspeed/op_builder/__init__.py b/model/train/yoco_moe/mindspeed/op_builder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..34803beda9d12e9289c8d6f6d91c6f3f438bb54c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/__init__.py
@@ -0,0 +1,27 @@
+from .fusion_attention_v2_builder import FusionAttentionV2OpBuilder
+from .algorithm_builder import AlgorithmOpBuilder
+from .memory_fragmentation_builder import MemoryFragmentationBuilder
+from .adaptive_recomputing_builder import AdaptiveRecomputingPluggableAllocatorBuilder
+from .npu_dropout_add_layer_norm_builder import NpuDropoutAddLayerNormOpBuilder
+from .atb_builder import AtbOpBuilder
+from .swiglu_builder import SwigluOpBuilder
+from .lcal_builder import LcalOpBuilder
+from .rms_norm_builder import RmsNormOpBuilder
+from .npu_grouped_mat_mul_all_reduce_builder import GroupedMatMulAllReduceOpBuilder
+from .gmm_builder import GMMOpBuilder, GMMV2OpBuilder
+from .quant_gmm_builder import QuantGMMOpBuilder
+from .weight_quant_gmm_builder import WeightQuantGMMOpBuilder
+from .ffn_builder import FFNOpBuilder
+from .npu_mm_all_reduce_add_rms_norm_builder import MatmulAllReduceAddRmsNormOpBuilder
+from .npu_inplace_mm_all_reduce_add_rms_norm_builder import InplaceMatmulAllReduceAddRmsNormOpBuilder
+from .npu_rotary_position_embedding_builder import RotaryPositionEmbeddingOpBuilder
+from .npu_moe_token_permute_builder import MoeTokenPermuteOpBuilder
+from .npu_moe_token_unpermute_builder import MoeTokenUnpermuteOpBuilder
+from .npu_ring_attention_update_builder import RingAttentionUpdateOpBuilder
+from .npu_bmm_reduce_scatter_all_to_all_builder import BatchMatMulReduceScatterAlltoAllOpBuilder
+from .npu_all_to_all_all_gather_bmm_builder import AllToAllAllGatherBatchMatMulOpBuilder
+from .adaptive_cp_builder import AdaptiveCpOpBuilder
+from .matmul_add_builder import MatmulAddOpBuilder
+from .groupmatmul_add_builder import GroupMatmulAddOpBuilder
+from .fused_ema_adamw_builder import FusedEmaAdamWOpBuilder
+from .smart_swap_builder import SmartSwapBuilder
diff --git a/model/train/yoco_moe/mindspeed/op_builder/adaptive_cp_builder.py b/model/train/yoco_moe/mindspeed/op_builder/adaptive_cp_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..335d34ad71ea04cee05bade8ae86a32bfdafb7f2
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/adaptive_cp_builder.py
@@ -0,0 +1,11 @@
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class AdaptiveCpOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "adaptive_cp"
+
+ def __init__(self):
+ super(AdaptiveCpOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/algorithm/adaptive_cp/adaptive_cp.cpp']
diff --git a/model/train/yoco_moe/mindspeed/op_builder/adaptive_recomputing_builder.py b/model/train/yoco_moe/mindspeed/op_builder/adaptive_recomputing_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f736af849d3e3aa5f7fe0a7d9cc67a4bb82fb9f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/adaptive_recomputing_builder.py
@@ -0,0 +1,37 @@
+import os
+
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class AdaptiveRecomputingPluggableAllocatorBuilder(MindSpeedOpBuilder):
+ NAME = "AdaptiveRecomputing"
+ _torch_path = None
+
+ def __init__(self):
+ import torch
+ self._torch_path = os.path.dirname(os.path.abspath(torch.__file__))
+ super(AdaptiveRecomputingPluggableAllocatorBuilder, self).__init__(self.NAME)
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += [
+ os.path.join(self._torch_path, 'include'),
+ os.path.join(self._torch_path, 'include/torch/csrc/api/include'),
+ os.path.join(self._torch_npu_path, 'include/third_party/acl/inc/acl/'),
+ ]
+ return paths
+
+ def sources(self):
+ return ['ops/csrc/pluggable_allocator/adaptive_recomputing/NpuCachingCustomAllocator.cpp']
+
+ def cxx_args(self):
+ args = ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', '-fPIC', '-pie',
+ '-s', '-D_FORTIFY_SOURCE=2', '-O2',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"]
+ return args
+
+ def extra_ldflags(self):
+ flags = [
+ '-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu'
+ ]
+ return flags
diff --git a/model/train/yoco_moe/mindspeed/op_builder/algorithm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/algorithm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6a54374dd1ec9e309f9465678039c67228af0e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/algorithm_builder.py
@@ -0,0 +1,11 @@
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class AlgorithmOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "algorithm"
+
+ def __init__(self):
+ super(AlgorithmOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/algorithm/algorithm.cpp']
diff --git a/model/train/yoco_moe/mindspeed/op_builder/atb_builder.py b/model/train/yoco_moe/mindspeed/op_builder/atb_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f0750bfc2f8264cc92e084c1185325aacd0190e
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/atb_builder.py
@@ -0,0 +1,49 @@
+import os
+import site
+from torch.utils.cpp_extension import load
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class AtbOpBuilder(MindSpeedOpBuilder):
+ def __init__(self, name):
+ super(AtbOpBuilder, self).__init__(name)
+
+ def get_atb_path(self):
+ if os.getenv('ATB_HOME_PATH') is not None:
+ return os.environ['ATB_HOME_PATH'] # 'ATB_HOME_PATH' refers to home path of Ascend Transformer Boost lib, which is used for ops loading.
+ else:
+ raise RuntimeError("please install cann-nnal package first")
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths.append(os.path.join(self._torch_npu_path, 'include', 'third_party', 'acl', 'inc'))
+ paths.append('ops/csrc/atb/extensions/inc')
+ paths.append(os.path.join(self.get_atb_path(), 'include'))
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'",
+ '-D ENABLE_ATB',
+ '-fstack-protector-all',
+ '-Wl,-z,relro,-z,now,-z,noexecstack',
+ '-fPIC', '-pie', '-s'
+ ]
+ return args
+
+ def extra_ldflags(self):
+ flags = ['-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu',
+ '-L' + os.path.join(self.get_atb_path(), 'lib'), '-lasdops',
+ '-L' + os.path.join(self.get_atb_path(), 'lib'), '-llcal',
+ '-L' + os.path.join(self.get_atb_path(), 'lib'), '-latb']
+ return flags
+
+ def load(self, verbose=True):
+ if os.getenv('ASDOPS_HOME_PATH') is None or os.getenv('ASDOPS_OPS_PATH') is None:
+ os.environ['ASDOPS_HOME_PATH'] = self.get_atb_path() # 'ASDOPS_HOME_PATH' refers to home path of Ascend ops, which is used for ops loading.
+ os.environ['ASDOPS_OPS_PATH'] = os.path.join(self.get_atb_path(), 'ops') # 'ASDOPS_OPS_PATH' refers to ops path of Ascend ops, which is used for ops loading.
+ return super().load()
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/builder.py b/model/train/yoco_moe/mindspeed/op_builder/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cacac2a355eef229fe8e24f14ddbe91753ac77d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/builder.py
@@ -0,0 +1,77 @@
+import re
+import os
+from abc import ABC, abstractmethod
+from typing import List, Union
+from torch.utils.cpp_extension import load
+from torch.library import Library
+import torch_npu
+import mindspeed
+
+ASCEND_HOME_PATH = "ASCEND_HOME_PATH"
+AS_LIBRARY = Library("mindspeed", "DEF")
+
+
+class MindSpeedOpBuilder(ABC):
+ _cann_path = None
+ _torch_npu_path = None
+ _cann_version = None
+ _loaded_ops = {}
+
+ def __init__(self, name):
+ self.name = name
+ self._cann_path = self.get_cann_path()
+ self._torch_npu_path = os.path.dirname(os.path.abspath(torch_npu.__file__))
+
+ def get_cann_path(self):
+ if ASCEND_HOME_PATH in os.environ and os.path.exists(os.environ[ASCEND_HOME_PATH]):
+ return os.environ[ASCEND_HOME_PATH]
+ return None
+
+ def get_absolute_paths(self, paths):
+ mindspeed_path = os.path.abspath(os.path.dirname(mindspeed.__file__))
+ return [os.path.join(mindspeed_path, path) for path in paths]
+
+ def register_op_proto(self, op_proto: Union[str, List[str]]):
+ if isinstance(op_proto, str):
+ op_proto = [op_proto]
+ for proto in op_proto:
+ AS_LIBRARY.define(proto)
+
+ @abstractmethod
+ def sources(self):
+ ...
+
+ def include_paths(self):
+ paths = [
+ os.path.join(self._torch_npu_path, 'include'),
+ os.path.join(self._torch_npu_path, 'include/third_party/hccl/inc'),
+ os.path.join(self._torch_npu_path, 'include/third_party/acl/inc'),
+ os.path.join(self._cann_path, 'include'),
+ ]
+ return paths
+
+ def cxx_args(self):
+ args = ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', '-fPIC', '-pie',
+ '-s', '-fvisibility=hidden', '-D_FORTIFY_SOURCE=2', '-O2']
+ return args
+
+ def extra_ldflags(self):
+ flags = [
+ '-L' + os.path.join(self._cann_path, 'lib64'), '-lascendcl',
+ '-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu'
+ ]
+ return flags
+
+ def load(self, verbose=True):
+ if self.name in __class__._loaded_ops:
+ return __class__._loaded_ops[self.name]
+
+ op_module = load(name=self.name,
+ sources=self.get_absolute_paths(self.sources()),
+ extra_include_paths=self.get_absolute_paths(self.include_paths()),
+ extra_cflags=self.cxx_args(),
+ extra_ldflags=self.extra_ldflags(),
+ verbose=verbose)
+ __class__._loaded_ops[self.name] = op_module
+
+ return op_module
diff --git a/model/train/yoco_moe/mindspeed/op_builder/ffn_builder.py b/model/train/yoco_moe/mindspeed/op_builder/ffn_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..27630818390b816ba6af288c07d93c37bad7e0bc
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/ffn_builder.py
@@ -0,0 +1,373 @@
+# Copyright (c) 2024, Huawei Technologies.
+
+
+from typing import (
+ Any,
+ Callable,
+ ContextManager,
+ Iterator,
+ List,
+ Literal,
+ NamedTuple,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
+
+from mindspeed.op_builder.builder import MindSpeedOpBuilder, AS_LIBRARY
+
+import torch
+import torchair
+from torch.library import Library, impl
+
+torch_npu_api_version = None
+try:
+ from torchair import ge
+ from torchair.ge import Tensor, TensorSpec, DataType
+ from torchair import register_fx_node_ge_converter
+except ImportError:
+ torch_npu_api_version = 1
+ from torchair.ge_concrete_graph import ge_apis as ge
+ from torchair.ge_concrete_graph.fx2ge_converter import register_fx_node_ge_converter
+ from torchair.ge_concrete_graph.ge_graph import Tensor, TensorSpec, DataType
+ from torchair.ge_concrete_graph.ge_graph import get_default_ge_graph, next_unique_name
+ from torchair.ge_concrete_graph.ge_graph import compat_as_bytes
+ from torchair.ge_concrete_graph.ge_graph import get_invalid_desc
+else:
+ torch_npu_api_version = 2
+
+
+class FFNOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_ffn"
+ OP_PROTO = "npu_ffn(Tensor x, Tensor weight1, Tensor weight2, str activation, *, Tensor? expert_tokens=None, \
+ Tensor? expert_tokens_index=None, Tensor? bias1=None, Tensor? bias2=None, Tensor? scale=None, \
+ Tensor? offset=None, Tensor? deq_scale1=None, Tensor? deq_scale2=None, Tensor? antiquant_scale1=None, \
+ Tensor? antiquant_scale2=None, Tensor? antiquant_offset1=None, Tensor? antiquant_offset2=None, \
+ int? inner_precise=None, ScalarType? output_dtype=None) -> Tensor"
+
+ def __init__(self):
+ super(FFNOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def sources(self):
+ return ['ops/csrc/cann/ffn.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_ffn", "Meta")
+ def npu_ffn_forward(x, weight1, weight2, activation, *, expert_tokens=None, expert_tokens_index=None,
+ bias1=None, bias2=None, scale=None, offset=None, deq_scale1=None, deq_scale2=None,
+ antiquant_scale1=None, antiquant_scale2=None, antiquant_offset1=None,
+ antiquant_offset2=None, inner_precise=0, output_dtype=None):
+ dim_list = []
+ for i in range(0, x.dim() - 1):
+ dim_list.append(x.size(i))
+ dim_list.append(weight2.size(weight2.dim() - 1))
+ if x.dtype == torch.int8:
+ if output_dtype is not None and output_dtype == torch.bfloat16:
+ return x.new_empty(tuple(dim_list), dtype=torch.bfloat16)
+ else:
+ return x.new_empty(tuple(dim_list), dtype=torch.float16)
+ else:
+ return x.new_empty(tuple(dim_list))
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_ffn.default)
+ def convert_npu_ffn(
+ x: Tensor,
+ weight1: Tensor,
+ weight2: Tensor,
+ activation: str,
+ *,
+ expert_tokens: Optional[Tensor] = None,
+ expert_tokens_index: Optional[Tensor] = None,
+ bias1: Optional[Tensor] = None,
+ bias2: Optional[Tensor] = None,
+ scale: Optional[Tensor] = None,
+ offset: Optional[Tensor] = None,
+ deq_scale1: Optional[Tensor] = None,
+ deq_scale2: Optional[Tensor] = None,
+ antiquant_scale1: Optional[Tensor] = None,
+ antiquant_scale2: Optional[Tensor] = None,
+ antiquant_offset1: Optional[Tensor] = None,
+ antiquant_offset2: Optional[Tensor] = None,
+ inner_precise: Optional[int] = 0,
+ output_dtype: Optional[int] = None,
+ meta_outputs: TensorSpec = None
+ ):
+ '''"npu::npu_ffn(Tensor x, Tensor weight1, Tensor weight2, str activation, *, Tensor? expert_tokens=None,
+ Tensor? expert_tokens_index=None, Tensor? bias1=None, Tensor? bias2=None, Tensor? scale=None,
+ Tensor? offset=None, Tensor? deq_scale1=None, Tensor? deq_scale2=None,
+ Tensor? antiquant_scale1=None, Tensor? antiquant_scale2=None, Tensor? antiquant_offset1=None,
+ Tensor? antiquant_offset2=None, int? inner_precise=None, ScalarType? output_dtype=None)
+ -> Tensor
+ "'''
+ tokens_index_flag = False
+ if expert_tokens is not None and expert_tokens_index is not None:
+ raise ValueError("Cannot assign the value to expert_tokens and expert_tokens_index simultaneously!")
+ elif expert_tokens_index is not None:
+ tokens_index_flag = True
+ expert_tokens = expert_tokens_index
+
+ y_dtype = -1
+ if x.dtype == DataType.DT_INT8 and output_dtype is not None:
+ if output_dtype == torch.float16:
+ y_dtype = 0
+ elif output_dtype == torch.bfloat16:
+ y_dtype = 1
+ else:
+ raise NotImplementedError("In the quant scenario, output_dtype should be float16 or bfloat16,"
+ "otherwise it should be None!")
+
+ return FFN(x, weight1, weight2, expert_tokens=expert_tokens, bias1=bias1, bias2=bias2, scale=scale,
+ offset=offset, deq_scale1=deq_scale1, deq_scale2=deq_scale2, antiquant_scale1=antiquant_scale1,
+ antiquant_scale2=antiquant_scale2, antiquant_offset1=antiquant_offset1,
+ antiquant_offset2=antiquant_offset2, activation=activation, inner_precise=inner_precise,
+ output_dtype=y_dtype, tokens_index_flag=tokens_index_flag)
+
+
+FFN = None
+if torch_npu_api_version == 2:
+ def FFNV2(x: Tensor,
+ weight1: Tensor,
+ weight2: Tensor,
+ expert_tokens: Optional[Tensor],
+ bias1: Optional[Tensor],
+ bias2: Optional[Tensor],
+ scale: Optional[Tensor],
+ offset: Optional[Tensor],
+ deq_scale1: Optional[Tensor],
+ deq_scale2: Optional[Tensor],
+ antiquant_scale1: Optional[Tensor],
+ antiquant_scale2: Optional[Tensor],
+ antiquant_offset1: Optional[Tensor],
+ antiquant_offset2: Optional[Tensor],
+ *,
+ activation: str,
+ inner_precise: int = 0,
+ output_dtype: int = -1,
+ tokens_index_flag: bool = False):
+ """REG_OP(FFN)\n
+ .INPUT(x, TensorType({DT_INT8, DT_FLOAT16, DT_BF16}))\n
+ .INPUT(weight1, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
+ .INPUT(weight2, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
+ .OPTIONAL_INPUT(expert_tokens, TensorType({DT_INT64}))\n
+ .OPTIONAL_INPUT(bias1, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
+ .OPTIONAL_INPUT(bias2, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
+ .OPTIONAL_INPUT(scale, TensorType({DT_FLOAT}))\n
+ .OPTIONAL_INPUT(offset, TensorType({DT_FLOAT}))\n
+ .OPTIONAL_INPUT(deq_scale1, TensorType({DT_UINT64, DT_BF16}))\n
+ .OPTIONAL_INPUT(deq_scale2, TensorType({DT_UINT64, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_scale1, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_scale2, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_offset1, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_offset2, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .REQUIRED_ATTR(activation, String)\n
+ .ATTR(inner_precise, Int, 0)\n
+ .ATTR(output_dtype, Int, -1)\n
+ .ATTR(tokens_index_flag, Bool, false)\n
+ """
+
+ y = torchair.ge.custom_op("FFN",
+ inputs={
+ "x": x,
+ "weight1": weight1,
+ "weight2": weight2,
+ "expert_tokens": expert_tokens,
+ "bias1": bias1,
+ "bias2": bias2,
+ "scale": scale,
+ "offset": offset,
+ "deq_scale1": deq_scale1,
+ "deq_scale2": deq_scale2,
+ "antiquant_scale1": antiquant_scale1,
+ "antiquant_scale2": antiquant_scale2,
+ "antiquant_offset1": antiquant_offset1,
+ "antiquant_offset2": antiquant_offset2
+ },
+ attrs={
+ "activation": ge.attr.Str(activation),
+ "inner_precise": ge.attr.Int(inner_precise),
+ "output_dtype": ge.attr.Int(output_dtype),
+ "tokens_index_flag": ge.attr.Bool(tokens_index_flag)
+ },
+ outputs=[
+ "y"
+ ])
+
+ return y
+ FFN = FFNV2
+elif torch_npu_api_version == 1:
+ def FFNV1(x: Tensor,
+ weight1: Tensor,
+ weight2: Tensor,
+ expert_tokens: Optional[Tensor],
+ bias1: Optional[Tensor],
+ bias2: Optional[Tensor],
+ scale: Optional[Tensor],
+ offset: Optional[Tensor],
+ deq_scale1: Optional[Tensor],
+ deq_scale2: Optional[Tensor],
+ antiquant_scale1: Optional[Tensor],
+ antiquant_scale2: Optional[Tensor],
+ antiquant_offset1: Optional[Tensor],
+ antiquant_offset2: Optional[Tensor],
+ *,
+ activation: str,
+ inner_precise: int = 0,
+ output_dtype: int = -1,
+ tokens_index_flag: bool = False,
+ dependencies=[],
+ node_name=None):
+ """REG_OP(FFN)\n
+ .INPUT(x, TensorType({DT_INT8, DT_FLOAT16, DT_BF16}))\n
+ .INPUT(weight1, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
+ .INPUT(weight2, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
+ .OPTIONAL_INPUT(expert_tokens, TensorType({DT_INT64}))\n
+ .OPTIONAL_INPUT(bias1, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
+ .OPTIONAL_INPUT(bias2, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
+ .OPTIONAL_INPUT(scale, TensorType({DT_FLOAT}))\n
+ .OPTIONAL_INPUT(offset, TensorType({DT_FLOAT}))\n
+ .OPTIONAL_INPUT(deq_scale1, TensorType({DT_UINT64, DT_BF16}))\n
+ .OPTIONAL_INPUT(deq_scale2, TensorType({DT_UINT64, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_scale1, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_scale2, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_offset1, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_offset2, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .REQUIRED_ATTR(activation, String)\n
+ .ATTR(inner_precise, Int, 0)\n
+ .ATTR(output_dtype, Int, -1)\n
+ .ATTR(tokens_index_flag, Bool, false)\n
+ """
+
+ op = get_default_ge_graph().op.add()
+ op.type = "FFN"
+ op.name = next_unique_name(node_name, "FFN")
+
+ # process dependices
+ for dependency in dependencies:
+ op.input.append(dependency.controller)
+
+ # process inputs
+ op.input.append(x.tensor)
+ op.input_desc.add().CopyFrom(x.desc)
+ op.input_desc[-1].name = "x"
+ op.input.append(weight1.tensor)
+ op.input_desc.add().CopyFrom(weight1.desc)
+ op.input_desc[-1].name = "weight1"
+ op.input.append(weight2.tensor)
+ op.input_desc.add().CopyFrom(weight2.desc)
+ op.input_desc[-1].name = "weight2"
+ if expert_tokens is not None:
+ op.input.append(expert_tokens.tensor)
+ op.input_desc.add().CopyFrom(expert_tokens.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "expert_tokens"
+ if bias1 is not None:
+ op.input.append(bias1.tensor)
+ op.input_desc.add().CopyFrom(bias1.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "bias1"
+ if bias2 is not None:
+ op.input.append(bias2.tensor)
+ op.input_desc.add().CopyFrom(bias2.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "bias2"
+ if scale is not None:
+ op.input.append(scale.tensor)
+ op.input_desc.add().CopyFrom(scale.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "scale"
+ if offset is not None:
+ op.input.append(offset.tensor)
+ op.input_desc.add().CopyFrom(offset.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "offset"
+ if deq_scale1 is not None:
+ op.input.append(deq_scale1.tensor)
+ op.input_desc.add().CopyFrom(deq_scale1.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "deq_scale1"
+ if deq_scale2 is not None:
+ op.input.append(deq_scale2.tensor)
+ op.input_desc.add().CopyFrom(deq_scale2.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "deq_scale2"
+ if antiquant_scale1 is not None:
+ op.input.append(antiquant_scale1.tensor)
+ op.input_desc.add().CopyFrom(antiquant_scale1.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "antiquant_scale1"
+ if antiquant_scale2 is not None:
+ op.input.append(antiquant_scale2.tensor)
+ op.input_desc.add().CopyFrom(antiquant_scale2.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "antiquant_scale2"
+ if antiquant_offset1 is not None:
+ op.input.append(antiquant_offset1.tensor)
+ op.input_desc.add().CopyFrom(antiquant_offset1.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "antiquant_offset1"
+ if antiquant_offset2 is not None:
+ op.input.append(antiquant_offset2.tensor)
+ op.input_desc.add().CopyFrom(antiquant_offset2.desc)
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "antiquant_offset2"
+
+ # process attrs
+ op.attr["activation"].s = compat_as_bytes(activation)
+ op.attr["inner_precise"].i = inner_precise
+ op.attr["output_dtype"].i = output_dtype
+ op.attr["tokens_index_flag"].b = tokens_index_flag
+
+ # process outputs
+ output_index = 0
+ op.output_desc.add().name = "y"
+ y = Tensor(op, output_index)
+ output_index += 1
+
+ return y
+ FFN = FFNV1
+else:
+ raise ValueError("touch_npu_api_version unsupport!")
diff --git a/model/train/yoco_moe/mindspeed/op_builder/fused_ema_adamw_builder.py b/model/train/yoco_moe/mindspeed/op_builder/fused_ema_adamw_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..25ac3b892cb5b71d999d99ad34a9901856b7742d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/fused_ema_adamw_builder.py
@@ -0,0 +1,31 @@
+import os
+import torch
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class FusedEmaAdamWOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_apply_fused_ema_adamw"
+ _torch_path = None
+
+ def __init__(self):
+ from sysconfig import get_paths
+ self._torch_path = os.path.dirname(os.path.abspath(torch.__file__))
+ super(FusedEmaAdamWOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_apply_fused_ema_adamw.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc',
+ os.path.join(self._torch_path, 'include'),
+ os.path.join(self._torch_path, 'include/torch/csrc/api/include'),
+ os.path.join(self._torch_npu_path, 'include/torch_npu/csrc/framework/utils'),
+ os.path.join(self._torch_npu_path, 'include/torch_npu/csrc/aten'),
+ ]
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += ['-Wno-narrowing']
+ return args
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/fusion_attention_v2_builder.py b/model/train/yoco_moe/mindspeed/op_builder/fusion_attention_v2_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6cf99c72a0b1f219a8dc4cce6fbde425c7dc786
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/fusion_attention_v2_builder.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2024, Huawei Technologies.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class FusionAttentionV2OpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "fusion_attention_v2"
+
+ def __init__(self):
+ super(FusionAttentionV2OpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/fusion_attention_v2.cpp', 'ops/csrc/flop_counter/flop_counter.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
diff --git a/model/train/yoco_moe/mindspeed/op_builder/gmm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/gmm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b53b9e7d7470ad511e7cca8efd89c0a62b72cef7
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/gmm_builder.py
@@ -0,0 +1,352 @@
+from typing import List, Optional
+from collections import namedtuple
+import torch
+import torchair
+from torch.library import impl
+
+from mindspeed.op_builder.builder import MindSpeedOpBuilder, AS_LIBRARY
+
+torch_npu_api_version = None
+try:
+ from torchair import ge
+ from torchair.ge import Tensor, TensorSpec, DataType
+ from torchair import register_fx_node_ge_converter
+except ImportError:
+ torch_npu_api_version = 1
+ from torchair.ge_concrete_graph import ge_apis as ge
+ from torchair.ge_concrete_graph.fx2ge_converter import register_fx_node_ge_converter
+ from torchair.ge_concrete_graph.ge_graph import Tensor, TensorSpec, DataType
+ from torchair.ge_concrete_graph.ge_graph import get_default_ge_graph, next_unique_name
+ from torchair.ge_concrete_graph.ge_graph import compat_as_bytes
+ from torchair.ge_concrete_graph.ge_graph import get_invalid_desc
+else:
+ torch_npu_api_version = 2
+
+if torch_npu_api_version == 2:
+ def fill_empty_tensor(dtype):
+ return Fill(ge.Const(0), ge.Cast(0., dst_type=dtype))
+else:
+ def fill_empty_tensor(dtype):
+ return ge.Fill([0], ge.Cast(0., dst_type=dtype))
+
+
+gmm_param = namedtuple('gmm_param', ['bias', 'scale', 'offset', 'antiquant_scale', 'antiquant_offset'])
+
+
+def conveter_npu_gmm_param(
+ x: Tensor,
+ bias: Tensor,
+ group_type: int
+):
+ if group_type == 2:
+ raise ValueError(f"GMM: graph mode does not support group_type 2!")
+ x_dtype = x.dtype
+ if bias is None:
+ if x_dtype == DataType.DT_BF16:
+ bias = fill_empty_tensor(DataType.DT_FLOAT)
+ elif x_dtype == DataType.DT_UINT8:
+ bias = fill_empty_tensor(DataType.DT_INT32)
+ else:
+ bias = fill_empty_tensor(x_dtype)
+ scale = [fill_empty_tensor(DataType.DT_UINT64)]
+ offset = [fill_empty_tensor(DataType.DT_FLOAT)]
+ antiquant_scale = [fill_empty_tensor(DataType.DT_FLOAT16)]
+ antiquant_offset = [fill_empty_tensor(DataType.DT_FLOAT16)]
+ if x_dtype == DataType.DT_BF16:
+ antiquant_scale = [fill_empty_tensor(DataType.DT_BF16)]
+ antiquant_offset = [fill_empty_tensor(DataType.DT_BF16)]
+ return gmm_param(bias, scale, offset, antiquant_scale, antiquant_offset)
+
+
+class GMMOpBuilderPublic(MindSpeedOpBuilder):
+ TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
+
+ def sources(self):
+ return ['ops/csrc/cann/gmm.cpp', 'ops/csrc/flop_counter/flop_counter.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ if self.TORCH_MAJOR >= 2 and self.TORCH_MINOR >= 1:
+ cpp_std = " -std=c++17"
+ else:
+ cpp_std = " -std=c++14"
+ args.append(cpp_std)
+ return args
+
+
+class GMMOpBuilder(GMMOpBuilderPublic):
+ OP_NAME = "grouped_matmul"
+ OP_PROTO = (
+ "npu_gmm.Tensor(Tensor original_weight, Tensor x, Tensor weight, *, Tensor? bias=None, Tensor? group_list=None, int? group_type=0, bool? gemm_fusion=False) -> Tensor",
+ "npu_gmm.List(Tensor original_weight, Tensor x, Tensor weight, *, Tensor? bias=None, int[]? group_list=None, int? group_type=0, bool? gemm_fusion=False) -> Tensor"
+ )
+
+ def __init__(self):
+ super(GMMOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_gmm.Tensor", "Meta")
+ def npu_gmm_forward(original_weight, x, weight, *, bias=None, group_list=None, group_type=0, gemm_fusion=False):
+ BM = x.shape[0]
+ N = weight.shape[-1]
+ y = x.new_empty((BM, N), dtype=x.dtype)
+ return y
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_gmm.Tensor)
+ def conveter_npu_gmm(
+ original_weight: Tensor,
+ x: Tensor,
+ weight: Tensor,
+ *,
+ bias: Optional[Tensor] = None,
+ group_list: Optional[Tensor] = None,
+ group_type: Optional[int] = 0,
+ gemm_fusion: Optional[bool] = False,
+ meta_outputs: TensorSpec = None,
+ ):
+ """npu_gmm(Tensor x, Tensor weight, *, Tensor? bias=None, Tensor? group_list=None, int? group_type=0) -> Tensor
+ """
+ result = conveter_npu_gmm_param(x, bias, group_type)
+
+ return GroupedMatmul([x], [weight], [result.bias], result.scale, result.offset, result.antiquant_scale,
+ result.antiquant_offset, group_list, split_item=3, group_type=group_type,
+ dtype=-1, transpose_weight=False, group_list_type=0)[0]
+
+
+class GMMV2OpBuilder(GMMOpBuilderPublic):
+ OP_NAME = "grouped_matmul_v2"
+ OP_PROTO = (
+ "npu_gmm_v2.Tensor(Tensor original_weight, Tensor x, Tensor weight, *, Tensor? bias=None, Tensor? group_list=None, int? group_type=0, bool? gemm_fusion=False) -> Tensor"
+ )
+
+ def __init__(self):
+ super(GMMV2OpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_gmm_v2.Tensor", "Meta")
+ def npu_gmm_v2_forward(original_weight, x, weight, *, bias=None, group_list=None, group_type=0, gemm_fusion=False):
+ BM = x.shape[0]
+ N = weight.shape[-1]
+ y = x.new_empty((BM, N), dtype=x.dtype)
+ return y
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_gmm_v2.Tensor)
+ def conveter_npu_gmm_v2(
+ original_weight: Tensor,
+ x: Tensor,
+ weight: Tensor,
+ *,
+ bias: Optional[Tensor] = None,
+ group_list: Optional[Tensor] = None,
+ group_type: Optional[int] = 0,
+ gemm_fusion: Optional[bool] = False,
+ meta_outputs: TensorSpec = None,
+ ):
+ """npu_gmm_v2(Tensor x, Tensor weight, *, Tensor? bias=None, Tensor? group_list=None, int? group_type=0) -> Tensor
+ """
+ result = conveter_npu_gmm_param(x, bias, group_type)
+
+ return GroupedMatmul([x], [weight], [result.bias], result.scale, result.offset, result.antiquant_scale,
+ result.antiquant_offset, group_list, split_item=3, group_type=group_type,
+ dtype=-1, transpose_weight=False, group_list_type=1)[0]
+
+if torch_npu_api_version == 2:
+ def Fill(dims: Tensor, value: Tensor):
+ """REG_OP(Fill)\n
+ .INPUT(dims, TensorType::IndexNumberType())\n
+ .INPUT(value, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_COMPLEX64, DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, DT_UINT16, DT_COMPLEX128, DT_FLOAT16, DT_BF16, DT_UINT32, DT_UINT64, DT_STRING}))\n
+ .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_COMPLEX64, DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, DT_UINT16, DT_COMPLEX128, DT_FLOAT16, DT_BF16, DT_UINT32, DT_UINT64, DT_STRING}))\n
+ """
+
+ y = torchair.ge.custom_op("Fill",
+ inputs={
+ "dims":dims,
+ "value":value
+ },
+ outputs=["y"]
+ )
+
+ return y
+
+GroupedMatmul = None
+if torch_npu_api_version == 2:
+ def GroupedMatmulV2(x: List[Tensor], weight: List[Tensor], bias: List[Tensor], scale: List[Tensor],
+ offset: List[Tensor], antiquant_scale: List[Tensor], antiquant_offset: List[Tensor],
+ group_list: Optional[Tensor] = None, per_token_scale: Optional[Tensor] = None, *,
+ split_item: int = 0, dtype: int = 0, transpose_weight: bool = False, transpose_x: bool = False,
+ group_type: int = -1, group_list_type: int = 0, act_type: int = 0):
+ """REG_OP(GroupedMatmul)\n
+ .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_BF16, DT_INT8}))\n
+ .DYNAMIC_INPUT(weight, TensorType({DT_FLOAT16, DT_BF16, DT_INT8}))\n
+ .DYNAMIC_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))\n
+ .DYNAMIC_INPUT(scale, TensorType({DT_UINT64}))\n
+ .DYNAMIC_INPUT(offset, TensorType({DT_FLOAT32}))\n
+ .DYNAMIC_INPUT(antiquant_scale, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .DYNAMIC_INPUT(antiquant_offset, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(group_list, TensorType({DT_INT64}))\n
+ .OPTIONAL_INPUT(per_token_scale, TensorType({DT_FLOAT}))\n
+ .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT}))\n
+ .ATTR(split_item, Int, 0)\n
+ .ATTR(dtype, Int, 0)\n
+ .ATTR(transpose_weight, Bool, false)\n
+ .ATTR(transpose_x, Bool, false)\n
+ .ATTR(group_type, Int, -1)\n
+ .ATTR(group_list_type, Int, 0)\n
+ .ATTR(act_type, Int, 0)\n
+ """
+
+ y = torchair.ge.custom_op("GroupedMatmul",
+ inputs={
+ "x":x,
+ "weight":weight,
+ "bias":bias,
+ "scale":scale,
+ "offset":offset,
+ "antiquant_scale":antiquant_scale,
+ "antiquant_offset":antiquant_offset,
+ "group_list":group_list,
+ "per_token_scale": per_token_scale
+ },
+ attrs={
+ "split_item":ge.attr.Int(split_item),
+ "dtype":ge.attr.Int(dtype),
+ "transpose_weight":ge.attr.Bool(transpose_weight),
+ "transpose_x":ge.attr.Bool(transpose_x),
+ "group_type":ge.attr.Int(group_type),
+ "group_list_type":ge.attr.Int(group_list_type),
+ "act_type":ge.attr.Int(act_type)
+ },
+ outputs=[("y", 1)]
+ )
+
+ return y
+ GroupedMatmul = GroupedMatmulV2
+elif torch_npu_api_version == 1:
+ def GroupedMatmulV1(x: List[Tensor], weight: List[Tensor], bias: List[Tensor], scale: List[Tensor],
+ offset: List[Tensor], antiquant_scale: List[Tensor], antiquant_offset: List[Tensor],
+ group_list: Optional[Tensor] = None, per_token_scale: Optional[Tensor] = None, *,
+ split_item: int = 0, dtype: int = 0, transpose_weight: bool = False, transpose_x: bool = False,
+ group_type: int = -1, group_list_type: int = 0, act_type: int = 0,
+ dependencies=None, node_name=None):
+ """REG_OP(GroupedMatmul)\n
+ .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_BF16, DT_INT8}))\n
+ .DYNAMIC_INPUT(weight, TensorType({DT_FLOAT16, DT_BF16, DT_INT8}))\n
+ .DYNAMIC_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))\n
+ .DYNAMIC_INPUT(scale, TensorType({DT_UINT64}))\n
+ .DYNAMIC_INPUT(offset, TensorType({DT_FLOAT32}))\n
+ .DYNAMIC_INPUT(antiquant_scale, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .DYNAMIC_INPUT(antiquant_offset, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(group_list, TensorType({DT_INT64}))\n
+ .OPTIONAL_INPUT(per_token_scale, TensorType({DT_FLOAT}))\n
+ .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT}))\n
+ .ATTR(split_item, Int, 0)\n
+ .ATTR(dtype, Int, 0)\n
+ .ATTR(transpose_weight, Bool, false)\n
+ .ATTR(transpose_x, Bool, false)\n
+ .ATTR(group_type, Int, -1)\n
+ .ATTR(group_list_type, Int, 0)\n
+ .ATTR(act_type, Int, 0)\n
+ """
+
+ op = get_default_ge_graph().op.add()
+ op.type = "GroupedMatmul"
+ op.name = next_unique_name(node_name, "GroupedMatmul")
+
+ # process dependices
+ if dependencies is not None:
+ for dependency in dependencies:
+ op.input.append(dependency.controller)
+
+ # process inputs
+ if not isinstance(x, (tuple, list)):
+ raise AssertionError
+ for i, v in enumerate(x):
+ op.input.append(v.tensor)
+ op.input_desc.add().CopyFrom(v.desc)
+ op.input_desc[-1].name = "x" + str(i)
+ if not isinstance(weight, (tuple, list)):
+ raise AssertionError("weight must be a tuple or a list.")
+ for i, v in enumerate(weight):
+ op.input.append(v.tensor)
+ op.input_desc.add().CopyFrom(v.desc)
+ op.input_desc[-1].name = "weight" + str(i)
+ if not isinstance(bias, (tuple, list)):
+ raise AssertionError("bias must be a tuple or a list.")
+ for i, v in enumerate(bias):
+ op.input.append(v.tensor)
+ op.input_desc.add().CopyFrom(v.desc)
+ op.input_desc[-1].name = "bias" + str(i)
+ if not isinstance(scale, (tuple, list)):
+ raise AssertionError("scale must be a tuple or a list.")
+ for i, v in enumerate(scale):
+ op.input.append(v.tensor)
+ op.input_desc.add().CopyFrom(v.desc)
+ op.input_desc[-1].name = "scale" + str(i)
+ if not isinstance(offset, (tuple, list)):
+ raise AssertionError("offset must be a tuple or a list.")
+ for i, v in enumerate(offset):
+ op.input.append(v.tensor)
+ op.input_desc.add().CopyFrom(v.desc)
+ op.input_desc[-1].name = "offset" + str(i)
+ if not isinstance(antiquant_scale, (tuple, list)):
+ raise AssertionError("antiquant_scale must be a tuple or a list.")
+ for i, v in enumerate(antiquant_scale):
+ op.input.append(v.tensor)
+ op.input_desc.add().CopyFrom(v.desc)
+ op.input_desc[-1].name = "antiquant_scale" + str(i)
+ if not isinstance(antiquant_offset, (tuple, list)):
+ raise AssertionError("antiquant_offset must be a tuple or a list.")
+ for i, v in enumerate(antiquant_offset):
+ op.input.append(v.tensor)
+ op.input_desc.add().CopyFrom(v.desc)
+ op.input_desc[-1].name = "antiquant_offset" + str(i)
+ if group_list is not None:
+ op.input.append(group_list.tensor)
+ op.input_desc.add().CopyFrom(group_list.desc)
+ op.input_desc[-1].name = "group_list"
+ else:
+ op.input.append("")
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "group_list"
+ if per_token_scale is not None:
+ op.input.append(per_token_scale.tensor)
+ op.input_desc.add().CopyFrom(per_token_scale.desc)
+ op.input_desc[-1].name = "per_token_scale"
+ else:
+ op.input.append("")
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "per_token_scale"
+
+ # process attrs
+ op.attr["split_item"].i = split_item
+ op.attr["dtype"].i = dtype
+ op.attr["transpose_weight"].b = transpose_weight
+ op.attr["transpose_x"].b = transpose_x
+ op.attr["group_type"].i = group_type
+ op.attr["group_list_type"].i = group_list_type
+ op.attr["act_type"].i = act_type
+
+ # process outputs
+ output_index = 0
+ op.output_desc.add().name = "y" + str(output_index)
+ y = [Tensor(op, output_index)]
+
+ return y
+ GroupedMatmul = GroupedMatmulV1
+else:
+ raise ValueError("touch_npu_api_version unsupport!")
diff --git a/model/train/yoco_moe/mindspeed/op_builder/groupmatmul_add_builder.py b/model/train/yoco_moe/mindspeed/op_builder/groupmatmul_add_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..24dba7565738059167d78558463ec2855fd4c1e0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/groupmatmul_add_builder.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from mindspeed.op_builder.atb_builder import AtbOpBuilder
+
+
+class GroupMatmulAddOpBuilder(AtbOpBuilder):
+ OP_NAME = "npu_groupmatmul_add_fp32"
+
+ def __init__(self):
+ super(GroupMatmulAddOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/atb/groupmatmul_add.cpp',
+ 'ops/csrc/atb/utils/atb_adapter.cpp',
+ 'ops/csrc/flop_counter/flop_counter.cpp']
+
diff --git a/model/train/yoco_moe/mindspeed/op_builder/lcal_builder.py b/model/train/yoco_moe/mindspeed/op_builder/lcal_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..10acf0e53b6902d1e979685b9517986d9989107a
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/lcal_builder.py
@@ -0,0 +1,18 @@
+from mindspeed.op_builder.atb_builder import AtbOpBuilder
+
+
+class LcalOpBuilder(AtbOpBuilder):
+ OP_NAME = "lcal"
+
+ def __init__(self):
+ super(LcalOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/atb/lcal_coc.cpp',
+ 'ops/csrc/atb/utils/atb_adapter.cpp',
+ 'ops/csrc/flop_counter/flop_counter.cpp']
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args.append(" -std=c++17")
+ return args
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/matmul_add_builder.py b/model/train/yoco_moe/mindspeed/op_builder/matmul_add_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df7522980a8057a1950f47992b7d55a996d53cd
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/matmul_add_builder.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from mindspeed.op_builder.atb_builder import AtbOpBuilder
+
+
+class MatmulAddOpBuilder(AtbOpBuilder):
+ OP_NAME = "npu_matmul_add_fp32"
+
+ def __init__(self):
+ super(MatmulAddOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/atb/matmul_add.cpp',
+ 'ops/csrc/atb/utils/atb_adapter.cpp',
+ 'ops/csrc/flop_counter/flop_counter.cpp']
+
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/memory_fragmentation_builder.py b/model/train/yoco_moe/mindspeed/op_builder/memory_fragmentation_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..17955efb13afe616d9fa925296231c33708fe6c8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/memory_fragmentation_builder.py
@@ -0,0 +1,47 @@
+import torch
+import os
+
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class MemoryFragmentationBuilder(MindSpeedOpBuilder):
+ NAME = "MemoryFragmentation"
+ _torch_path = None
+ _python_path = None
+
+ def __init__(self):
+ from sysconfig import get_paths
+ self._torch_path = os.path.dirname(os.path.abspath(torch.__file__))
+ self._python_path = os.path.dirname(os.path.abspath(get_paths().get('include')))
+ super(MemoryFragmentationBuilder, self).__init__(self.NAME)
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += [
+ os.path.join(self._torch_path, 'include'),
+ os.path.join(self._torch_path, 'include/torch/csrc/api/include'),
+ os.path.join(self._torch_npu_path, 'include/third_party/acl/inc/acl/'),
+ os.path.join(self._python_path),
+ ]
+ return paths
+
+ def sources(self):
+ return ['ops/csrc/pluggable_allocator/memory_fragmentation/EventPool.cpp',
+ 'ops/csrc/pluggable_allocator/memory_fragmentation/CachingAllocatorConfig.cpp',
+ 'ops/csrc/pluggable_allocator/memory_fragmentation/DeviceCachingAllocator.cpp',
+ 'ops/csrc/pluggable_allocator/memory_fragmentation/PluggableAllocator.cpp',
+ 'ops/csrc/pluggable_allocator/memory_fragmentation/PluggableAllocatorFunctions.cpp',
+ 'ops/csrc/pluggable_allocator/memory_fragmentation/Decorator.cpp',
+ 'ops/csrc/pluggable_allocator/memory_fragmentation/Recorder.cpp',
+ 'ops/csrc/pluggable_allocator/memory_fragmentation/common.cpp']
+
+ def cxx_args(self):
+ args = ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', '-fPIC', '-pie',
+ '-s', '-D_FORTIFY_SOURCE=2', '-O2', "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"]
+ return args
+
+ def extra_ldflags(self):
+ flags = [
+ '-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu'
+ ]
+ return flags
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_all_to_all_all_gather_bmm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_all_to_all_all_gather_bmm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ac84f26c3b8eedd51c26ba43afca8b1006a180
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_all_to_all_all_gather_bmm_builder.py
@@ -0,0 +1,244 @@
+# Copyright (c) 2024, Huawei Technologies.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+import torch
+import torchair
+from torch.library import Library, impl
+from mindspeed.op_builder.builder import MindSpeedOpBuilder, AS_LIBRARY
+torch_npu_api_version = None
+try:
+ from torchair import ge
+ from torchair import register_fx_node_ge_converter
+ from torchair.ge import Tensor, TensorSpec, DataType
+except ImportError:
+ ge, Tensor, TensorSpec, DataType = None, None, None, None
+ from torchair.ge_concrete_graph.fx2ge_converter import register_fx_node_ge_converter
+ torch_npu_api_version = 1
+else:
+ torch_npu_api_version = 2
+
+
+class AllToAllAllGatherBatchMatMulOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_alltoall_allgather_bmm"
+ OP_PROTO = "npu_alltoall_allgather_bmm(Tensor x, Tensor weight, \
+ str group_ep, int group_ep_worldsize, \
+ str group_tp, int group_tp_worldsize, \
+ *, Tensor? bias=None, int shard_type=0, int act_type=0, \
+ bool need_allgather_out=False, \
+ bool need_activation_feature=False) -> (Tensor, Tensor, Tensor)"
+
+ def __init__(self):
+ super(AllToAllAllGatherBatchMatMulOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_all_to_all_all_gather_bmm.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_alltoall_allgather_bmm", "Meta")
+ def npu_alltoall_allgather_bmm_forward(x, weight,
+ group_ep, group_ep_worldsize, group_tp, group_tp_worldsize,
+ *, bias=None, shard_type=0, act_type=0,
+ need_allgather_out=False, need_activation_feature=False):
+ batch = weight.size(0)
+ m = x.size(1) * group_ep_worldsize
+ if shard_type == 1:
+ m *= group_tp_worldsize
+ n = weight.size(2)
+ k = weight.size(1)
+
+ if x.size(0) == 0:
+ raise AssertionError('The first dim of x can not be 0.')
+ if x.size(1) == 0:
+ raise AssertionError('The second dim of x can not be 0.')
+ if x.size(2) == 0:
+ raise AssertionError('The last dim of x can not be 0.')
+ if weight.size(0) == 0:
+ raise AssertionError('The first dim of weight can not be 0.')
+ if weight.size(1) == 0:
+ raise AssertionError('The second dim of weight can not be 0.')
+ if weight.size(2) == 0:
+ raise AssertionError('The last dim of weight can not be 0.')
+
+ empty_tensor = x.new_empty((0))
+ return (x.new_empty((batch, m, n)),
+ x.new_empty((batch, m, k)) if need_allgather_out else empty_tensor,
+ x.new_empty((batch, m, n)) if need_activation_feature else empty_tensor)
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_alltoall_allgather_bmm.default)
+ def convert_npu_alltoall_allgather_bmm(
+ x: Tensor,
+ weight: Tensor,
+ group_ep: str,
+ group_ep_worldsize: int,
+ group_tp: str,
+ group_tp_worldsize: int,
+ *,
+ bias: Optional[Tensor] = None,
+ shard_type: Optional[int] = 0,
+ act_type: Optional[int] = 0,
+ need_allgather_out: Optional[bool] = False,
+ need_activation_feature: Optional[bool] = False,
+ meta_outputs: List[TensorSpec] = None):
+ '''"npu_alltoall_allgather_bmm(Tensor x, Tensor weight, str group_ep, str group_tp,
+ int ep_world_size, int tp_world_size, *, Tensor? bias=None, int x_shard_type=0, int act_type=0,
+ bool need_allgather_out=False, bool need_activation_feature=False) -> (Tensor, Tensor, Tensor)"'''
+ if torch_npu_api_version != 2:
+ raise ValueError(f"torch_npu_api_version {torch_npu_api_version} unsupport")
+ CheckDtype(x, weight, bias)
+ return AllToAllAllGatherBatchMatmul(x,
+ weight,
+ group_ep,
+ group_ep_worldsize,
+ group_tp,
+ group_tp_worldsize,
+ bias=bias,
+ shard_type=shard_type,
+ act_type=act_type,
+ need_allgather_out=need_allgather_out,
+ need_activation_feature=need_activation_feature)
+
+
+def CheckDtype(x: Tensor, weight: Tensor, bias: Optional[Tensor]):
+ if x.dtype != DataType.DT_BF16 and x.dtype != DataType.DT_FLOAT16:
+ raise AssertionError(f'type of x must be DT_FLOAT16/DT_BF16, but got {GeDtypeToStr(x.dtype)}.')
+ if weight.dtype != DataType.DT_BF16 and weight.dtype != DataType.DT_FLOAT16:
+ raise AssertionError(f'type of weight must be DT_FLOAT16/DT_BF16, but got {GeDtypeToStr(weight.dtype)}.')
+ if x.dtype != weight.dtype:
+ raise AssertionError(f'type of x and weight must be same, but got x {GeDtypeToStr(x.dtype)} '\
+ f'weight {GeDtypeToStr(weight.dtype)}.')
+ if bias is not None:
+ if bias.dtype != DataType.DT_FLOAT16 and bias.dtype != DataType.DT_FLOAT:
+ raise AssertionError(f'type of bias must DT_FLOAT16/DT_FLOAT32, but got {GeDtypeToStr(bias.dtype)}.')
+ if x.dtype == DataType.DT_FLOAT16 and bias.dtype != DataType.DT_FLOAT16:
+ raise AssertionError(f'type of bias must DT_FLOAT16 when x is DT_FLOAT16, '\
+ f'but got {GeDtypeToStr(bias.dtype)}.')
+ if x.dtype == DataType.DT_BF16 and bias.dtype != DataType.DT_FLOAT:
+ raise AssertionError(f'type of bias must DT_FLOAT32 when x is DT_BF16, '\
+ f'but got {GeDtypeToStr(bias.dtype)}.')
+
+
+def GeDtypeToStr(ge_dtype: DataType):
+ ge_datatype = {
+ DataType.DT_FLOAT: 'DT_FLOAT32',
+ DataType.DT_FLOAT16: 'DT_FLOAT16',
+ DataType.DT_INT8: 'DT_INT8',
+ DataType.DT_INT16: 'DT_INT16',
+ DataType.DT_UINT16: 'DT_UINT16',
+ DataType.DT_UINT8: 'DT_UINT8',
+ DataType.DT_INT32: 'DT_INT32',
+ DataType.DT_INT64: 'DT_INT64',
+ DataType.DT_UINT32: 'DT_UINT32',
+ DataType.DT_UINT64: 'DT_UINT64',
+ DataType.DT_BOOL: 'DT_BOOL',
+ DataType.DT_DOUBLE: 'DT_DOUBLE',
+ DataType.DT_STRING: 'DT_STRING',
+ DataType.DT_DUAL_SUB_INT8: 'DT_DUAL_SUB_INT8',
+ DataType.DT_DUAL_SUB_UINT8: 'DT_DUAL_SUB_UINT8',
+ DataType.DT_COMPLEX64: 'DT_COMPLEX64',
+ DataType.DT_COMPLEX128: 'DT_COMPLEX128',
+ DataType.DT_QINT8: 'DT_QINT8',
+ DataType.DT_QINT16: 'DT_QINT16',
+ DataType.DT_QINT32: 'DT_QINT32',
+ DataType.DT_QUINT8: 'DT_QUINT8',
+ DataType.DT_QUINT16: 'DT_QUINT16',
+ DataType.DT_RESOURCE: 'DT_RESOURCE',
+ DataType.DT_STRING_REF: 'DT_STRING_REF',
+ DataType.DT_DUAL: 'DT_DUAL',
+ DataType.DT_VARIANT: 'DT_VARIANT',
+ DataType.DT_BF16: 'DT_BF16',
+ DataType.DT_UNDEFINED: 'DT_UNDEFINED',
+ DataType.DT_INT4: 'DT_INT4',
+ DataType.DT_UINT1: 'DT_UINT1',
+ DataType.DT_INT2: 'DT_INT2',
+ DataType.DT_UINT2: 'DT_UINT2',
+ DataType.DT_COMPLEX32: 'DT_COMPLEX32',
+ DataType.DT_MAX: 'DT_MAX',
+ }
+ if ge_dtype in ge_datatype:
+ return ge_datatype[ge_dtype]
+ else:
+ return 'unknown'
+
+
+def AllToAllAllGatherBatchMatmul(
+ x: Tensor,
+ weight: Tensor,
+ group_ep: str,
+ group_ep_worldsize: int,
+ group_tp: str,
+ group_tp_worldsize: int,
+ *,
+ bias: Optional[Tensor] = None,
+ shard_type: Optional[int] = 0,
+ act_type: Optional[int] = 0,
+ need_allgather_out: Optional[bool] = False,
+ need_activation_feature: Optional[bool] = False):
+ """REG_OP(AlltoAllAllGatherBatchMatMul)\n
+ .INPUT(x, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .INPUT(weight, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_BF16, DT_FLOAT32}))\n
+ .OUTPUT(y1, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OUTPUT(y2, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .OUTPUT(y3, TensorType({DT_FLOAT16, DT_BF16}))\n
+ .REQUIRED_ATTR(group_ep, String)\n
+ .REQUIRED_ATTR(group_tp, String)\n
+ .REQUIRED_ATTR(ep_world_size, int)\n
+ .REQUIRED_ATTR(tp_world_size, int)\n
+ .ATTR(x_shard_type, Int, 1)\n
+ .ATTR(act_type, Int, 0)\n
+ .ATTR(need_allgather_out, Bool, False)\n
+ .ATTR(need_activation_feature, Bool, False)\n
+ .OP_END_FACTORY_REG(AlltoAllAllGatherBatchMatMul)
+ use to construct Opdesc
+ """
+ transpose_weight = False
+ return torchair.ge.custom_op(
+ "AlltoAllAllGatherBatchMatMul",
+ inputs={
+ "x": x,
+ "weight": weight,
+ "bias": bias
+ },
+ attrs={
+ "group_ep": ge.attr.Str(group_ep),
+ "group_tp": ge.attr.Str(group_tp),
+ "ep_world_size": ge.attr.Int(group_ep_worldsize),
+ "tp_world_size": ge.attr.Int(group_tp_worldsize),
+ "x_shard_type": ge.attr.Int(shard_type),
+ "act_type": ge.attr.Int(act_type),
+ "transpose_weight": ge.attr.Bool(transpose_weight),
+ "output_y2_flag": ge.attr.Bool(need_allgather_out),
+ "output_y3_flag": ge.attr.Bool(need_activation_feature)
+ },
+ outputs=["y1", "y2", "y3"]
+ )
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_bmm_reduce_scatter_all_to_all_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_bmm_reduce_scatter_all_to_all_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1a983301b1902f03fc2368d4574844d9a938f4b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_bmm_reduce_scatter_all_to_all_builder.py
@@ -0,0 +1,147 @@
+# Copyright (c) 2024, Huawei Technologies.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+import torch
+import torchair
+from torch.library import Library, impl
+from mindspeed.op_builder.builder import MindSpeedOpBuilder, AS_LIBRARY
+from mindspeed.op_builder.npu_all_to_all_all_gather_bmm_builder import CheckDtype
+torch_npu_api_version = None
+try:
+ from torchair import ge
+ from torchair import register_fx_node_ge_converter
+ from torchair.ge import Tensor, TensorSpec, DataType
+except ImportError:
+ ge, Tensor, TensorSpec, DataType = None, None, None, None
+ from torchair.ge_concrete_graph.fx2ge_converter import register_fx_node_ge_converter
+ torch_npu_api_version = 1
+else:
+ torch_npu_api_version = 2
+
+
+class BatchMatMulReduceScatterAlltoAllOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_bmm_reducescatter_alltoall"
+ OP_PROTO = "npu_bmm_reducescatter_alltoall(Tensor x, Tensor weight, str group_ep, int group_ep_worldsize, \
+ str group_tp, int group_tp_worldsize, *, Tensor? bias=None, int shard_type=0) -> Tensor"
+
+ def __init__(self):
+ super(BatchMatMulReduceScatterAlltoAllOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_bmm_reduce_scatter_all_to_all.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_bmm_reducescatter_alltoall", "Meta")
+ def npu_bmm_reducescatter_alltoall_forward(x, weight, group_ep, group_ep_worldsize,
+ group_tp, group_tp_worldsize, *, bias=None, shard_type=0):
+ if group_ep_worldsize == 0:
+ raise AssertionError('group_ep_worldsize can not be 0.')
+ if group_tp_worldsize == 0:
+ raise AssertionError('group_tp_worldsize can not be 0.')
+ e = x.size(0) * group_ep_worldsize
+ c = x.size(1) // group_ep_worldsize
+ h = weight.size(2)
+
+ if x.size(0) == 0:
+ raise AssertionError('The first dim of x can not be 0.')
+ if x.size(1) == 0:
+ raise AssertionError('The second dim of x can not be 0.')
+ if x.size(2) == 0:
+ raise AssertionError('The last dim of x can not be 0.')
+ if weight.size(0) == 0:
+ raise AssertionError('The first dim of weight can not be 0.')
+ if weight.size(1) == 0:
+ raise AssertionError('The second dim of weight can not be 0.')
+ if weight.size(2) == 0:
+ raise AssertionError('The last dim of weight can not be 0.')
+
+ if shard_type == 0:
+ # shard in h dimensions
+ h = h // group_tp_worldsize
+ else:
+ # shard in c dimensions
+ c = c // group_tp_worldsize
+
+ return x.new_empty((e, c, h))
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_bmm_reducescatter_alltoall.default)
+ def convert_npu_bmm_reducescatter_alltoall(x: Tensor,
+ weight: Tensor,
+ group_ep: str,
+ group_ep_worldsize: int,
+ group_tp: str,
+ group_tp_worldsize: int,
+ *,
+ bias: Optional[Tensor] = None,
+ shard_type: Optional[int] = 0,
+ meta_outputs: TensorSpec = None):
+ if torch_npu_api_version != 2:
+ raise ValueError(f"torch_npu_api_version {torch_npu_api_version} unsupport")
+ CheckDtype(x, weight, bias)
+ return BatchMatmulReduceScatterAlltoAll(x,
+ weight,
+ group_ep,
+ group_ep_worldsize,
+ group_tp,
+ group_tp_worldsize,
+ bias=bias,
+ shard_type=shard_type)
+
+
+def BatchMatmulReduceScatterAlltoAll(x: Tensor,
+ weight: Tensor,
+ group_ep: str,
+ group_ep_worldsize: int,
+ group_tp: str,
+ group_tp_worldsize: int,
+ *,
+ bias: Tensor = None,
+ shard_type: int = 0):
+ transpose_weight = False
+ return torchair.ge.custom_op(
+ "BatchMatMulReduceScatterAlltoAll",
+ inputs={
+ "x": x,
+ "weight": weight,
+ "bias": bias
+ },
+ attrs={
+ "group_ep": ge.attr.Str(group_ep),
+ "group_tp": ge.attr.Str(group_tp),
+ "ep_world_size": ge.attr.Int(group_ep_worldsize),
+ "tp_world_size": ge.attr.Int(group_tp_worldsize),
+ "y_shard_type": ge.attr.Int(shard_type),
+ "transpose_weight": ge.attr.Bool(transpose_weight)
+ },
+ outputs=["y"]
+ )
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_dropout_add_layer_norm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_dropout_add_layer_norm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dd19a124189bb384d8d115dfa3c0337c031e42c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_dropout_add_layer_norm_builder.py
@@ -0,0 +1,26 @@
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class NpuDropoutAddLayerNormOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_dropout_add_layer_norm"
+
+ def __init__(self):
+ super(NpuDropoutAddLayerNormOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_dropout_add_layer_norm.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_grouped_mat_mul_all_reduce_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_grouped_mat_mul_all_reduce_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bad183d545b3dd9cfec7e206d22395336ea4217
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_grouped_mat_mul_all_reduce_builder.py
@@ -0,0 +1,37 @@
+import torch
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class GroupedMatMulAllReduceOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_grouped_mat_mul_all_reduce"
+ TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
+
+ def __init__(self):
+ super(GroupedMatMulAllReduceOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_grouped_mat_mul_all_reduce.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ if self.TORCH_MAJOR >= 2 and self.TORCH_MINOR >= 1:
+ cpp_std = " -std=c++17"
+ compile_maroc = " -D__TORCH_2__"
+ else:
+ cpp_std = " -std=c++14"
+ compile_maroc = " -D__TORCH_1__"
+ args.append(cpp_std)
+ args.append(compile_maroc)
+
+ return args
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_inplace_mm_all_reduce_add_rms_norm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_inplace_mm_all_reduce_add_rms_norm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..117e4d897addeeb4b12c6b7a541d3e08dac48a79
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_inplace_mm_all_reduce_add_rms_norm_builder.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2024, Huawei Technologies.
+
+
+from typing import List, Optional
+
+from mindspeed.op_builder.builder import MindSpeedOpBuilder, AS_LIBRARY
+from mindspeed.op_builder.npu_mm_all_reduce_add_rms_norm_builder import MatmulAllReduceAddRmsNorm, CheckDtype
+
+import torch
+import torchair
+from torch.library import Library, impl
+
+torch_npu_api_version = None
+try:
+ from torchair import ge
+ from torchair.ge import Tensor, TensorSpec, DataType
+ from torchair import register_fx_node_ge_converter
+except ImportError:
+ torch_npu_api_version = 1
+ from torchair.ge_concrete_graph import ge_apis as ge
+ from torchair.ge_concrete_graph.fx2ge_converter import declare_supported, register_fx_node_ge_converter
+ from torchair.ge_concrete_graph.ge_graph import Tensor, TensorSpec
+ from torchair.ge_concrete_graph.ge_graph import get_default_ge_graph, next_unique_name
+ from torchair.ge_concrete_graph.ge_graph import compat_as_bytes
+ from torchair.ge_concrete_graph.ge_graph import get_invalid_desc
+else:
+ torch_npu_api_version = 2
+
+
+class InplaceMatmulAllReduceAddRmsNormOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_mm_all_reduce_add_rms_norm_"
+ OP_PROTO = "npu_mm_all_reduce_add_rms_norm_(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, \
+ str hcom, *, str reduce_op='sum', float epsilon=1e-06, Tensor? bias=None, Tensor? antiquant_scale=None, \
+ Tensor? antiquant_offset=None, Tensor? dequant_scale=None, int antiquant_group_size=0, int comm_turn=0) \
+ -> (Tensor, Tensor)"
+
+ def __init__(self):
+ super(InplaceMatmulAllReduceAddRmsNormOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_mm_all_reduce_add_rms_norm_.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_mm_all_reduce_add_rms_norm_", "Meta")
+ def npu_inplace_mm_all_reduce_add_rms_norm_forward(
+ x1, x2, residual, gamma, hcom, reduce_op='sum', epsilon=1e-6,
+ bias=None, antiquant_scale=None, antiquant_offset=None,
+ dequant_scale=None, antiquant_group_size=0, comm_turn=0):
+ return (torch.empty_like(residual, dtype=residual.dtype),
+ torch.empty_like(residual, dtype=residual.dtype))
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_mm_all_reduce_add_rms_norm_.default)
+ def convert_npu_mm_all_reduce_add_rms_norm_(
+ x1: Tensor,
+ x2: Tensor,
+ residual: Tensor,
+ gamma: Tensor,
+ hcom: str,
+ *,
+ reduce_op: str = 'sum',
+ epsilon: float = 1e-6,
+ bias: Optional[Tensor] = None,
+ antiquant_scale: Optional[Tensor] = None,
+ antiquant_offset: Optional[Tensor] = None,
+ dequant_scale: Optional[Tensor] = None,
+ antiquant_group_size: int = 0,
+ comm_turn: int = 0,
+ meta_outputs: List[TensorSpec] = None
+ ):
+ # transpose_x1 is set to False by default
+ transpose_x1 = False
+ transpose_x2 = False
+ '''npu_mm_all_reduce_add_rms_norm_(Tensor x1, Tensor x2, Tensor(a!) residual, Tensor gamma,
+ str hcom, *, str reduce_op='sum', float epsilon=1e-06, Tensor? bias=None,
+ Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? dequant_scale=None,
+ int antiquant_group_size=0, int comm_turn=0) -> (Tensor(a!), Tensor)'''
+ CheckDtype(x1, x2, bias=bias, residual=residual, gamma=gamma, antiquant_scale=antiquant_scale,
+ antiquant_offset=antiquant_offset, dequant_scale=dequant_scale)
+ return MatmulAllReduceAddRmsNorm(x1,
+ x2,
+ bias=bias,
+ residual=residual,
+ gamma=gamma,
+ antiquant_scale=antiquant_scale,
+ antiquant_offset=antiquant_offset,
+ dequant_scale=dequant_scale,
+ group=hcom,
+ reduce_op=reduce_op,
+ is_trans_a=transpose_x1,
+ is_trans_b=transpose_x2,
+ comm_turn=comm_turn,
+ antiquant_group_size=antiquant_group_size,
+ epsilon=epsilon)
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_mm_all_reduce_add_rms_norm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_mm_all_reduce_add_rms_norm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc8e1d4b0ea1ac91bd6492bdf3edc31802f94e8b
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_mm_all_reduce_add_rms_norm_builder.py
@@ -0,0 +1,324 @@
+# Copyright (c) 2024, Huawei Technologies.
+
+
+from typing import List, Optional
+
+from mindspeed.op_builder.builder import MindSpeedOpBuilder, AS_LIBRARY
+
+import torch
+import torchair
+from torch.library import Library, impl
+
+torch_npu_api_version = None
+try:
+ from torchair import ge
+ from torchair.ge import Tensor, TensorSpec, DataType
+ from torchair import register_fx_node_ge_converter
+except ImportError:
+ torch_npu_api_version = 1
+ from torchair.ge_concrete_graph import ge_apis as ge
+ from torchair.ge_concrete_graph.fx2ge_converter import declare_supported, register_fx_node_ge_converter
+ from torchair.ge_concrete_graph.ge_graph import Tensor, TensorSpec
+ from torchair.ge_concrete_graph.ge_graph import get_default_ge_graph, next_unique_name
+ from torchair.ge_concrete_graph.ge_graph import compat_as_bytes
+ from torchair.ge_concrete_graph.ge_graph import get_invalid_desc
+else:
+ torch_npu_api_version = 2
+
+DataType = dict(
+ DT_FLOAT16=1,
+ DT_INT8=2,
+ DT_INT32=3,
+ DT_INT64=9,
+ DT_UINT64=10,
+ DT_BF16=27,
+)
+
+
+class MatmulAllReduceAddRmsNormOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_mm_all_reduce_add_rms_norm"
+ OP_PROTO = "npu_mm_all_reduce_add_rms_norm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, str hcom, *, \
+ str reduce_op='sum', float epsilon=1e-06, Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? \
+ antiquant_offset=None, Tensor? dequant_scale=None, int antiquant_group_size=0, int comm_turn=0) \
+ -> (Tensor, Tensor)"
+
+ def __init__(self):
+ super(MatmulAllReduceAddRmsNormOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_mm_all_reduce_add_rms_norm.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_mm_all_reduce_add_rms_norm", "Meta")
+ def npu_mm_all_reduce_add_rms_norm_forward(x1, x2, residual, gamma, hcom, reduce_op='sum', epsilon=1e-6,
+ bias=None, antiquant_scale=None, antiquant_offset=None,
+ dequant_scale=None, antiquant_group_size=0, comm_turn=0):
+ return (torch.empty_like(residual, dtype=residual.dtype),
+ torch.empty_like(residual, dtype=residual.dtype))
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_mm_all_reduce_add_rms_norm.default)
+ def convert_npu_mm_all_reduce_add_rms_norm(
+ x1: Tensor,
+ x2: Tensor,
+ residual: Tensor,
+ gamma: Tensor,
+ hcom: str,
+ *,
+ reduce_op: str = 'sum',
+ epsilon: float = 1e-6,
+ bias: Optional[Tensor] = None,
+ antiquant_scale: Optional[Tensor] = None,
+ antiquant_offset: Optional[Tensor] = None,
+ dequant_scale: Optional[Tensor] = None,
+ antiquant_group_size: int = 0,
+ comm_turn: int = 0,
+ meta_outputs: List[TensorSpec] = None
+ ):
+ # transpose_x1 is set to False by default
+ transpose_x1 = False
+ transpose_x2 = False
+ '''"npu_mm_all_reduce_add_rms_norm(Tensor x1, Tensor x2, Tensor residual, Tensor gamma, str hcom,
+ *, str reduce_op='sum', float epsilon=1e-06, Tensor? bias=None, Tensor? antiquant_scale=None,
+ Tensor? antiquant_offset=None, Tensor? dequant_scale=None, int antiquant_group_size=0,
+ int comm_turn=0) -> (Tensor, Tensor)"'''
+ CheckDtype(x1, x2, bias=bias, residual=residual, gamma=gamma, antiquant_scale=antiquant_scale,
+ antiquant_offset=antiquant_offset, dequant_scale=dequant_scale)
+ return MatmulAllReduceAddRmsNorm(x1,
+ x2,
+ bias=bias,
+ residual=residual,
+ gamma=gamma,
+ antiquant_scale=antiquant_scale,
+ antiquant_offset=antiquant_offset,
+ dequant_scale=dequant_scale,
+ group=hcom,
+ reduce_op=reduce_op,
+ is_trans_a=transpose_x1,
+ is_trans_b=transpose_x2,
+ comm_turn=comm_turn,
+ antiquant_group_size=antiquant_group_size,
+ epsilon=epsilon)
+
+
+def CheckDtype(x1: Tensor, x2: Tensor, bias: Optional[Tensor], residual: Tensor, gamma: Tensor,
+ antiquant_scale: Optional[Tensor], antiquant_offset: Optional[Tensor],
+ dequant_scale: Optional[Tensor]):
+ if residual.dtype != gamma.dtype:
+ raise AssertionError('type of residual and gamma must be same.')
+ if x1.dtype in (DataType["DT_FLOAT16"], DataType["DT_BF16"]) and \
+ x2.dtype in (DataType["DT_FLOAT16"], DataType["DT_BF16"]):
+ if x2.dtype != x1.dtype:
+ raise AssertionError('type of x1 and x2 must be same.')
+ if bias is not None and bias.dtype != x1.dtype:
+ raise AssertionError('type of x1 and bias must be same.')
+ if residual.dtype != x1.dtype:
+ raise AssertionError('type of x1 and residual must be same.')
+ elif x1.dtype is DataType["DT_INT8"] and x2.dtype is DataType["DT_INT8"]:
+ if bias is not None and bias.dtype != DataType["DT_INT32"]:
+ raise AssertionError('type of bias must be int32.')
+ if dequant_scale is None:
+ raise AssertionError('dequant_scale must not be None.')
+ if dequant_scale.dtype in (DataType["DT_INT64"], DataType["DT_UINT64"]):
+ if residual.dtype != DataType["DT_FLOAT16"]:
+ raise AssertionError('when dequant_scale is int64(uint64), residual type must be fp16.')
+ elif dequant_scale.dtype is DataType["DT_BF16"]:
+ if residual.dtype != DataType["DT_BF16"]:
+ raise AssertionError('type of dequant_scale and residual should be bf16.')
+ else:
+ raise AssertionError('dequant_scale type must be int64, uint64 or bf16')
+ elif x1.dtype in (DataType["DT_FLOAT16"], DataType["DT_BF16"]) and \
+ x2.dtype is DataType["DT_INT8"]:
+ if bias is not None and bias.dtype != x1.dtype:
+ raise AssertionError('type of x1 and bias must be same.')
+ if antiquant_scale is None:
+ raise AssertionError('antiquant_scale must not be None.')
+ if antiquant_scale.dtype != x1.dtype:
+ raise AssertionError('type of x1 and antiquant_scale must be same.')
+ if antiquant_offset is not None and antiquant_offset.dtype != antiquant_scale.dtype:
+ raise AssertionError('type of antiquant_scale and antiquant_offset must be same.')
+ if residual.dtype != x1.dtype:
+ raise AssertionError('type of x1 and residual must be same.')
+ else:
+ raise AssertionError("the type of x1 and x2 should be suit the not quant scenario, "\
+ "dequant scenario, antiquant scenario.")
+
+MatmulAllReduceAddRmsNorm = None
+if torch_npu_api_version == 2:
+ def MatmulAllReduceAddRmsNormV2(x1: Tensor,
+ x2: Tensor,
+ bias: Optional[Tensor],
+ residual: Tensor,
+ gamma: Tensor,
+ antiquant_scale: Optional[Tensor],
+ antiquant_offset: Optional[Tensor],
+ dequant_scale: Optional[Tensor],
+ *,
+ group: str,
+ reduce_op: str = "sum",
+ is_trans_a: bool = False,
+ is_trans_b: bool = False,
+ comm_turn: int = 0,
+ antiquant_group_size: int = 0,
+ epsilon: float = 0.000001):
+ """REG_OP(MatmulAllReduceAddRmsNorm)\n
+ .INPUT(x1, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .INPUT(x2, TensorType({DT_FLOAT16, DT_BF16, DT_INT8, DT_INT8, DT_INT8, DT_INT4, DT_INT4}))\n
+ .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_BF16, DT_INT32, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .INPUT(residual, TensorType({DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .INPUT(gamma, TensorType({DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_scale, TensorType({DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(antiquant_offset, TensorType({DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .OPTIONAL_INPUT(dequant_scale, TensorType({DT_FLOAT16, DT_BF16, DT_UINT64, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .OUTPUT(norm_out, TensorType({DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_FLOAT16, DT_BF16, DT_FLOAT16, DT_BF16}))\n
+ .REQUIRED_ATTR(group, String)\n
+ .ATTR(reduce_op, String, "sum")\n
+ .ATTR(is_trans_a, Bool, false)\n
+ .ATTR(is_trans_b, Bool, false)\n
+ .ATTR(comm_turn, Int, 0)\n
+ .ATTR(antiquant_group_size, Int, 0)\n
+ .ATTR(epsilon, Float, 1e-6)\n
+ .OP_END_FACTORY_REG(MatmulAllReduceAddRmsNorm)
+ """
+
+ y, norm_out = torchair.ge.custom_op(
+ "MatmulAllReduceAddRmsNorm",
+ inputs={
+ "x1" : x1,
+ "x2" : x2,
+ "bias" : bias,
+ "residual" : residual,
+ "gamma" : gamma,
+ "antiquant_scale" : antiquant_scale,
+ "antiquant_offset" : antiquant_offset,
+ "dequant_scale" : dequant_scale,
+ },
+ attrs={
+ "group" : ge.attr.Str(group),
+ "reduce_op" : ge.attr.Str(reduce_op),
+ "is_trans_a" : ge.attr.Bool(is_trans_a),
+ "is_trans_b" : ge.attr.Bool(is_trans_b),
+ "comm_turn" : ge.attr.Int(comm_turn),
+ "antiquant_group_size" : ge.attr.Int(antiquant_group_size),
+ "epsilon" : ge.attr.Float(epsilon),
+ },
+ outputs=[
+ "y",
+ "norm_out"
+ ]
+ )
+ return y, norm_out
+ MatmulAllReduceAddRmsNorm = MatmulAllReduceAddRmsNormV2
+elif torch_npu_api_version == 1:
+ def MatmulAllReduceAddRmsNormV1(x1: Tensor,
+ x2: Tensor,
+ bias: Optional[Tensor],
+ residual: Tensor,
+ gamma: Tensor,
+ antiquant_scale: Optional[Tensor],
+ antiquant_offset: Optional[Tensor],
+ dequant_scale: Optional[Tensor],
+ *,
+ group: str,
+ reduce_op: str = "sum",
+ is_trans_a: bool = False,
+ is_trans_b: bool = False,
+ comm_turn: int = 0,
+ antiquant_group_size: int = 0,
+ epsilon: float = 0.000001,
+ dependencies=None,
+ node_name=None):
+ op = get_default_ge_graph().op.add()
+ op.type = "MatmulAllReduceAddRmsNorm"
+ op.name = next_unique_name(node_name, "MatmulAllReduceAddRmsNorm")
+
+ # process dependices
+ if dependencies is not None:
+ for dependency in dependencies:
+ op.input.append(dependency.controller)
+
+ # process inputs
+ op.input.append(x1.tensor)
+ op.input_desc.add().CopyFrom(x1.desc)
+ op.input_desc[-1].name = "x1"
+ op.input.append(x2.tensor)
+ op.input_desc.add().CopyFrom(x2.desc)
+ op.input_desc[-1].name = "x2"
+ if bias is not None:
+ op.input.append(bias.tensor)
+ op.input_desc.add().CopyFrom(bias.desc)
+ op.input_desc[-1].name = "bias"
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "bias"
+ op.input.append(residual.tensor)
+ op.input_desc.add().CopyFrom(residual.desc)
+ op.input_desc[-1].name = "residual"
+ op.input.append(gamma.tensor)
+ op.input_desc.add().CopyFrom(gamma.desc)
+ op.input_desc[-1].name = "gamma"
+ if antiquant_scale is not None:
+ op.input.append(antiquant_scale.tensor)
+ op.input_desc.add().CopyFrom(antiquant_scale.desc)
+ op.input_desc[-1].name = "antiquant_scale"
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "antiquant_scale"
+ if antiquant_offset is not None:
+ op.input.append(antiquant_offset.tensor)
+ op.input_desc.add().CopyFrom(antiquant_offset.desc)
+ op.input_desc[-1].name = "antiquant_offset"
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "antiquant_offset"
+ if dequant_scale is not None:
+ op.input.append(dequant_scale.tensor)
+ op.input_desc.add().CopyFrom(dequant_scale.desc)
+ op.input_desc[-1].name = "dequant_scale"
+ else:
+ op.input.append('')
+ op.input_desc.add().CopyFrom(get_invalid_desc())
+ op.input_desc[-1].name = "dequant_scale"
+
+ # process attrs
+ op.attr["group"].s = compat_as_bytes(group)
+ op.attr["reduce_op"].s = compat_as_bytes(reduce_op)
+ op.attr["is_trans_a"].b = is_trans_a
+ op.attr["is_trans_b"].b = is_trans_b
+ op.attr["comm_turn"].i = comm_turn
+ op.attr["antiquant_group_size"].i = antiquant_group_size
+ op.attr["epsilon"].f = epsilon
+
+ # process outputs
+ output_index = 0
+ op.output_desc.add().name = "y"
+ y = Tensor(op, output_index)
+ output_index += 1
+ op.output_desc.add().name = "norm_out"
+ norm_out = Tensor(op, output_index)
+ output_index += 1
+
+ return y, norm_out
+ MatmulAllReduceAddRmsNorm = MatmulAllReduceAddRmsNormV1
+else:
+ raise ValueError("torch_npu_api_version unsupport")
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_moe_token_permute_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_moe_token_permute_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..09a2913f3901bc568229d97a32ea93239ee593ca
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_moe_token_permute_builder.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class MoeTokenPermuteOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_moe_token_permute"
+
+ def __init__(self):
+ super(MoeTokenPermuteOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_moe_token_permute.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_moe_token_unpermute_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_moe_token_unpermute_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6b15a544750f87516932125b27b9ad71399bb80
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_moe_token_unpermute_builder.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class MoeTokenUnpermuteOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_moe_token_unpermute"
+
+ def __init__(self):
+ super(MoeTokenUnpermuteOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_moe_token_unpermute.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_ring_attention_update_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_ring_attention_update_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1960d86c2af993d7828033d3771150f7572aa5b0
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_ring_attention_update_builder.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class RingAttentionUpdateOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_ring_attention_update"
+
+ def __init__(self):
+ super(RingAttentionUpdateOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_ring_attention_update.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
diff --git a/model/train/yoco_moe/mindspeed/op_builder/npu_rotary_position_embedding_builder.py b/model/train/yoco_moe/mindspeed/op_builder/npu_rotary_position_embedding_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..eac82e537380e5e44647af1ab8f179d13b24e4cb
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/npu_rotary_position_embedding_builder.py
@@ -0,0 +1,26 @@
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class RotaryPositionEmbeddingOpBuilder(MindSpeedOpBuilder):
+ OP_NAME = "npu_rotary_position_embedding"
+
+ def __init__(self):
+ super(RotaryPositionEmbeddingOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/cann/npu_rotary_position_embedding.cpp']
+
+ def include_paths(self):
+ paths = super().include_paths()
+ paths += ['ops/csrc/cann/inc']
+ return paths
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += [
+ '-Wno-sign-compare',
+ '-Wno-deprecated-declarations',
+ '-Wno-return-type',
+ "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
+ ]
+ return args
diff --git a/model/train/yoco_moe/mindspeed/op_builder/quant_gmm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/quant_gmm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..34632a61ce29a372d23783683c36ac7c308769b3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/quant_gmm_builder.py
@@ -0,0 +1,69 @@
+from typing import Optional
+from functools import partial
+import torch
+from torch.library import impl
+
+from mindspeed.op_builder.builder import AS_LIBRARY
+from mindspeed.op_builder.gmm_builder import GMMOpBuilderPublic, GroupedMatmul, fill_empty_tensor
+from mindspeed.op_builder.gmm_builder import ge, Tensor, TensorSpec, DataType, register_fx_node_ge_converter
+
+
+class QuantGMMOpBuilder(GMMOpBuilderPublic):
+ OP_NAME = "quant_grouped_matmul"
+ OP_PROTO = (
+ "npu_quant_gmm(Tensor x, Tensor weight, Tensor scale, *, Tensor? offset=None, Tensor? per_token_scale=None, \
+ Tensor? bias=None, Tensor? group_list=None, int? group_list_type=0, ScalarType? output_dtype=None, \
+ int? act_type=0) -> Tensor"
+ )
+
+ def __init__(self):
+ super(QuantGMMOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def sources(self):
+ return ['ops/csrc/cann/quant_gmm.cpp']
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_quant_gmm", "Meta")
+ def npu_quant_gmm_forward(x, weight, scale, *, offset=None, per_token_scale=None, bias=None, group_list=None,
+ group_list_type=0, output_dtype=None, act_type=0):
+ BM = x.shape[0]
+ N = weight.shape[-1]
+ output_dtype = output_dtype or torch.float16
+ return x.new_empty((BM, N), dtype=output_dtype)
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_quant_gmm.default)
+ def conveter_npu_quant_gmm(
+ x: Tensor,
+ weight: Tensor,
+ scale: Tensor,
+ *,
+ offset: Optional[Tensor] = None,
+ per_token_scale: Optional[Tensor] = None,
+ bias: Optional[Tensor] = None,
+ group_list: Optional[Tensor] = None,
+ group_list_type: Optional[int] = 0,
+ output_dtype: Optional[DataType] = None,
+ act_type: Optional[int] = 0,
+ meta_outputs: TensorSpec = None,
+ ):
+ bias = bias or fill_empty_tensor(DataType.DT_INT32)
+ offset = offset or fill_empty_tensor(DataType.DT_FLOAT)
+ antiquant_scale = fill_empty_tensor(DataType.DT_FLOAT16)
+ antiquant_offset = fill_empty_tensor(DataType.DT_FLOAT16)
+
+ y_dtype = 0
+ if output_dtype is None or output_dtype == torch.float16:
+ y_dtype = 0
+ elif output_dtype == torch.bfloat16:
+ y_dtype = 1
+ elif output_dtype == torch.int8:
+ raise ValueError("output_dtype not support int8 yet for graph mode")
+ else:
+ raise ValueError(f"output_dtype should be int8, float16 or bfloat16, "
+ f"otherwise it should be None, but got {output_dtype}")
+
+ return GroupedMatmul([x], [weight], [bias], [scale], [offset], [antiquant_scale], [antiquant_offset],
+ group_list, per_token_scale, split_item=3, group_type=0,
+ group_list_type=group_list_type, dtype=y_dtype, act_type=act_type)[0]
diff --git a/model/train/yoco_moe/mindspeed/op_builder/rms_norm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/rms_norm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9ae8ce11712579c3a5aa52d25a1efbc4b6c897
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/rms_norm_builder.py
@@ -0,0 +1,13 @@
+from mindspeed.op_builder.atb_builder import AtbOpBuilder
+
+
+class RmsNormOpBuilder(AtbOpBuilder):
+ OP_NAME = "rms_norm"
+
+ def __init__(self):
+ super(RmsNormOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/atb/rms_norm.cpp',
+ 'ops/csrc/atb/utils/atb_adapter.cpp']
+
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/smart_swap_builder.py b/model/train/yoco_moe/mindspeed/op_builder/smart_swap_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d44b25030a41a280d1693d90516db19b3516854
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/smart_swap_builder.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
+import os
+import torch
+from mindspeed.op_builder.builder import MindSpeedOpBuilder
+
+
+class SmartSwapBuilder(MindSpeedOpBuilder):
+ NAME = "smart_swap"
+ _torch_path = None
+ _python_path = None
+
+ def __init__(self):
+ from sysconfig import get_paths
+
+ self._torch_path = os.path.dirname(os.path.abspath(torch.__file__))
+ self._python_path = os.path.dirname(os.path.abspath(get_paths().get("include")))
+ super(SmartSwapBuilder, self).__init__(self.NAME)
+
+ def sources(self):
+ root_path = "ops/csrc/pluggable_allocator/smart_swap"
+ files = [
+ "./SwapBindings.cpp",
+ "./NPUSwapManager.cpp",
+ "./DeviceCachingAllocator.cpp",
+ "./CachingAllocatorConfig.cpp",
+ "./EventPool.cpp",
+ "./PluggableAllocator.cpp",
+ "./common.cpp",
+ ]
+ return [os.path.join(root_path, item) for item in files]
+
+ def cxx_args(self):
+ args = super().cxx_args()
+ args += ["-D_GLIBCXX_USE_CXX11_ABI=0", "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"]
+
+ remove_arg = "-fvisibility=hidden"
+ while remove_arg in args:
+ args.remove(remove_arg)
+
+ return args
+
+ def extra_ldflags(self):
+ flags = ["-L" + os.path.join(self._torch_npu_path, "lib"), "-ltorch_npu"]
+ return flags
diff --git a/model/train/yoco_moe/mindspeed/op_builder/swiglu_builder.py b/model/train/yoco_moe/mindspeed/op_builder/swiglu_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6b06e99e0b347dc98c6530e12ff6ffc6253a843
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/swiglu_builder.py
@@ -0,0 +1,12 @@
+from mindspeed.op_builder.atb_builder import AtbOpBuilder
+
+
+class SwigluOpBuilder(AtbOpBuilder):
+ OP_NAME = "swiglu"
+
+ def __init__(self):
+ super(SwigluOpBuilder, self).__init__(self.OP_NAME)
+
+ def sources(self):
+ return ['ops/csrc/atb/swiglu.cpp',
+ 'ops/csrc/atb/utils/atb_adapter.cpp']
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/op_builder/weight_quant_gmm_builder.py b/model/train/yoco_moe/mindspeed/op_builder/weight_quant_gmm_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4619edb704b65130d402f4adcff37a1d1f9eaf37
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/op_builder/weight_quant_gmm_builder.py
@@ -0,0 +1,61 @@
+from typing import Optional
+from functools import partial
+import torch
+from torch.library import impl
+
+from mindspeed.op_builder.builder import AS_LIBRARY
+from mindspeed.op_builder.gmm_builder import GMMOpBuilderPublic, GroupedMatmul, fill_empty_tensor
+from mindspeed.op_builder.gmm_builder import ge, Tensor, TensorSpec, DataType, register_fx_node_ge_converter
+
+
+class WeightQuantGMMOpBuilder(GMMOpBuilderPublic):
+ OP_NAME = "weight_quant_grouped_matmul"
+ OP_PROTO = (
+ "npu_weight_quant_gmm(Tensor x, Tensor weight, Tensor antiquant_scale, *, Tensor? antiquant_offset=None, \
+ Tensor? bias=None, Tensor? group_list=None, int? group_list_type=0, int? act_type=0) -> Tensor"
+ )
+
+ def __init__(self):
+ super(WeightQuantGMMOpBuilder, self).__init__(self.OP_NAME)
+ self.register_op_proto(self.OP_PROTO)
+ self.register_op_ir()
+
+ def sources(self):
+ return ['ops/csrc/cann/weight_quant_gmm.cpp']
+
+ def register_op_ir(self):
+ @impl(AS_LIBRARY, "npu_weight_quant_gmm", "Meta")
+ def npu_weight_quant_gmm_forward(x, weight, antiquant_scale, *, antiquant_offset=None, bias=None,
+ group_list=None, group_list_type=0, act_type=0):
+ BM = x.shape[0]
+ N = weight.shape[-1]
+ output_dtype = x.dtype
+ return x.new_empty((BM, N), dtype=output_dtype)
+
+ @register_fx_node_ge_converter(torch.ops.mindspeed.npu_weight_quant_gmm.default)
+ def conveter_npu_weight_quant_gmm(
+ x: Tensor,
+ weight: Tensor,
+ antiquant_scale: Tensor,
+ *,
+ antiquant_offset: Optional[Tensor] = None,
+ bias: Optional[Tensor] = None,
+ group_list: Optional[Tensor] = None,
+ group_list_type: Optional[int] = 0,
+ act_type: Optional[int] = 0,
+ meta_outputs: TensorSpec = None,
+ ):
+ x_dtype = x.dtype
+ if bias is None:
+ if x_dtype == DataType.DT_BF16:
+ bias = fill_empty_tensor(DataType.DT_FLOAT)
+ elif x_dtype == DataType.DT_FLOAT16:
+ bias = fill_empty_tensor(DataType.DT_FLOAT16)
+ antiquant_offset = antiquant_offset or fill_empty_tensor(antiquant_scale.dtype)
+ scale = fill_empty_tensor(DataType.DT_UINT64)
+ offset = fill_empty_tensor(DataType.DT_FLOAT)
+
+
+ return GroupedMatmul([x], [weight], [bias], [scale], [offset], [antiquant_scale], [antiquant_offset],
+ group_list, split_item=3, group_type=0,
+ group_list_type=group_list_type, act_type=act_type)[0]
diff --git a/model/train/yoco_moe/mindspeed/ops/__init__.py b/model/train/yoco_moe/mindspeed/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/algorithm/adaptive_cp/adaptive_cp.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/algorithm/adaptive_cp/adaptive_cp.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1bbdc0479293e348959df5996c7da8192f47752d
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/algorithm/adaptive_cp/adaptive_cp.cpp
@@ -0,0 +1,454 @@
+// Copyright (c) 2024 Huawei Technologies Co., Ltd
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+
+namespace py = pybind11;
+
+
+bool is_block_all_one(const uint64_t* dataPtr, int rowBlockSize, int colBlockSize, int splitNum)
+{
+ for (size_t i = 0; i < rowBlockSize; i++) {
+ for (size_t j = 0; j < colBlockSize; j++) {
+ if (*(dataPtr++) != 0x0101010101010101) {
+ return false;
+ }
+ }
+ dataPtr += colBlockSize * (splitNum - 1);
+ }
+ return true;
+}
+
+void sub_coarsen_mask(const uint64_t *dataPtr, int rowBlockSize, int colBlockSize, int splitNum,
+ at::Tensor &output, int blockIdxStart, int blockIdxEnd)
+{
+ if (splitNum == 0) {
+ throw std::runtime_error("Split Number must be a positive integer.");
+ }
+ auto outputPtr = (uint8_t *) output.data_ptr();
+ outputPtr += blockIdxStart;
+ for (size_t i = blockIdxStart; i < blockIdxEnd; i++) {
+ int blockRowIdx = std::floor(i / splitNum);
+ int blockColIdx = i % splitNum;
+ int grid_val = is_block_all_one(
+ dataPtr + (blockRowIdx * rowBlockSize) * (splitNum * colBlockSize) + (blockColIdx * colBlockSize),
+ rowBlockSize, colBlockSize, splitNum);
+ *(outputPtr++) = grid_val;
+ }
+}
+
+void coarsen_mask(const at::Tensor& input, const int splitNum, at::Tensor& output)
+{
+ int rowDim = input.size(0);
+ int colDim = input.size(1);
+ if (splitNum == 0) {
+ throw std::runtime_error("Split number must be a positive integer.");
+ }
+ if (rowDim % splitNum != 0 || colDim % splitNum != 0) {
+ throw std::runtime_error("Both dims of the input 2-dim matrix must be divisible by split num.");
+ }
+ int rowBlockSize = rowDim / splitNum;
+ int colBlockSize = colDim / splitNum;
+ int sizeRatioInt64ToBool = sizeof(uint64_t) / sizeof(bool);
+ if (rowBlockSize % sizeRatioInt64ToBool != 0 || colBlockSize % sizeRatioInt64ToBool != 0) {
+ throw std::runtime_error("Both dims of the input 2-dim matrix must be divisible by 8 * split_num, to iterate "
+ "data pointer in uint64 instead of bool.");
+ }
+ auto dataPtr = (uint64_t*) input.data_ptr();
+ colBlockSize /= sizeRatioInt64ToBool;
+ std::vector threads;
+ int totalNumBlocks = splitNum * splitNum;
+ int numThreads = std::thread::hardware_concurrency();
+ if (numThreads == 0) {
+ throw std::runtime_error("Number of threads must be a positive integer.");
+ }
+ if (totalNumBlocks < numThreads) {
+ numThreads = totalNumBlocks;
+ }
+ int blockNumPerThread = totalNumBlocks / numThreads;
+ for (size_t i = 0; i < numThreads; ++i) {
+ int blockIdxStart = i * blockNumPerThread;
+ threads.emplace_back(sub_coarsen_mask, dataPtr, rowBlockSize, colBlockSize, splitNum, std::ref(output),
+ blockIdxStart, blockIdxStart + blockNumPerThread);
+ }
+ // 等待所有线程完成
+ for (auto& t : threads) {
+ t.join();
+ }
+}
+
+void sub_select_perm_mask(const at::Tensor &input, const std::vector indList, at::Tensor &output, int subIndCnt,
+ int subStartIdx)
+{
+ uint64_t seqLen = input.size(0);
+ uint64_t indCnt = indList.size();
+ auto maskTensorPtr = (uint8_t *) input.data_ptr();
+ auto outputTensorPtr = (uint8_t *) output.data_ptr();
+ uint8_t *subOutputPtr = outputTensorPtr + subStartIdx * indCnt;
+ std::vector rowStartIdxList(subIndCnt);
+ for (size_t i = 0; i < subIndCnt; i++) {
+ rowStartIdxList[i] = ((uint64_t) indList[subStartIdx + i] * seqLen);
+ }
+
+ for (size_t i = 0; i < subIndCnt; i++) {
+ uint64_t rowStartIdx = rowStartIdxList[i];
+ for (size_t j = 0; j < indCnt; j++) {
+ uint64_t colIdx = indList[j];
+ uint8_t extractedValue = *(maskTensorPtr + (rowStartIdx + colIdx));
+ *(subOutputPtr++) = extractedValue;
+ }
+ }
+}
+
+void select_perm_mask(const at::Tensor &input, const std::vector indList, at::Tensor &output)
+{
+ if (input.dim() != 2 || input.size(0) != input.size(1)) {
+ throw std::runtime_error("Input mask must be 2-dimensional squared tensor.");
+ }
+ if (input.scalar_type() != torch::kBool) {
+ throw std::runtime_error("The datatype of input mask must be bool.");
+ }
+ uint64_t indCnt = indList.size();
+ std::vector threads;
+ int numThreads = std::thread::hardware_concurrency();
+ if (numThreads == 0) {
+ throw std::runtime_error("Number of threads must be a positive integer.");
+ }
+ if (indCnt % numThreads != 0 || numThreads > indCnt) {
+ numThreads = indCnt;
+ }
+ int subIndCnt = indCnt / numThreads;
+ for (size_t i = 0; i < numThreads; ++i) {
+ int subStartIdx = i * subIndCnt;
+ threads.emplace_back(sub_select_perm_mask, input, indList, std::ref(output), subIndCnt, subStartIdx);
+ }
+ // 等待所有线程完成
+ for (auto& t : threads) {
+ t.join();
+ }
+}
+
+// Function to calculate the Euclidean distance between two points
+float euclidean_distance(const std::vector& point1, const std::vector& point2)
+{
+ float sum = 0.0f;
+ for (size_t i = 0; i < point1.size(); ++i) {
+ sum += (point1[i] - point2[i]) * (point1[i] - point2[i]);
+ }
+ return std::sqrt(sum);
+}
+
+// Function to calculate distances between each point and all centroids
+std::vector> calculate_distances(
+ const std::vector>& data,
+ const std::vector>& centroids)
+{
+ std::vector> distances(data.size(), std::vector(centroids.size()));
+ for (size_t i = 0; i < data.size(); ++i) {
+ for (size_t j = 0; j < centroids.size(); ++j) {
+ distances[i][j] = euclidean_distance(data[i], centroids[j]);
+ }
+ }
+ return distances;
+}
+
+// Function to find the index of the minimum element in a vector
+size_t argmin(const std::vector& dataVec)
+{
+ return std::distance(dataVec.begin(), std::min_element(dataVec.begin(), dataVec.end()));
+}
+
+// FUnction to update centroids
+std::vector> update_centroids(
+ const std::vector>& data,
+ const std::vector& labels,
+ size_t numClusters,
+ size_t dimensionSize)
+{
+ std::vector> newCentroids(numClusters, std::vector(dimensionSize, 0.0f));
+ std::vector counts(numClusters, 0);
+
+ for (size_t i = 0; i < data.size(); ++i) {
+ for (size_t j = 0; j < dimensionSize; ++j) {
+ newCentroids[labels[i]][j] += data[i][j];
+ }
+ counts[labels[i]]++;
+ }
+
+ for (size_t i = 0; i < numClusters; ++i) {
+ if (counts[i] > 0) {
+ for (size_t j = 0; j < dimensionSize; ++j) {
+ newCentroids[i][j] /= counts[i];
+ }
+ } else {
+ // Reinitialize centroid randomly if no points are assigned to this cluster
+ newCentroids[i] = data[std::rand() % data.size()];
+ }
+ }
+
+ return newCentroids;
+}
+
+bool allClose(const std::vector>& centroids,
+ const std::vector>& newCentroids,
+ float rtol = 1e-5, float atol = 1e-8)
+{
+ // Check if the dimensions match
+ if (centroids.size() != newCentroids.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < centroids.size(); ++i) {
+ if (centroids[i].size() != newCentroids[i].size()) {
+ return false;
+ }
+
+ for (size_t j = 0; j < centroids[i].size(); ++j) {
+ float diff = std::fabs(centroids[i][j] - newCentroids[i][j]);
+ float tol = atol + rtol * std::fabs(newCentroids[i][j]);
+ if (diff > tol) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+// Function to check if centroids have converged
+bool centroids_converged(
+ const std::vector>& centroids,
+ const std::vector>& newCentroids)
+{
+ return allClose(centroids, newCentroids);
+}
+
+std::vector get_num_tasks_on_device(const torch::Tensor& gridMask)
+{
+ int P = gridMask.size(0);
+ std::vector numTaskList(P, 0);
+
+ // 计算每行和每列中0的数量
+ for (int i = 0; i < P; ++i) {
+ int rowZeroCnt = 0;
+ int colZeroCnt = 0;
+
+ // 计算第i行中0的数量
+ for (int j = 0; j < P; ++j) {
+ if (gridMask[i][j].item() == 0) {
+ rowZeroCnt++;
+ }
+ }
+
+ // 计算第i列中0的数量
+ for (int j = 0; j < P; ++j) {
+ if (gridMask[j][i].item() == 0) {
+ colZeroCnt++;
+ }
+ }
+
+ // 第i行和第i列的0的数量之和
+ numTaskList[i] = rowZeroCnt + colZeroCnt - (gridMask[i][i].item() == 0 ? 1 : 0);
+ }
+
+ return numTaskList;
+}
+
+std::pair get_score(const at::Tensor& mask, size_t cpSize, at::Tensor &gridMask)
+{
+ if (cpSize == 0) {
+ throw std::runtime_error("CP size must be a positive integer.");
+ }
+ size_t maskSize = mask.size(0);
+ coarsen_mask(mask, cpSize, gridMask);
+ float totalTaskDensity = 1 - (gridMask.sum().item() / (cpSize * cpSize));
+ std::vector numTaskList = get_num_tasks_on_device(gridMask);
+ float taskNumDev = 0.0f;
+ if (!numTaskList.empty()) {
+ float mean = std::accumulate(numTaskList.begin(), numTaskList.end(), 0.0f) / numTaskList.size();
+ float sum = 0.0f;
+ for (const auto& num : numTaskList) {
+ sum += (num - mean) * (num - mean);
+ }
+ taskNumDev = std::sqrt(sum / numTaskList.size());
+ }
+ return {totalTaskDensity, taskNumDev};
+}
+
+// Kmeans function
+std::pair>, std::vector> kmeans(
+ const std::vector>& data,
+ size_t numClusters,
+ size_t numIters)
+{
+ size_t seqLen = data.size();
+ size_t dimensionSize = data[0].size();
+ // Initialize centroids randomly
+ std::vector> centroids(numClusters);
+ std::srand(0);
+ std::vector indices(seqLen);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::random_shuffle(indices.begin(), indices.end());
+ for (size_t i = 0; i < numClusters; ++i) {
+ centroids[i] = data[indices[i]];
+ }
+ std::vector labels(seqLen);
+ for (size_t iterIdx = 0; iterIdx < numIters; ++iterIdx) {
+ // Calculate distances between each point and centroids
+ std::vector> distances = calculate_distances(data, centroids);
+ // Assign labels based on nearest centroid
+ for (size_t i = 0; i < seqLen; ++ i) {
+ labels[i] = argmin(distances[i]);
+ }
+ // Update centroids
+ std::vector> newCentroids = update_centroids(data, labels, numClusters, dimensionSize);
+ // Check for convergence
+ if (centroids_converged(centroids, newCentroids)) {
+ break;
+ }
+ centroids = newCentroids;
+ }
+ return {centroids, labels};
+}
+
+std::vector search_kmeans(
+ const at::Tensor& attnMask,
+ const std::vector>& reducedMask,
+ at::Tensor &tmpAttnMask,
+ at::Tensor &tmpGridMask,
+ at::Tensor &optGridMask,
+ at::Tensor &optAttnMask,
+ py::list optNumCluster,
+ size_t cpSize,
+ size_t numIters)
+{
+ std::vector optSeq(attnMask.size(0));
+ std::iota(optSeq.begin(), optSeq.end(), 0);
+ auto [minTaskDensity, optTaskDev] = get_score(attnMask, cpSize, optGridMask);
+ for (int numClusters = 2; numClusters < 9 ; ++numClusters) {
+ auto [centroids, labels] = kmeans(reducedMask, numClusters, numIters);
+ // Sort indices based on labels
+ std::vector sortedSeq(labels.size());
+ std::iota(sortedSeq.begin(), sortedSeq.end(), 0);
+ std::sort(sortedSeq.begin(), sortedSeq.end(), [&labels](size_t i, size_t j) {
+ return labels[i] < labels[j];
+ });
+ select_perm_mask(attnMask, sortedSeq, tmpAttnMask);
+ auto [taskDensity, taskNumDev] = get_score(tmpAttnMask, cpSize, tmpGridMask);
+ if (taskDensity < minTaskDensity) {
+ minTaskDensity = taskDensity;
+ optAttnMask.copy_(tmpAttnMask);
+ optNumCluster[0] = numClusters;
+ optTaskDev = taskNumDev;
+ optSeq = sortedSeq;
+ optGridMask.copy_(tmpAttnMask);
+ } else if (taskDensity == minTaskDensity && taskNumDev < optTaskDev) {
+ optAttnMask.copy_(tmpAttnMask);
+ optNumCluster[0] = numClusters;
+ optTaskDev = taskNumDev;
+ optSeq = sortedSeq;
+ optGridMask.copy_(tmpGridMask);
+ }
+ }
+ return optSeq;
+}
+
+void get_mask_list_with_remap(const at::Tensor& attnMask, at::Tensor& output, std::vector rowIdxSeq, std::vector colIdxSeq)
+{
+ size_t maskColLen = attnMask.size(1);
+ size_t rowIdxLen = rowIdxSeq.size();
+ size_t colIdxLen = colIdxSeq.size();
+ if (rowIdxLen > output.size(0) || colIdxLen > output.size(1)) {
+ throw std::runtime_error("Row or colum index length large than size of attention mask");
+ }
+ uint8_t *inputPtr = (uint8_t *) attnMask.data_ptr();
+ uint8_t *outputPtr = (uint8_t *) output.data_ptr();
+
+ for (size_t i = 0; i < rowIdxLen; i++) {
+ uint8_t *inputRowStartPtr = inputPtr + rowIdxSeq[i] * maskColLen;
+ for (size_t j = 0; j < colIdxLen; j++) {
+ *(outputPtr++) = *(inputRowStartPtr + colIdxSeq[j]);
+ }
+ }
+}
+
+void get_mask_list_without_remap(const at::Tensor& attnMask, at::Tensor& output, std::vector blockIdx, int cpSize)
+{
+ if (cpSize == 0) {
+ throw std::runtime_error("CP size must be a positive integer.");
+ }
+ int sizeRatioInt64ToBool = sizeof(uint64_t) / sizeof(bool);
+ int rowGridSize = attnMask.size(0) / cpSize;
+ int colGridSize = rowGridSize / sizeRatioInt64ToBool;
+ if (rowGridSize % sizeRatioInt64ToBool != 0) {
+ throw std::runtime_error("Sequence length on each cp rank must be a multiple of 8");
+ }
+ int rowStartIdx = blockIdx[0] * rowGridSize;
+ int colStartIdx = blockIdx[1] * colGridSize;
+
+ uint64_t *inputPtr = (uint64_t*) attnMask.data_ptr();
+ uint64_t *outputPtr = (uint64_t*) output.data_ptr();
+
+ uint64_t *currPtr = inputPtr + rowStartIdx * (colGridSize * cpSize) + colStartIdx;
+ int numUnitToNextRow = cpSize * colGridSize;
+
+ uint64_t memmoveCnt = 0;
+ if (colGridSize > std::numeric_limits::max() / rowGridSize) {
+ throw std::runtime_error("sequence length too long or context parallel size too small");
+ }
+ uint64_t outputSize = static_cast(rowGridSize) * colGridSize;
+
+ for (size_t i = 0; i < rowGridSize; i++) {
+ if (memmoveCnt + colGridSize > outputSize) {
+ throw std::runtime_error("Memory move out of range.");
+ }
+ memmove(outputPtr, currPtr, colGridSize * sizeof(uint64_t));
+ memmoveCnt += colGridSize;
+ outputPtr += colGridSize;
+ currPtr += numUnitToNextRow;
+ }
+}
+
+PYBIND11_MODULE(adaptive_cp, m)
+{
+m.def("coarsen_mask",
+ &coarsen_mask,
+ "A function that coarse a bool tensor with given split number",
+ py::arg("input"), py::arg("splitNum"), py::arg("output"));
+m.def("search_kmeans",
+ &search_kmeans,
+ "Search optimal k-means clustering result among various number of clusters",
+ py::arg("attnMask"), py::arg("reduceMask"), py::arg("tmpAttnMask"), py::arg("tmpGridMask"),
+ py::arg("optGridMask"), py::arg("optAttnMask"), py::arg("optNumCluster"), py::arg("cpSize"),
+ py::arg("numIters"));
+m.def("get_mask_list_with_remap",
+ &get_mask_list_with_remap,
+ py::arg("attnMask"), py::arg("output"), py::arg("rowIdxSeq"), py::arg("colIdxSeq"));
+m.def("get_mask_list_without_remap",
+ &get_mask_list_without_remap,
+ py::arg("attnMask"), py::arg("output"), py::arg("blockIdx"), py::arg("cpSize"));
+}
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/algorithm/algorithm.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/algorithm/algorithm.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..580746e49dbf236beda0d1611d142b5bd84a1cd5
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/algorithm/algorithm.cpp
@@ -0,0 +1,24 @@
+#include
+
+void reuse_data_ptr(at::Tensor& des, at::Tensor& src, int64_t offset)
+{
+ TORCH_CHECK(
+ offset >= 0,
+ "Expect offset equal or greater than zero, got: ", offset);
+
+ TORCH_CHECK(
+ (offset + des.numel()) * des.element_size() <=
+ src.numel() * src.element_size(),
+ "Offsets overflow, got: ",
+ "offset ", offset * des.element_size(),
+ ", des storage size ", des.numel() * des.element_size(),
+ ", src storage size ", src.numel()* src.element_size());
+
+ char* data_ptr = static_cast(src.storage().data_ptr().get()) + offset * des.element_size();
+ at::DataPtr aim_data_ptr = at::DataPtr(data_ptr, des.storage().device());
+ des.storage().set_data_ptr(std::move(aim_data_ptr));
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("reuse_data_ptr", &reuse_data_ptr, "reuse tensor data ptr");
+}
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/atb/groupmatmul_add.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/atb/groupmatmul_add.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9f7335e1180e4f46833be369ab6f18a739f992e8
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/atb/groupmatmul_add.cpp
@@ -0,0 +1,71 @@
+// Copyright (c) 2023 Huawei Technologies Co., Ltd
+// Copyright (c) 2019, Facebook CORPORATION.
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef ENABLE_ATB
+#include "inc/atb_adapter.h"
+#include "atb/operation.h"
+#include "atb/train_op_params.h"
+#include "atb/infer_op_params.h"
+#include "../flop_counter/flop_counter.h"
+#endif
+
+using namespace std;
+using torch::autograd::AutogradContext;
+using torch::autograd::Function;
+namespace {
+
+
+void groupmatmul_add_fp32(const at::Tensor &x, const at::Tensor &weight, const at::Tensor &group_list, at::Tensor & grad)
+{
+#ifndef ENABLE_ATB
+ TORCH_CHECK(false, "ATB MatmulAdd not implemented");
+#else
+ atb::infer::GroupedMatmulInplaceAddParam param;
+ param.transposeA = true; // 是否转置A矩阵
+ param.transposeB = false; // 是否转置B矩阵
+
+ ParamSetter paramsetter;
+ paramsetter.Input(x)
+ .Input(weight)
+ .Input(group_list)
+ .Input(grad)
+ .Output(grad);
+ // 构造算子并执行
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "GroupMatmulAdd get op failed!");
+ RunAtbCmd(op, paramsetter, "GroupedMatmulInplaceAddOperation");
+ #ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::gmm_add_flop, x, weight, group_list);
+ #endif
+ return ;
+#endif
+}
+} // namespace
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("npu_groupmatmul_add_fp32", &groupmatmul_add_fp32, "matmul_add on ascend device",
+ pybind11::arg("x"), pybind11::arg("weight"), pybind11::arg("group_list"), pybind11::arg("grad"));
+}
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/atb/inc/atb_adapter.h b/model/train/yoco_moe/mindspeed/ops/csrc/atb/inc/atb_adapter.h
new file mode 100644
index 0000000000000000000000000000000000000000..061f861899ca197942ec3cc7a5ad7b0c3799c0a9
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/atb/inc/atb_adapter.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPEED_OPS_CSRC_ATB_INC_ATB_ADAPTER_H
+#define MINDSPEED_OPS_CSRC_ATB_INC_ATB_ADAPTER_H
+#include
+#include
+#include
+#include
+#include
+#include "atb/types.h"
+#include "atb/operation.h"
+#include "atb/utils.h"
+#if __has_include("torch_npu/csrc/flopcount/FlopCount.h")
+ #include "torch_npu/csrc/flopcount/FlopCount.h"
+#endif
+
+atb::Tensor AtTensor2Tensor(const at::Tensor atTensor);
+atb::Context* GetContext();
+at::Tensor GetWorkspaceTensor(uint64_t workspaceSize, atb::Operation *operation);
+uint64_t OperationSetup(atb::VariantPack variantPack, atb::Operation *operation, atb::Context* contextPtr);
+class ParamSetter {
+public:
+ ParamSetter& Input(const at::Tensor &tensor);
+ ParamSetter& Input(const c10::optional &tensor);
+ ParamSetter& Output(at::Tensor &tensor);
+ atb::VariantPack variantPack;
+};
+
+void RunAtbCmd(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name);
+
+#endif
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/atb/lcal_coc.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/atb/lcal_coc.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..885fd6cb56add521bbca63730999397705f9097c
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/atb/lcal_coc.cpp
@@ -0,0 +1,283 @@
+// Copyright (c) 2023 Huawei Technologies Co., Ltd
+// Copyright (c) 2019, Facebook CORPORATION.
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef ENABLE_ATB
+#include
+#include
+#include "inc/atb_adapter.h"
+#include "atb/operation.h"
+#include "atb/infer_op_params.h"
+#include "../flop_counter/flop_counter.h"
+#endif
+
+
+void matmul_all_reduce(const at::Tensor &input1, const at::Tensor &input2, const c10::optional &biasOpt,
+ at::Tensor &output, int rank, int rankSize, const std::string &commDomain)
+{
+ const at::Tensor &bias = biasOpt.value_or(at::Tensor());
+
+ atb::infer::LinearParallelParam param;
+ bool transB = input1.size(1) != input2.size(0);
+ param.transWeight = transB;
+ param.rank = rank;
+ param.rankSize = rankSize;
+ param.rankRoot = 0;
+ param.hasResidual = biasOpt.has_value();
+ param.backend = "lcoc";
+ param.commMode = atb::infer::CommMode::COMM_MULTI_PROCESS;
+ param.type = atb::infer::LinearParallelParam::ParallelType::LINEAR_ALL_REDUCE;
+ param.keepIntermediate = false;
+ param.commDomain = commDomain;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(input1)
+ .Input(input2);
+ if (biasOpt.has_value()) {
+ paramsetter.Input(bias);
+ }
+ paramsetter.Output(output);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "lcal coc get op failed!");
+ RunAtbCmd(op, paramsetter, "matmul_all_reduce");
+#ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::coc_flop, input1, input2, transB, rankSize, false);
+#endif
+}
+
+
+void all_gather_matmul(const at::Tensor &input1, const at::Tensor &input2, const c10::optional &biasOpt,
+ at::Tensor &output, int rank, int rankSize, const std::string &commDomain)
+{
+ const at::Tensor &bias = biasOpt.value_or(at::Tensor());
+
+ atb::infer::LinearParallelParam param;
+ bool transB = input1.size(1) != input2.size(0);
+ param.transWeight = transB;
+ param.rank = rank;
+ param.rankSize = rankSize;
+ param.rankRoot = 0;
+ param.hasResidual = biasOpt.has_value();
+ param.backend = "lcoc";
+ param.commMode = atb::infer::CommMode::COMM_MULTI_PROCESS;
+ param.type = atb::infer::LinearParallelParam::ParallelType::ALL_GATHER_LINEAR;
+ param.keepIntermediate = false;
+ param.commDomain = commDomain;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(input1)
+ .Input(input2);
+ if (biasOpt.has_value()) {
+ paramsetter.Input(bias);
+ }
+ paramsetter.Output(output);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "lcal coc get op failed!");
+ RunAtbCmd(op, paramsetter, "all_gather_matmul");
+#ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::coc_flop, input1, input2, transB, rankSize, true);
+#endif
+}
+
+
+void all_gather_matmul_v2(const at::Tensor &input1, const at::Tensor &input2, const c10::optional &biasOpt,
+ at::Tensor &output, at::Tensor &commOutput, int rank, int rankSize, const std::string &commDomain)
+{
+ const at::Tensor &bias = biasOpt.value_or(at::Tensor());
+
+ atb::infer::LinearParallelParam param;
+ bool transB = input1.size(1) != input2.size(0);
+ param.transWeight = transB;
+ param.rank = rank;
+ param.rankSize = rankSize;
+ param.rankRoot = 0;
+ param.hasResidual = biasOpt.has_value();
+ param.backend = "lcoc";
+ param.commMode = atb::infer::CommMode::COMM_MULTI_PROCESS;
+ param.type = atb::infer::LinearParallelParam::ParallelType::ALL_GATHER_LINEAR;
+ param.keepIntermediate = true;
+ param.commDomain = commDomain;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(input1)
+ .Input(input2);
+ if (biasOpt.has_value()) {
+ paramsetter.Input(bias);
+ }
+ paramsetter.Output(output)
+ .Output(commOutput);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "lcal coc get op failed!");
+ RunAtbCmd(op, paramsetter, "all_gather_matmul_v2");
+#ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::coc_flop, input1, input2, transB, rankSize, true);
+#endif
+}
+
+
+void matmul_reduce_scatter(const at::Tensor &input1, const at::Tensor &input2, const c10::optional &biasOpt,
+ at::Tensor &output, int rank, int rankSize, const std::string &commDomain)
+{
+ const at::Tensor &bias = biasOpt.value_or(at::Tensor());
+
+ atb::infer::LinearParallelParam param;
+ bool transB = input1.size(1) != input2.size(0);
+ param.transWeight = transB;
+ param.rank = rank;
+ param.rankSize = rankSize;
+ param.rankRoot = 0;
+ param.hasResidual = biasOpt.has_value();
+ param.backend = "lcoc";
+ param.commMode = atb::infer::CommMode::COMM_MULTI_PROCESS;
+ param.type = atb::infer::LinearParallelParam::ParallelType::LINEAR_REDUCE_SCATTER;
+ param.keepIntermediate = false;
+ param.commDomain = commDomain;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(input1)
+ .Input(input2);
+ if (biasOpt.has_value()) {
+ paramsetter.Input(bias);
+ }
+ paramsetter.Output(output);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "lcal coc get op failed!");
+ RunAtbCmd(op, paramsetter, "matmul_reduce_scatter");
+#ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::coc_flop, input1, input2, transB, rankSize, false);
+#endif
+}
+
+
+void pure_matmul(const at::Tensor &input1, const at::Tensor &input2, const c10::optional &biasOpt,
+ at::Tensor &output, int rank, int rankSize, const std::string &commDomain)
+{
+ const at::Tensor &bias = biasOpt.value_or(at::Tensor());
+
+ atb::infer::LinearParallelParam param;
+ bool transB = input1.size(1) != input2.size(0);
+ param.transWeight = transB;
+ param.rank = rank;
+ param.rankSize = rankSize;
+ param.rankRoot = 0;
+ param.hasResidual = biasOpt.has_value();
+ param.backend = "lcoc";
+ param.commMode = atb::infer::CommMode::COMM_MULTI_PROCESS;
+ param.type = atb::infer::LinearParallelParam::ParallelType::PURE_LINEAR;
+ param.keepIntermediate = false;
+ param.commDomain = commDomain;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(input1)
+ .Input(input2);
+ if (biasOpt.has_value()) {
+ paramsetter.Input(bias);
+ }
+ paramsetter.Output(output);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "lcal coc get op failed!");
+ RunAtbCmd(op, paramsetter, "pure_matmul");
+#ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::coc_flop, input1, input2, transB, rankSize, false);
+#endif
+}
+
+template
+struct atb_support_all_gather_matmul_reduce_scatter_op : std::false_type {};
+
+template
+struct atb_support_all_gather_matmul_reduce_scatter_op> : std::true_type {};
+
+template
+void all_gather_matmul_reduce_scatter(const at::Tensor &input1, const at::Tensor &input2,
+ const c10::optional &biasOpt, at::Tensor &output, int rank,
+ int tpSize, const std::string &commDomain, int agDim, int rsDim, bool innerDimIsAg)
+{
+ if constexpr (atb_support_all_gather_matmul_reduce_scatter_op::value) {
+ const at::Tensor &bias = biasOpt.value_or(at::Tensor());
+ T param;
+ bool transB = input1.size(1) != input2.size(0);
+ param.transWeight = transB;
+ param.rank = rank;
+ param.rankSize = tpSize;
+ param.rankRoot = 0;
+ param.twoDimTPInfo.agDim = agDim;
+ param.twoDimTPInfo.rsDim = rsDim;
+ param.twoDimTPInfo.innerDimIsAg = innerDimIsAg;
+ param.hasResidual = biasOpt.has_value();
+ param.backend = "lcoc";
+ param.commMode = atb::infer::CommMode::COMM_MULTI_PROCESS;
+ param.type = T::ParallelType::ALL_GATHER_LINEAR_REDUCE_SCATTER;
+ param.keepIntermediate = false;
+ param.commDomain = commDomain;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(input1)
+ .Input(input2);
+ if (biasOpt.has_value()) {
+ paramsetter.Input(bias);
+ }
+ paramsetter.Output(output);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "lcal coc get op failed!");
+ RunAtbCmd(op, paramsetter, "all_gather_matmul_reduce_scatter");
+#ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::coc_flop, input1, input2, transB, agDim, true);
+#endif
+ } else {
+ TORCH_CHECK(false, "Current version of ATB doesn't support the all_gather_matmul_reduce_scatter operator!");
+ }
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("matmul_all_reduce", &matmul_all_reduce, "matmul_all_reduce", pybind11::arg("input1"),
+ pybind11::arg("input2"), pybind11::arg("biasOpt"), pybind11::arg("output"), pybind11::arg("rank"), pybind11::arg("rankSize"), pybind11::arg("commDomain"));
+ m.def("all_gather_matmul", &all_gather_matmul, "all_gather_matmul", pybind11::arg("input1"),
+ pybind11::arg("input2"), pybind11::arg("biasOpt"), pybind11::arg("output"), pybind11::arg("rank"), pybind11::arg("rankSize"), pybind11::arg("commDomain"));
+ m.def("all_gather_matmul_v2", &all_gather_matmul_v2, "all_gather_matmul_v2", pybind11::arg("input1"),
+ pybind11::arg("input2"), pybind11::arg("biasOpt"), pybind11::arg("output"), pybind11::arg("commOutput"),
+ pybind11::arg("rank"), pybind11::arg("rankSize"), pybind11::arg("commDomain"));
+ m.def("matmul_reduce_scatter", &matmul_reduce_scatter, "matmul_reduce_scatter", pybind11::arg("input1"),
+ pybind11::arg("input2"), pybind11::arg("biasOpt"), pybind11::arg("output"), pybind11::arg("rank"), pybind11::arg("rankSize"), pybind11::arg("commDomain"));
+ m.def("pure_matmul", &pure_matmul, "pure_matmul", pybind11::arg("input1"), pybind11::arg("input2"),
+ pybind11::arg("biasOpt"), pybind11::arg("output"), pybind11::arg("rank"), pybind11::arg("rankSize"), pybind11::arg("commDomain"));
+ m.def("all_gather_matmul_reduce_scatter", &all_gather_matmul_reduce_scatter, "all_gather_matmul_reduce_scatter", pybind11::arg("input1"),
+ pybind11::arg("input2"), pybind11::arg("biasOpt"), pybind11::arg("output"), pybind11::arg("rank"), pybind11::arg("tpSize"), pybind11::arg("commDomain"), pybind11::arg("agDim"), pybind11::arg("rsDim"), pybind11::arg("innerDimIsAg"));
+}
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/atb/matmul_add.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/atb/matmul_add.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2399e4dfc5395737857090c37e6f64c972ea38b2
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/atb/matmul_add.cpp
@@ -0,0 +1,72 @@
+// Copyright (c) 2023 Huawei Technologies Co., Ltd
+// Copyright (c) 2019, Facebook CORPORATION.
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef ENABLE_ATB
+#include "inc/atb_adapter.h"
+#include "atb/operation.h"
+#include "atb/train_op_params.h"
+#include "atb/infer_op_params.h"
+#include "../flop_counter/flop_counter.h"
+#endif
+
+using namespace std;
+using torch::autograd::AutogradContext;
+using torch::autograd::Function;
+namespace {
+
+
+void matmul_add_fp32(const at::Tensor &x, const at::Tensor &weight, at::Tensor & C)
+{
+#ifndef ENABLE_ATB
+ TORCH_CHECK(false, "ATB MatmulAdd not implemented");
+#else
+ atb::infer::LinearParam param;
+ param.transposeA = true; // 是否转置A矩阵
+ param.transposeB = false; // 是否转置B矩阵
+ param.hasBias = false;
+ param.enAccum = true;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(x)
+ .Input(weight)
+ .Input(C)
+ .Output(C);
+ // 构造算子并执行
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "MatmulAdd_forward get op failed!");
+ RunAtbCmd(op, paramsetter, "LinearOperation");
+ #ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::mm_flop, x, weight);
+ #endif
+ return ;
+#endif
+}
+} // namespace
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("npu_matmul_add_fp32", &matmul_add_fp32, "matmul_add on ascend device",
+ pybind11::arg("x"), pybind11::arg("weight"), pybind11::arg("C"));
+}
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/atb/rms_norm.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/atb/rms_norm.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..81a7769e74f328b714df6572e965cc1aaf988ca3
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/atb/rms_norm.cpp
@@ -0,0 +1,145 @@
+// Copyright (c) 2023 Huawei Technologies Co., Ltd
+// Copyright (c) 2019, Facebook CORPORATION.
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef ENABLE_ATB
+#include "inc/atb_adapter.h"
+#include "atb/operation.h"
+#include "atb/train_op_params.h"
+#include "atb/infer_op_params.h"
+#endif
+
+using namespace std;
+using torch::autograd::AutogradContext;
+using torch::autograd::Function;
+namespace {
+const static int RMSNORM_LAYERTYPE = 1;
+const static int SAVE_X = 0;
+const static int SAVE_RSTD = 1;
+const static int SAVE_GAMMA = 2;
+const static int N = 32;
+
+void InferShapeRmsNorm(c10::SmallVector &size, const at::Tensor &self, const at::Tensor &gamma)
+{
+ int64_t rstd_dim = self.dim();
+ rstd_dim -= gamma.dim();
+ TORCH_CHECK(rstd_dim >= 0,
+ "RmsNorm intensor gamma dim error,gamma's dim should not greater than x's dim");
+ for (uint64_t i = 0; i < self.dim(); i++) {
+ if (i < rstd_dim) {
+ size.emplace_back(self.size(i));
+ } else {
+ size.emplace_back(1);
+ }
+ }
+}
+
+void CheckRmsNorm(const at::Tensor &x, const at::Tensor &gamma)
+{
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::BFloat16 ||
+ x.scalar_type() == at::ScalarType::Float,
+ "Input x dtype ", x.scalar_type(), " invalid, should be float, float16 or bfloat16");
+ TORCH_CHECK(x.scalar_type() == gamma.scalar_type(),
+ "Input x dtype should be same with gamma, but got x ", x.scalar_type(), " gamma ", gamma.scalar_type());
+}
+
+class NPURmsNormFunction : public torch::autograd::Function {
+public:
+ static at::Tensor forward(
+ AutogradContext *ctx, const at::Tensor &x, const at::Tensor &gamma, float epsilon)
+ {
+#ifndef ENABLE_ATB
+ TORCH_CHECK(false, "ATB RmsNorm not implemented");
+#else
+ at::AutoNonVariableTypeMode g;
+ c10::SmallVector tensor_rstd_shape;
+ CheckRmsNorm(x, gamma);
+ InferShapeRmsNorm(tensor_rstd_shape, x, gamma);
+ // apply tensor
+ at::Tensor tensor_rstd = at::empty(at::IntArrayRef(tensor_rstd_shape), x.options().dtype(at::ScalarType::Float));
+ at::Tensor tensor_y = at::empty(x.sizes(), x.options());
+
+ atb::infer::RmsNormParam param;
+ param.layerType = (atb::infer::RmsNormParam::RmsNormType)RMSNORM_LAYERTYPE;
+ param.normParam.epsilon = epsilon;
+ param.normParam.rstd = true;
+
+ // set input and output
+ ParamSetter paramsetter;
+ paramsetter.Input(x)
+ .Input(gamma)
+ .Output(tensor_y)
+ .Output(tensor_rstd);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "RmsNorm get op failed!");
+ RunAtbCmd(op, paramsetter, "RmsNorm_forward");
+
+ ctx->save_for_backward({x, tensor_rstd, gamma});
+
+ return tensor_y;
+#endif
+ }
+
+ static std::vector backward(AutogradContext *ctx, std::vector grad_output)
+ {
+#ifndef ENABLE_ATB
+ TORCH_CHECK(false, "RmsNormBackward not implemented");
+#else
+ auto saved = ctx->get_saved_variables();
+ auto x = saved[SAVE_X];
+ auto rstd = saved[SAVE_RSTD];
+ auto gamma = saved[SAVE_GAMMA];
+ atb::train::RmsNormBackwardParam param;
+
+ at::Tensor tensor_x_grad = at::empty(x.sizes(), x.options());
+ at::Tensor tensor_gamma_grad = at::empty(gamma.sizes(), gamma.options().dtype(at::ScalarType::Float));
+
+ ParamSetter paramsetter;
+ paramsetter.Input(grad_output[0])
+ .Input(x)
+ .Input(rstd)
+ .Input(gamma)
+ .Output(tensor_x_grad)
+ .Output(tensor_gamma_grad);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "RmsNormBackward get op failed!");
+ RunAtbCmd(op, paramsetter, "RmsNorm_backward");
+
+ return {tensor_x_grad, tensor_gamma_grad, at::Tensor()};
+#endif
+ }
+};
+} // namespace
+
+at::Tensor npu_rms_norm(const at::Tensor &x, const at::Tensor &gamma, float epsilon)
+{
+ return NPURmsNormFunction::apply(x, gamma, epsilon);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("rms_norm", &npu_rms_norm, "rms_norm on ascend device",
+ pybind11::arg("x"), pybind11::arg("gamma"), pybind11::arg("epsilon")=1e-6);
+}
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/atb/swiglu.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/atb/swiglu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1a364bb50acd7cab385ada7fd89ae67aebaa9f3f
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/atb/swiglu.cpp
@@ -0,0 +1,136 @@
+// Copyright (c) 2023 Huawei Technologies Co., Ltd
+// Copyright (c) 2019, Facebook CORPORATION.
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef ENABLE_ATB
+#include "inc/atb_adapter.h"
+#include "atb/operation.h"
+#include "atb/infer_op_params.h"
+#endif
+
+using namespace std;
+using torch::autograd::AutogradContext;
+using torch::autograd::Function;
+namespace {
+const static int N = 32;
+void InferSwigluForward(c10::SmallVector &out_tensor_shape, const at::Tensor &x, int32_t dim)
+{
+ int64_t split_dim = dim;
+ if (split_dim < 0) {
+ split_dim += x.dim();
+ }
+ TORCH_CHECK(split_dim >= 0 && split_dim < x.dim(), "Input dim range is invalid");
+ const int32_t split_num = 2;
+ out_tensor_shape[split_dim] = x.size(split_dim) / split_num;
+}
+
+void CheckSwigluForward(const at::Tensor &x)
+{
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::BFloat16 ||
+ x.scalar_type() == at::ScalarType::Float, "Input tensor dtype ", x.scalar_type(),
+ " invalid, should be float32, float16 or bfloat16");
+}
+
+void CheckSwigluBackward(const at::Tensor &y_grad, const at::Tensor &x)
+{
+ TORCH_CHECK(y_grad.scalar_type() == at::ScalarType::Half || y_grad.scalar_type() == at::ScalarType::BFloat16 ||
+ y_grad.scalar_type() == at::ScalarType::Float, "Input y_grad tensor dtype ", y_grad.scalar_type(),
+ " invalid, should be float32, float16 or bfloat16");
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::BFloat16 ||
+ x.scalar_type() == at::ScalarType::Float, "Input x tensor dtype ", x.scalar_type(),
+ " invalid, should be float32, float16 or bfloat16");
+ TORCH_CHECK(x.scalar_type() == y_grad.scalar_type(), "Input x tensor dtype is not equal to y_grad");
+}
+
+class NPUSwigluFunction : public torch::autograd::Function {
+public:
+ static at::Tensor forward(AutogradContext *ctx, const at::Tensor &x, int32_t dim = -1)
+ {
+#ifndef ENABLE_ATB
+ TORCH_CHECK(false, "swiglu_forward not implemented");
+#else
+ at::AutoNonVariableTypeMode g;
+ CheckSwigluForward(x);
+ c10::SmallVector out_tensor_shape{x.sizes()};
+ InferSwigluForward(out_tensor_shape, x, dim);
+ // apply tensor
+ at::Tensor y = at::empty(out_tensor_shape, x.options());
+
+ atb::infer::ActivationParam param;
+ param.activationType = atb::infer::ActivationType::ACTIVATION_SWIGLU_FORWARD;
+ param.dim = dim;
+
+ // set input and output
+ ParamSetter paramsetter;
+ paramsetter.Input(x)
+ .Output(y);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "swiglu_forward get op failed!");
+ RunAtbCmd(op, paramsetter, "swiglu_forward");
+
+ ctx->save_for_backward({x});
+ ctx->saved_data["dim"] = dim;
+ return y;
+#endif
+ }
+
+ static std::vector backward(AutogradContext *ctx, std::vector grad_output)
+ {
+#ifndef ENABLE_ATB
+ TORCH_CHECK(false, "swiglu_backward not implemented");
+#else
+ auto dim = ctx->saved_data["dim"].toInt();
+ auto saved = ctx->get_saved_variables();
+ auto x = saved[0];
+ CheckSwigluBackward(grad_output[0], x);
+ at::Tensor x_grad = at::empty(x.sizes(), x.options());
+
+ atb::infer::ActivationParam param;
+ param.activationType = atb::infer::ActivationType::ACTIVATION_SWIGLU_BACKWARD;
+ param.dim = dim;
+
+ ParamSetter paramsetter;
+ paramsetter.Input(grad_output[0])
+ .Input(x)
+ .Output(x_grad);
+
+ atb::Operation* op = nullptr;
+ atb::CreateOperation(param, &op);
+ TORCH_CHECK(op != nullptr, "swiglu_backward get op failed!");
+ RunAtbCmd(op, paramsetter, "swiglu_backward");
+
+ return {x_grad, at::Tensor()};
+#endif
+ }
+};
+} // namespace
+
+at::Tensor npu_swiglu(const at::Tensor &x, int32_t dim)
+{
+ return NPUSwigluFunction::apply(x, dim);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("swiglu", &npu_swiglu, "swiglu realization", pybind11::arg("x"), pybind11::arg("dim")=-1);
+}
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/atb/utils/atb_adapter.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/atb/utils/atb_adapter.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..835758e14b2bcfe0425ab9515075ccf2569bbaa5
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/atb/utils/atb_adapter.cpp
@@ -0,0 +1,140 @@
+
+/*
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "../inc/atb_adapter.h"
+#include
+#include
+#include
+
+using namespace std;
+
+static atb::Context* msContext = nullptr;
+
+at::Tensor FormatTrans(const at::Tensor &at_tensor)
+{
+ if (at_tensor.defined()) {
+ TORCH_CHECK(torch_npu::utils::is_npu(at_tensor), "only npu tensor is supported");
+ return at_npu::native::npu_format_cast(at_tensor, ACL_FORMAT_ND);
+ }
+ return at_tensor;
+}
+
+atb::Tensor AtTensor2Tensor(const at::Tensor atTensor)
+{
+ static std::map dtypeMap = {
+ {at::ScalarType::Bool, ACL_BOOL}, {at::ScalarType::Byte, ACL_UINT8},
+ {at::ScalarType::Char, ACL_INT8}, {at::ScalarType::Half, ACL_FLOAT16},
+ {at::ScalarType::Float, ACL_FLOAT}, {at::ScalarType::Int, ACL_INT32},
+ {at::ScalarType::Long, ACL_INT64}, {at::ScalarType::BFloat16, ACL_BF16},
+ };
+
+ TORCH_CHECK(atTensor.is_contiguous(), "atTensor is not contiguous");
+ atb::Tensor tensor;
+ tensor.desc.format = ACL_FORMAT_ND;
+ tensor.deviceData = atTensor.data_ptr();
+
+ tensor.desc.shape.dimNum = atTensor.sizes().size();
+ for (uint64_t i = 0; i < atTensor.sizes().size(); i++) {
+ tensor.desc.shape.dims[i] = atTensor.sizes()[i];
+ }
+
+ auto it = dtypeMap.find(atTensor.scalar_type());
+ TORCH_CHECK(it != dtypeMap.end(), "not support dtype:");
+ tensor.desc.dtype = it->second;
+
+ tensor.dataSize = atb::Utils::GetTensorSize(tensor);
+
+ return tensor;
+}
+
+void RunAtbCmd(atb::Operation *op, const ParamSetter ¶msetter, const std::string &name)
+{
+ auto contextPtr = GetContext();
+ uint64_t workspaceSize = OperationSetup(paramsetter.variantPack, op, contextPtr);
+ auto workspaceTensor = GetWorkspaceTensor(workspaceSize, op);
+ const void *workspacePtr = nullptr;
+ workspacePtr = workspaceTensor.storage().data();
+ auto acl_call = [op, contextPtr, paramsetter, workspacePtr, workspaceSize]() -> int {
+ auto st = op->Execute(paramsetter.variantPack, (uint8_t *)workspacePtr, workspaceSize, contextPtr);
+ DestroyOperation(op);
+ return 0;
+ };
+ at_npu::native::OpCommand cmd;
+ cmd.Name(name);
+ cmd.SetCustomHandler(acl_call);
+ cmd.Run();
+}
+
+ParamSetter& ParamSetter::Input(const at::Tensor &tensor)
+{
+ if (!tensor.defined()) {
+ variantPack.inTensors.push_back(atb::Tensor());
+ return *this;
+ }
+ at::Tensor newTensor = FormatTrans(tensor);
+ if(!newTensor.is_contiguous()) {
+ newTensor = newTensor.contiguous();
+ }
+ auto AtTensor = AtTensor2Tensor(newTensor);
+
+ variantPack.inTensors.push_back(AtTensor);
+ return *this;
+}
+
+ParamSetter& ParamSetter::Input(const c10::optional &tensor)
+{
+ if (!tensor.has_value()) {
+ variantPack.inTensors.push_back(atb::Tensor());
+ return *this;
+ }
+ return Input(tensor.value());
+}
+
+ParamSetter& ParamSetter::Output(at::Tensor &output)
+{
+ auto AtTensor = AtTensor2Tensor(output);
+ variantPack.outTensors.push_back(AtTensor);
+ return *this;
+}
+
+uint64_t OperationSetup(atb::VariantPack variantPack, atb::Operation *operation, atb::Context* contextPtr)
+{
+ uint64_t workspaceSize = 0;
+ atb::Status status = operation->Setup(variantPack, workspaceSize, contextPtr);
+ TORCH_CHECK(status == 0, "setup failed!");
+ return workspaceSize;
+}
+
+at::Tensor GetWorkspaceTensor(uint64_t workspaceSize, atb::Operation *operation)
+{
+ at::TensorOptions options = at::TensorOptions(torch_npu::utils::get_npu_device_type());
+ at::Tensor workspaceTensor = at::empty(at::IntArrayRef(workspaceSize), options.dtype(at::kByte));
+ return workspaceTensor;
+}
+
+atb::Context* GetContext()
+{
+ if (msContext == nullptr) {
+ auto status = atb::CreateContext(&msContext);
+ TORCH_CHECK(status == 0, "create context failed!");
+ int32_t devId = 0;
+ aclrtGetDevice(&devId);
+ aclrtStream stream = c10_npu::getCurrentNPUStream(devId).stream(false);
+ TORCH_CHECK(stream != nullptr, "get current stream failed");
+ msContext->SetExecuteStream(stream);
+ }
+ return msContext;
+}
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/cann/ffn.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/cann/ffn.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1886e6da6dcce34c1f41907b9d0e488a1f939a69
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/cann/ffn.cpp
@@ -0,0 +1,96 @@
+// Copyright (c) 2024 Huawei Technologies Co., Ltd
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include
+#include
+#include
+#include
+#include
+
+#include "inc/aclnn_common.h"
+
+const static int MIN_DIM = 2;
+const static int X_MAX_DIM = 8;
+
+using npu_preparation = at_npu::native::OpPreparation;
+
+namespace op_infer {
+constexpr int SIZE = 8;
+
+c10::SmallVector array_to_small_vector(c10::IntArrayRef shape)
+{
+ c10::SmallVector small_shape;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ small_shape.emplace_back(shape[i]);
+ }
+ return small_shape;
+}
+}
+
+at::Tensor npu_ffn(const at::Tensor &x, const at::Tensor &weight1, const at::Tensor &weight2,
+ std::string activation, c10::optional expert_tokens, c10::optional expert_tokens_index,
+ const c10::optional &bias1, const c10::optional &bias2,
+ const c10::optional &scale, const c10::optional &offset,
+ const c10::optional &deq_scale1, const c10::optional &deq_scale2,
+ const c10::optional &antiquant_scale1, const c10::optional &antiquant_scale2,
+ const c10::optional &antiquant_offset1, const c10::optional &antiquant_offset2,
+ c10::optional inner_precise, c10::optional output_dtype)
+{
+ auto weight1_dim_num = weight1.dim();
+ auto weight2_dim_num = weight2.dim();
+ auto x_dim_num = x.dim();
+ TORCH_CHECK(x_dim_num >= MIN_DIM && x_dim_num <= X_MAX_DIM, "x shape dims should be 2~8, but it is ", x_dim_num);
+ auto x_k_dim = x.size(x.dim() - 1);
+ auto wight1_k_dim = weight1.size(weight1.dim() - 2);
+ TORCH_CHECK(x_k_dim == wight1_k_dim, "The k of x and weight should be equal. but x_k_dim is ",
+ x_k_dim, ", wight1_k_dim is ", wight1_k_dim);
+
+ TORCH_CHECK(!(expert_tokens.has_value() && expert_tokens_index.has_value()),
+ "expert_tokens and expert_tokens_index should not have the value simultaneously.");
+
+ char *activation_ptr = const_cast(activation.data());
+ const at::Tensor &bias1_real = bias1.value_or(at::Tensor());
+ const at::Tensor &bias2_real = bias2.value_or(at::Tensor());
+ const at::Tensor &scale_real = scale.value_or(at::Tensor());
+ const at::Tensor &offset_real = offset.value_or(at::Tensor());
+ const at::Tensor &deq_scale1_real = deq_scale1.value_or(at::Tensor());
+ const at::Tensor &deq_scale2_real = deq_scale2.value_or(at::Tensor());
+ const at::Tensor &antiquant_scale1_real = antiquant_scale1.value_or(at::Tensor());
+ const at::Tensor &antiquant_scale2_real = antiquant_scale2.value_or(at::Tensor());
+ const at::Tensor &antiquant_offset1_real = antiquant_offset1.value_or(at::Tensor());
+ const at::Tensor &antiquant_offset2_real = antiquant_offset2.value_or(at::Tensor());
+ auto output_size = op_infer::array_to_small_vector(x.sizes());
+ output_size[x.dim() - 1] = weight2.size(weight2.dim() - 1);
+ c10::TensorOptions options = x.options().dtype(x.scalar_type());
+ if (x.scalar_type() == at::kChar && weight1.scalar_type() == at::kChar && weight2.scalar_type() == at::kChar) {
+ options = x.options().dtype(output_dtype.value_or(at::kHalf));
+ }
+ at::Tensor result = at::empty(output_size, options);
+ int64_t inner_precise_val = inner_precise.has_value() ? inner_precise.value() : 0;
+
+ bool tokens_index_flag = expert_tokens_index.has_value();
+
+ const at::Tensor &expert_tokens_real = expert_tokens.has_value() ? expert_tokens.value() :
+ expert_tokens_index.value_or(at::Tensor());
+
+ ACLNN_CMD(aclnnFFNV3, x, weight1, weight2, expert_tokens_real, bias1_real, bias2_real,
+ scale_real, offset_real, deq_scale1_real, deq_scale2_real, antiquant_scale1_real, antiquant_scale2_real,
+ antiquant_offset1_real, antiquant_offset2_real, activation_ptr, inner_precise_val, tokens_index_flag, result);
+
+ return result;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("npu_ffn", &npu_ffn, "npu_ffn");
+}
\ No newline at end of file
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/cann/fusion_attention_v2.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/cann/fusion_attention_v2.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3013b2fb5e32874b60117e89106d19a3d8d308ae
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/cann/fusion_attention_v2.cpp
@@ -0,0 +1,458 @@
+// Copyright (c) 2023 Huawei Technologies Co., Ltd
+// Copyright (c) 2019, Facebook CORPORATION.
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include
+#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
+#include "torch_npu/csrc/framework/utils/OpAdapter.h"
+#include "torch_npu/csrc/core/npu/NPUFormat.h"
+#include
+
+#include "inc/aclnn_common.h"
+#include "../flop_counter/flop_counter.h"
+
+const static int FLASH_THRESHOLD = 512;
+const static int N = 32;
+const static int64_t SOFTMAXMAX_LAST_DIMSHAPE = 8;
+using namespace at_npu::native;
+
+constexpr static int SIZE_8 = 8;
+
+enum class DropOutStatus {
+ DROPOUT_NORMAL = 0,
+ DROPOUT_NONE,
+ DROPOUT_ALL
+};
+
+enum class SparseMode {
+ NO_MASK = 0,
+ ALL_MASK,
+ LEFT_UP_CAUSAL,
+ RIGHT_DOWN_CAUSAL,
+ BAND,
+ PREFIX,
+ PREFIX_COMPRESS,
+ RIGHT_DOWN_CAUSAL_BAND,
+ BAND_LEFT_UP_CAUSAL
+};
+DropOutStatus get_dropout_status(double keep_prob)
+{
+ if (keep_prob == 0) {
+ return DropOutStatus::DROPOUT_ALL;
+ }
+ if (keep_prob == 1.) {
+ return DropOutStatus::DROPOUT_NONE;
+ }
+ return DropOutStatus::DROPOUT_NORMAL;
+}
+
+at::Tensor format_trans(const at::Tensor &at_tensor)
+{
+ if (at_tensor.defined()) {
+ TORCH_CHECK(torch_npu::utils::is_npu(at_tensor), "only npu tensor is supported");
+ return at_npu::native::npu_format_cast(at_tensor, ACL_FORMAT_ND);
+ }
+ return at_tensor;
+}
+
+at::Tensor dropout_gen_mask(const at::Tensor &query, const at::Tensor &key, double keep_prob, int64_t head_num, const std::string &input_layout,
+ bool gen_mask_parallel, bool sync, int64_t &seed, int64_t &offset, int64_t &numels)
+{
+ at::Tensor drop_mask;
+ if (input_layout == "BSH") {
+ numels = query.size(0) * head_num * query.size(1) * key.size(1); // [B,N,S,S]
+ } else if (input_layout == "SBH") {
+ numels = query.size(1) * head_num * query.size(0) * key.size(0); // [B,N,S,S]
+ } else if (input_layout == "BNSD") {
+ numels = query.size(0) * query.size(1) * query.size(2) * key.size(2); // [B,N,S,S]
+ } else if (input_layout == "BSND") {
+ numels = query.size(0) * query.size(2) * query.size(1) * key.size(1); // [B,N,S,S]
+ }
+ int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
+ length += 32;
+ if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
+ const auto gen = at_npu::detail::getDefaultNPUGenerator();
+ auto pair = at::check_generator(gen)->philox_engine_inputs(10);
+ seed = pair.first;
+ offset = pair.second;
+ drop_mask = at_npu::native::npu_dropout_gen_mask(query, at::IntArrayRef{ numels }, 1 - keep_prob,
+ seed, offset, gen_mask_parallel, sync);
+ } else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
+ drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
+ }
+ return drop_mask;
+}
+
+std::tuple npu_fusion_attention_backward_v2(
+ const at::Tensor &query,
+ const at::Tensor &key,
+ const at::Tensor &value,
+ const at::Tensor &dy,
+ int64_t head_num,
+ const std::string &input_layout,
+ const c10::optional &pse,
+ const c10::optional &drop_mask,
+ const c10::optional &padding_mask,
+ const c10::optional &atten_mask,
+ const c10::optional &softmax_max,
+ const c10::optional &softmax_sum,
+ const c10::optional &softmax_in,
+ const c10::optional &attention_in,
+ double scale_value,
+ double keep_prob,
+ int64_t pre_tokens,
+ int64_t next_tokens,
+ int64_t inner_precise,
+ const c10::optional> &prefix,
+ const c10::optional> &actual_seq_qlen,
+ const c10::optional> &actual_seq_kvlen,
+ const c10::optional> &q_start_idx,
+ const c10::optional> &kv_start_idx,
+ int64_t sparse_mode,
+ int64_t pse_type)
+{
+ double scale = scale_value;
+
+ const at::Tensor &pse_const = pse.value_or(at::Tensor());
+ const at::Tensor &drop_mask_const = drop_mask.value_or(at::Tensor());
+ const at::Tensor &padding_mask_const = padding_mask.value_or(at::Tensor());
+ const at::Tensor &atten_mask_const = atten_mask.value_or(at::Tensor());
+ const at::Tensor &softmax_max_const = softmax_max.value_or(at::Tensor());
+ const at::Tensor &softmax_sum_const = softmax_sum.value_or(at::Tensor());
+ const at::Tensor &softmax_const = softmax_in.value_or(at::Tensor());
+ const at::Tensor &attention_const = attention_in.value_or(at::Tensor());
+ auto prefixN_tmp = prefix.value_or(std::vector{});
+ auto ac_seq_qlen_tmp = actual_seq_qlen.value_or(std::vector{});
+ auto ac_seq_kvlen_tmp = actual_seq_kvlen.value_or(std::vector{});
+ auto q_start_idx_val_tmp = q_start_idx.value_or(std::vector{});
+ auto kv_start_idx_val_tmp = kv_start_idx.value_or(std::vector{});
+
+ c10::optional prefixN(prefixN_tmp);
+ c10::optional ac_seq_qlen(ac_seq_qlen_tmp);
+ c10::optional ac_seq_kvlen(ac_seq_kvlen_tmp);
+ c10::optional q_start_idx_val(q_start_idx_val_tmp);
+ c10::optional kv_start_idx_val(kv_start_idx_val_tmp);
+
+ at::Tensor format_query = format_trans(query);
+ at::Tensor format_key = format_trans(key);
+ at::Tensor format_value = format_trans(value);
+ at::Tensor format_dy = format_trans(dy);
+
+ at::Tensor format_pse = format_trans(pse_const);
+ at::Tensor format_drop_mask = format_trans(drop_mask_const);
+ at::Tensor format_padding_mask = format_trans(padding_mask_const);
+ at::Tensor format_atten_mask = format_trans(atten_mask_const);
+ at::Tensor format_softmax_max = format_trans(softmax_max_const);
+ at::Tensor format_softmax_sum = format_trans(softmax_sum_const);
+ at::Tensor format_softmax = format_trans(softmax_const);
+ at::Tensor format_attention = format_trans(attention_const);
+ at::Tensor dq = at::empty(format_query.sizes(), format_query.options());
+ at::Tensor dk = at::empty(format_key.sizes(), format_key.options());
+ at::Tensor dv = at::empty(format_value.sizes(), format_value.options());
+ char* input_layout_ptr = const_cast(input_layout.c_str());
+ at::Tensor dpse;
+ if (format_pse.defined()) {
+ dpse = at::empty(format_pse.sizes(), format_pse.options());
+ } else {
+ dpse = at::empty({0}, query.options());
+ }
+
+ if (!ac_seq_qlen_tmp.empty() && !ac_seq_kvlen_tmp.empty()) {
+ ACLNN_CMD(
+ aclnnFlashAttentionUnpaddingScoreGradV2, format_query, format_key, format_value, format_dy,
+ format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
+ format_softmax_sum, format_softmax, format_attention, prefixN, ac_seq_qlen, ac_seq_kvlen, q_start_idx_val, kv_start_idx_val,
+ scale_value, keep_prob, pre_tokens, next_tokens, head_num, input_layout_ptr, inner_precise, sparse_mode, pse_type,
+ dq, dk, dv, dpse);
+ } else {
+ ACLNN_CMD(
+ aclnnFlashAttentionScoreGradV2, format_query, format_key, format_value, format_dy,
+ format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
+ format_softmax_sum, format_softmax, format_attention, prefixN, q_start_idx_val, kv_start_idx_val, scale_value, keep_prob,
+ pre_tokens, next_tokens, head_num, input_layout_ptr, inner_precise, sparse_mode, pse_type, dq, dk, dv, dpse);
+ }
+
+ if (!format_pse.defined()) {
+ at::Tensor dpse_required;
+ dpse = dpse_required;
+ }
+ #ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::flash_attention_backward_flop, query, key, value, dy, head_num, input_layout, actual_seq_qlen, actual_seq_kvlen);
+ #endif
+ return std::make_tuple(dq, dk, dv, dpse);
+}
+
+std::tuple npu_fusion_attention_grad_v2(
+ const at::Tensor &query,
+ const at::Tensor &key,
+ const at::Tensor &value,
+ const at::Tensor &dy,
+ int64_t head_num,
+ const std::string &input_layout,
+ const c10::optional &pse,
+ const c10::optional &padding_mask,
+ const c10::optional &atten_mask,
+ const c10::optional &softmax_max,
+ const c10::optional &softmax_sum,
+ const c10::optional &softmax_in,
+ const c10::optional &attention_in,
+ double scale_value,
+ double keep_prob,
+ int64_t pre_tokens,
+ int64_t next_tokens,
+ int64_t inner_precise,
+ int64_t seed,
+ int64_t offset,
+ int64_t numels,
+ const c10::optional> &prefix,
+ const c10::optional> &actual_seq_qlen,
+ const c10::optional> &actual_seq_kvlen,
+ int64_t sparse_mode,
+ bool gen_mask_parallel,
+ bool sync,
+ int64_t pse_type,
+ const c10::optional> &q_start_idx,
+ const c10::optional> &kv_start_idx)
+{
+ TORCH_CHECK(query.dim() == 3 || query.dim() == 4, "The shapes of the input query should be 3 or 4 dimensional, but got ",
+ query.dim(), "-dimensional");
+ TORCH_CHECK(key.dim() == 3 || key.dim() == 4, "The shapes of the input key should be 3 or 4 dimensional, but got ",
+ key.dim(), "-dimensional");
+ TORCH_CHECK(value.dim() == 3 || value.dim() == 4, "The shapes of the input value should be 3 or 4 dimensional, but got ",
+ value.dim(), "-dimensional");
+ TORCH_CHECK(dy.dim() == 3 || dy.dim() == 4, "The shapes of the input dy should be 3 or 4 dimensional, but got ", dy.dim(), "-dimensional");
+ TORCH_CHECK(keep_prob >= 0 && keep_prob <= 1, "The keep_prob value must be in range of [0, 1], but got ", keep_prob);
+ TORCH_CHECK(pse_type >= 0 && pse_type <= 3, "The pse_type value must be in range of [0, 3], but got ", pse_type);
+ std::string input_layout_str = std::string(input_layout);
+ if (input_layout_str == "TND") {
+ TORCH_CHECK((sparse_mode >= static_cast(SparseMode::NO_MASK) &&
+ sparse_mode < static_cast(SparseMode::PREFIX)) ||
+ (sparse_mode > static_cast(SparseMode::PREFIX) &&
+ sparse_mode <= static_cast(SparseMode::BAND_LEFT_UP_CAUSAL)),
+ "The sparse_mode value must be in range of [0,5) or (5,8], but got ",
+ sparse_mode);
+ } else {
+ TORCH_CHECK(sparse_mode >= static_cast(SparseMode::NO_MASK) &&
+ sparse_mode <= static_cast(SparseMode::PREFIX_COMPRESS),
+ "The sparse_mode value must be in range of [0,6], but got ",
+ sparse_mode);
+ }
+ for (auto &c : input_layout_str) {
+ c = toupper(c);
+ }
+ TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" || input_layout_str == "BNSD" ||
+ input_layout_str == "BSND" || input_layout_str == "TND",
+ "The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ", input_layout);
+
+ int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
+ length += 32;
+ at::Tensor drop_mask;
+ if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_NORMAL) {
+ drop_mask = at_npu::native::npu_dropout_gen_mask(query, at::IntArrayRef{ numels }, 1 - keep_prob,
+ seed, offset, gen_mask_parallel, sync);
+ } else if (get_dropout_status(keep_prob) == DropOutStatus::DROPOUT_ALL) {
+ drop_mask = at::zeros(at::IntArrayRef{length}, query.options().dtype(at::kByte));
+ }
+ auto result = npu_fusion_attention_backward_v2(query,
+ key, value, dy, head_num, input_layout_str, pse, drop_mask, padding_mask, atten_mask,
+ softmax_max, softmax_sum, softmax_in, attention_in, scale_value, keep_prob, pre_tokens,
+ next_tokens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, q_start_idx, kv_start_idx, sparse_mode, pse_type);
+ if (!sync && get_dropout_status(keep_prob) != DropOutStatus::DROPOUT_NONE) {
+ c10::Device device = drop_mask.device();
+ c10::impl::VirtualGuardImpl impl(device.type());
+ impl.recordDataPtrOnStream(drop_mask.storage().data_ptr(), c10_npu::getCurrentNPUStream());
+ }
+ return result;
+}
+
+std::tuple npu_fusion_attention_v2(
+ const at::Tensor &query, const at::Tensor &key,
+ const at::Tensor &value, int64_t head_num, const std::string &input_layout,
+ const c10::optional &pse_opt, const c10::optional &padding_mask_opt,
+ const c10::optional &atten_mask_opt,
+ double scale, double keep_prob, int64_t pre_tokens, int64_t next_tokens, int64_t inner_precise,
+ const c10::optional> &prefix_opt, const c10::optional> &actual_seq_qlen,
+ const c10::optional> &actual_seq_kvlen, int64_t sparse_mode, bool gen_mask_parallel, bool sync,
+ int64_t pse_type, const c10::optional> &q_start_idx, const c10::optional> &kv_start_idx)
+{
+ const at::Tensor &pse = pse_opt.value_or(at::Tensor());
+ const at::Tensor &padding_mask = padding_mask_opt.value_or(at::Tensor());
+ const at::Tensor &atten_mask = atten_mask_opt.value_or(at::Tensor());
+ auto prefixN_tmp = prefix_opt.value_or(std::vector{});
+ auto ac_seq_qlen_tmp = actual_seq_qlen.value_or(std::vector{});
+ auto ac_seq_kvlen_tmp = actual_seq_kvlen.value_or(std::vector{});
+ auto q_start_idx_val_tmp = q_start_idx.value_or(std::vector{});
+ auto kv_start_idx_val_tmp = kv_start_idx.value_or(std::vector{});
+
+ c10::optional prefixN(prefixN_tmp);
+ c10::optional ac_seq_qlen(ac_seq_qlen_tmp);
+ c10::optional ac_seq_kvlen(ac_seq_kvlen_tmp);
+ c10::optional q_start_idx_val(q_start_idx_val_tmp);
+ c10::optional kv_start_idx_val(kv_start_idx_val_tmp);
+
+ TORCH_CHECK(head_num > 0, "head_num must > 0, but got ", head_num);
+ TORCH_CHECK(query.dim() == 3 || query.dim() == 4, "The shapes of the input query should be 3 or 4 dimensional, but got ",
+ query.dim(), "-dimensional");
+ TORCH_CHECK(key.dim() == 3 || key.dim() == 4, "The shapes of the input key should be 3 or 4 dimensional, but got ", key.dim(),
+ "-dimensional");
+ TORCH_CHECK(value.dim() == 3 || value.dim() == 4, "The shapes of the input value should be 3 or 4 dimensional, but got ",
+ value.dim(), "-dimensional");
+ TORCH_CHECK(keep_prob >= 0 && keep_prob <= 1, "The keep_prob value must be in range of [0, 1], but got ", keep_prob);
+ TORCH_CHECK(pse_type >= 0 && pse_type <= 3, "The pse_type value must be in range of [0, 3], but got ", pse_type);
+ std::string input_layout_str = std::string(input_layout);
+ if (input_layout_str == "TND") {
+ TORCH_CHECK((sparse_mode >= static_cast(SparseMode::NO_MASK) &&
+ sparse_mode < static_cast(SparseMode::PREFIX)) ||
+ (sparse_mode > static_cast(SparseMode::PREFIX) &&
+ sparse_mode <= static_cast(SparseMode::BAND_LEFT_UP_CAUSAL)),
+ "The sparse_mode value must be in range of [0,5) or (5,8], but got ",
+ sparse_mode);
+ } else {
+ TORCH_CHECK(sparse_mode >= static_cast(SparseMode::NO_MASK) &&
+ sparse_mode <= static_cast(SparseMode::PREFIX_COMPRESS),
+ "The sparse_mode value must be in range of [0,6], but got ",
+ sparse_mode);
+ }
+ for (auto &c : input_layout_str) {
+ c = toupper(c);
+ }
+ TORCH_CHECK(input_layout_str == "BSH" || input_layout_str == "SBH" || input_layout_str == "BNSD" ||
+ input_layout_str == "BSND" || input_layout_str == "TND",
+ "The input_layout should be BSH/SBH/BNSD/BSND/TND(case-insensitive), but got ", input_layout);
+
+ int64_t B = 0;
+ int64_t S0 = 0; // S for query
+ int64_t S1 = 0; // S for key & value
+ int64_t N = 0;
+ int64_t D = 0;
+ int64_t H = 0;
+ int64_t T = 0;
+ int64_t D2 = 0; // D2 for value head-dim
+ c10::SmallVector atten_score_shape;
+
+ if (input_layout_str == "BSH") {
+ B = query.size(0);
+ S0 = query.size(1);
+ S1 = key.size(1);
+ H = query.size(2);
+ D = H / head_num;
+ D2 = (!D || !key.size(2)) ? 0 : value.size(2) / (key.size(2) / D);
+ atten_score_shape = {B, S0, head_num * D2};
+ } else if (input_layout_str == "SBH") {
+ B = query.size(1);
+ S0 = query.size(0);
+ S1 = key.size(0);
+ H = query.size(2);
+ D = H / head_num;
+ D2 = (!D || !key.size(2)) ? 0 : value.size(2) / (key.size(2) / D);
+ atten_score_shape = {S0, B, head_num * D2};
+ } else if (input_layout_str == "BNSD") {
+ B = query.size(0);
+ N = query.size(1);
+ S0 = query.size(2);
+ S1 = key.size(2);
+ D = query.size(3);
+ D2 = value.size(3);
+ atten_score_shape = {B, N, S0, D2};
+ } else if (input_layout_str == "BSND") {
+ B = query.size(0);
+ N = query.size(2);
+ S0 = query.size(1);
+ S1 = key.size(1);
+ D = query.size(3);
+ D2 = value.size(3);
+ atten_score_shape = {B, S0, N, D2};
+ } else if (input_layout_str == "TND") {
+ T = query.size(0);
+ N = query.size(1);
+ D = query.size(2);
+ D2 = value.size(2);
+ atten_score_shape = {T, N, D2};
+ }
+
+ double scale_value = scale;
+
+ at::Tensor format_query = format_trans(query);
+ at::Tensor attention_score = at::empty(atten_score_shape, query.options());
+ at::Tensor format_key = format_trans(key);
+ at::Tensor format_value = format_trans(value);
+
+ at::Tensor format_pse = format_trans(pse);
+ at::Tensor format_padding_mask = format_trans(padding_mask);
+ at::Tensor format_atten_mask = format_trans(atten_mask);
+
+ int64_t seed;
+ int64_t offset;
+ int64_t numels;
+ //check
+ for(size_t i = 0; i < ac_seq_qlen_tmp.size(); i++){
+ TORCH_CHECK(ac_seq_qlen_tmp[i] <= 1000000 && ac_seq_kvlen_tmp[i] <= 1000000, "The sequence length should not greater than 1M, but got q", ac_seq_qlen_tmp[i],"kv", ac_seq_kvlen_tmp[i]);
+ }
+
+ if (input_layout_str == "TND" && ac_seq_qlen_tmp.size() == ac_seq_kvlen_tmp.size()) {
+ numels = N;
+ int64_t accum = ac_seq_qlen_tmp[0] * ac_seq_kvlen_tmp[0];
+ for (size_t i = 1; i < ac_seq_qlen_tmp.size(); i++) {
+ accum += ((ac_seq_qlen_tmp[i] - ac_seq_qlen_tmp[i - 1]) * (ac_seq_kvlen_tmp[i] - ac_seq_kvlen_tmp[i - 1]));
+ }
+ numels *= accum;
+ }
+
+ at::Tensor format_drop_mask = dropout_gen_mask(format_query, format_key, keep_prob, head_num, input_layout_str,
+ gen_mask_parallel, sync, seed, offset, numels);
+
+ at::Tensor softmax_max;
+ at::Tensor softmax_sum;
+ at::Tensor softmax_out;
+
+ if (input_layout_str != "TND") {
+ softmax_max = at::empty({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat)); // [B, N, S0, 8]
+ softmax_sum = at::empty({B, head_num, S0, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat)); // [B, N, S0, 8]
+ } else {
+ softmax_max = at::empty({T, N, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat)); // [T, N, 8]
+ softmax_sum = at::empty({T, N, SOFTMAXMAX_LAST_DIMSHAPE}, query.options().dtype(at::kFloat)); // [T, N, 8]
+ }
+ softmax_out = at::empty({0}, query.options());
+
+ char* input_layout_ptr = const_cast(input_layout_str.c_str());
+ if (!ac_seq_qlen_tmp.empty() && !ac_seq_kvlen_tmp.empty()) {
+ ACLNN_CMD(
+ aclnnFlashAttentionVarLenScoreV2, format_query, format_key, format_value,
+ format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
+ ac_seq_qlen, ac_seq_kvlen, q_start_idx_val, kv_start_idx_val, scale, keep_prob, pre_tokens, next_tokens, head_num,
+ input_layout_ptr, inner_precise, sparse_mode, pse_type, softmax_max, softmax_sum,
+ softmax_out, attention_score);
+ } else {
+ ACLNN_CMD(
+ aclnnFlashAttentionScoreV2, format_query, format_key, format_value,
+ format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN, q_start_idx_val, kv_start_idx_val,
+ scale, keep_prob, pre_tokens, next_tokens, head_num, input_layout_ptr, inner_precise,
+ sparse_mode, pse_type, softmax_max, softmax_sum, softmax_out, attention_score);
+ }
+ if (!sync && get_dropout_status(keep_prob) != DropOutStatus::DROPOUT_NONE) {
+ c10::Device device = format_drop_mask.device();
+ c10::impl::VirtualGuardImpl impl(device.type());
+ impl.recordDataPtrOnStream(format_drop_mask.storage().data_ptr(), c10_npu::getCurrentNPUStream());
+ }
+ #ifdef FLOP_COUNT
+ FLOP_COUNT(FlopCounter::flash_attention_forward_flop, query, key, value, head_num, input_layout, actual_seq_qlen, actual_seq_kvlen);
+ #endif
+ return std::make_tuple(attention_score, softmax_max, softmax_sum, softmax_out,
+ seed, offset, numels);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("npu_fusion_attention_v2", &npu_fusion_attention_v2, "fusion attention forward v2");
+ m.def("npu_fusion_attention_grad_v2", &npu_fusion_attention_grad_v2, "fusion attention backward v2");
+}
diff --git a/model/train/yoco_moe/mindspeed/ops/csrc/cann/gmm.cpp b/model/train/yoco_moe/mindspeed/ops/csrc/cann/gmm.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fc3a63f3fb29e21189a22213599c1cf2bd3819ed
--- /dev/null
+++ b/model/train/yoco_moe/mindspeed/ops/csrc/cann/gmm.cpp
@@ -0,0 +1,309 @@
+// Copyright (c) 2024 Huawei Technologies Co., Ltd
+// All rights reserved.
+//
+// Licensed under the BSD 3-Clause License (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include
+#include
+#include
+#include
+#include
+
+#include "../flop_counter/flop_counter.h"
+#include "inc/aclnn_common.h"
+
+using npu_preparation = at_npu::native::OpPreparation;
+
+namespace op_infer {
+constexpr int SIZE = 8;
+
+c10::SmallVector array_to_small_vector(c10::IntArrayRef shape)
+{
+ c10::SmallVector small_shape;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ small_shape.emplace_back(shape[i]);
+ }
+ return small_shape;
+}
+}
+
+void _check_dims(size_t num_x, const at::TensorList &weight, size_t num_group_list)
+{
+ size_t num_w = weight.size();
+ TORCH_CHECK(num_x > 0 && num_w > 0,
+ "Neither x nor weight could be empty.");
+ size_t dim_num_w = weight[0].sizes().size();
+ size_t dim_0_w = weight[0].sizes()[0];
+}
+
+void _create_new_tensor(std::vector