From 0ca38cb9b08f76db3e8ae25fe8f7d100fae470b3 Mon Sep 17 00:00:00 2001 From: zhangyihuiben Date: Tue, 17 Jun 2025 10:26:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E6=8E=A5=E5=8F=A3=E7=B2=BE=E5=BA=A6?= =?UTF-8?q?=E5=AF=B9=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_comparison.md | 295 ++++++++++++++++++ .../example/accuracy_comparison/example.sh | 130 ++++++++ .../example/accuracy_comparison/example.yaml | 222 +++++++++++++ docs/mindformers/docs/source_zh_cn/index.rst | 5 + 4 files changed, 652 insertions(+) create mode 100644 docs/mindformers/docs/source_zh_cn/advanced_development/accuracy_comparison.md create mode 100644 docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.sh create mode 100644 docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.yaml diff --git a/docs/mindformers/docs/source_zh_cn/advanced_development/accuracy_comparison.md b/docs/mindformers/docs/source_zh_cn/advanced_development/accuracy_comparison.md new file mode 100644 index 0000000000..d8e9f72c23 --- /dev/null +++ b/docs/mindformers/docs/source_zh_cn/advanced_development/accuracy_comparison.md @@ -0,0 +1,295 @@ +# Parallel Core 训练模型精度对比 + +## 1. 概述 + +在大模型训练系统中,模型层级的数值精度验证是保障训练稳定性和结果可信度的关键环节。随着训练任务日益复杂,模型结构日趋庞大,确保不同实现之间在模型整体行为上的对齐,显得尤为重要。 + +Megatron-LM 是一个面向大规模训练任务的成熟框架,具备高度模块化与良好的可扩展性,广泛应用于高性能训练场景。MindSpore Transformers r1.6.0 版本在模型构建方面引入了名为 Parallel Core 的**全新架构**,以**ModuleSpec** 配置方式搭建模型,使得模型结构定义更加**灵活**且**易于复用**,极大提升了开发效率。同时在 NPU 环境下提供了全面优化的训练支持,能够充分发挥 NPU 架构优势。 + +本文档聚焦于两者在模型层面的精度一致性验证。通过构建等价的模型结构与配置,使用统一的输入,比较其前向输出、损失值、梯度行为等关键训练过程中的表现差异,以此验证 MindSpore Transformers 在 NPU 环境下实现的可靠性与精度可控性。 + +## 2. 环境说明 + +本节说明精度对比实验的推荐基础运行环境,包括: + +### 驱动版本 + +| GPU | 版本 | NPU | 版本 | +|------|------|------|---------| +| CUDA | 12.1 | CANN | 8.1.RC1 | + +### 重要库和依赖版本 + +| GPU | 版本 | NPU | 版本 | +|--------------------|--------------|------------------------|---------| +| Megatron-LM | core_r0.12.0 | MindSpore Transformers | dev | +| Python | \>=3.10 | Python | \>=3.10 | +| PyTorch | 2.7.0 | MindSpore | 2.6.0 | +| NumPy | 1.26.4 | NumPy | 1.26.4 | +| Transformer Engine | 2.1.0 | | | +| Apex | 0.1 | | | + +### 镜像链接 + +上表中的 **GPU / NPU** 相关依赖版本为参考信息,实际环境请以对应官方镜像为准: +> - **Megatron-LM** + :参考 [Megatron-LM 文档](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.12.0?tab=readme-ov-file#setup) +> - **MindSpore Transformers**: + :参考 [MindSpore Transformers 文档](https://gitee.com/mindspore/mindformers/blob/dev/README_CN.md) + +## 3. 精度对比流程 + +本节介绍 MindSpore Transformers 在 NPU 环境下与业界主流实现 Megatron-LM 进行模型级别的精度对齐验证流程。本流程旨在指导用户完成从模型配置、数据输入、前向输出到梯度反向传播的全流程对齐,最终评估两个框架在相同任务下的数值一致性。 + +### 3.1 配置对齐 + +精度对比流程的第一步是确保两个框架使用**完全一致的模型配置**。为此,本小节提供了 [Megatron-LM](https://gitee.com/mindspore/docs/blob/master/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.sh) 与 [MindSpore Transformers](https://gitee.com/mindspore/docs/blob/master/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.yaml) 的对应配置文件,分别定义了模型结构、并行策略以及关键训练超参数。 + +配置对齐的目标是保证两个系统在初始化状态下尽可能一致,从而使得后续的前向输出、梯度反向传播等比对具有可比性。 + +下表列出了提供的配置文件中涉及的关键模型参数,及其在两套系统中的配置对照情况: + +| Megatron-LM | 含义 | MindSpore Transformers | 含义 | +|-------------------------------------|-------------------------|------------------------|----------------------------------------| +| use-mcore-models | 是否使用 mcore 模型 | 无 | | +| disable-bias-linear | 不在线性层使用bias | add_bias_linear | 在线性层使用 bias | +| seq-length | 序列长度 | seq_length | 序列长度 | +| num-layers | 解码器层数 | num_layers | 解码器层数 | +| hidden-size | 隐藏状态的维度 | hidden_size | 隐藏状态的维度 | +| ffn-hidden-size | 前馈神经网络的隐藏层维度 | intermediate_size | 前馈神经网络的隐藏层维度 | +| num-attention-heads | 注意力头数 | num_heads | 注意力头数 | +| init-method-std | 模型参数初始化时使用的正态分布的标准差 | 无 | 同名配置,默认为0.1 | +| attention-dropout | 多头自注意力机制里应用的 Dropout 概率 | attention_dropout | 多头自注意力机制里应用的 Dropout 概率 | +| hidden-dropout | 隐藏层的 Dropout 概率 | hidden_dropout | 隐藏层的 Dropout 概率 | +| Normalization | 归一化操作类型 | 无 | 同名配置,默认为 RMSNorm | +| norm-epsilon | 归一化数值稳定因子 | rms_norm_eps | RMSNorm 稳定因子 | +| position-embedding-type | 位置编码类型 | 无 | 同名配置,默认为 "rope" | +| swiglu | 激活函数类型是否为 SwiGLU | 无 | 无需配置,默认为 SwiGLU | +| untie-embeddings-and-output-weights | 输入嵌入层和输出投影层是否共享权重 | 无 | 同名配置,默认为`True` | +| num-query-groups | Query 分组数量 | 无 | 同名配置,默认为注意力头数 | +| no-masked-softmax-fusion | 关闭 Masked Softmax 融合 | 无 | 配置名为 masked_softmax_fusion ,默认为`False` | +| mtp-num-layers | MoE 层的数量 | mtp_depth | MoE 层的数量 | +| mtp-loss-scaling-factor | MoE 架构中的损失缩放 | mtp_loss_factor | MoE 架构中的损失缩放 | +| q-lora-rank | Q-LoRA 的秩 | q_lora_rank | Q-LoRA 的秩 | +| kv-lora-rank | KV-LoRA 的秩 | kv_lora_rank | KV-LoRA 的秩 | +| qk-pos-emb-head-dim | Query/Key 位置编码每头维度 | qk_rope_head_dim | Query/Key 位置编码每头维度 | +| v-head-dim | Value 每个注意力头的维度 | v_head_dim | Value 每个注意力头的维度 | +| qk-head-dim | Query/Key 每个注意力头的维度 | qk_nope_head_dim | Query/Key 每个注意力头的维度 | +| qk-layernorm | 启用Query/Key 层归一化 | 无 | 无需配置,默认启用 | +| vocab-size | 词汇表的总大小 | vocab_size | 词汇表的总大小 | +| use-flash-attn | 启用 FlashAttention | use_flash_attention | 启用 FlashAttention | +| multi-latent-attention | 启用 多隐变量注意力机制 | 无 | 同名配置,默认启用 | +| moe-layer-freq | 指定一个列表形式的 MoE 层频率 | first_k_dense_replace | 前k层不为 MoE 层 | +| num-experts | 专家的总数 | expert_num | 专家的总数 | +| moe-router-topk | 选择激活的专家数量 K | num_experts_chosen | 选择激活的专家数量 | +| moe-aux-loss-coeff | MoE 辅助损失系数 | aux_loss_factors | MoE 辅助损失系数 | +| moe-ffn-hidden-size | MoE 前馈网络隐藏层维度大小 | moe_intermediate_size | MoE 前馈网络隐藏层维度大小 | +| lr | 学习率 | learning_rate | 学习率 | +| lr-decay-style | 学习率衰减策略 | type | 学习率衰减策略 | +| adam-beta1 和 adam-beta2 的组合 | Adam 优化器的 beta 参数 | betas | Adam 优化器的 beta 参数 | +| adam-eps | Adam 优化器的 epsilon 参数 | eps | Adam 优化器的 epsilon 参数 | + +**注意**: 其他未在上表提及的参数,均为另一个框架无需配置项。 + +### 3.2 数据集对齐 + +精度对比流程中,必须确保两个框架使用完全一致的数据输入。该小节将介绍如何对齐 Megatron-LM 与 MindSpore Transformers 的数据集制作和配置,从而保证输入样本的一致性,为后续权重加载与精度验证提供基础。 + +#### 3.2.1 数据集准备 + +两个框架均支持加载 Megatron 数据集,该数据集通常经过预处理,序列化为二进制格式(例如`.bin`或`.idx`文件),并配套特定索引机制,便于在分布式集群环境下高效并行加载与数据切分。 + +- 数据集下载:[wikitext-103数据集](https://dagshub.com/DagsHub/WIkiText-103/src/main/dataset/tokens) + +- 分词模型下载:分词模型[tokenizer.json](https://huggingface.co/deepseek-ai/DeepSeek-V3/resolve/main/tokenizer.json?download=true) + +#### 3.2.2 数据集处理 + +数据集处理可参考[Megatron数据集-数据预处理](https://www.mindspore.cn/mindformers/docs/zh-CN/dev/feature/dataset.html#%E6%95%B0%E6%8D%AE%E9%A2%84%E5%A4%84%E7%90%86) + +- 生成Megatron BIN格式文件 + + 将数据集文件`wiki.train.tokens`和分词模型文件`tokenizer.json`放置在`../dataset`下。 + + 使用以下命令将数据集文件转换为BIN格式文件。 + + ```shell + cd $MINDFORMERS_HOME + python research/deepseek3/wikitext_to_bin.py \ + --input ../dataset/wiki.train.tokens \ + --output-prefix ../dataset/wiki_4096 \ + --vocab-file ../dataset/tokenizer.json \ + --seq-length 4096 \ + --workers 1 + ``` + +- 构建Megatron BIN数据集模块 + + 执行如下命令构建Megatron BIN数据集模块。如使用提供的镜像请跳过此操作。 + + ```shell + pip install pybind11 + cd $MINDFORMERS_HOME/mindformers/dataset/blended_datasets + make + ``` + + 其中,`$MINDFORMERS_HOME` 指 Mindspore Transformers 源代码所在的目录。 + +#### 3.2.2 数据集配置 + +本小节会将两个框架配置文件中的数据集配置项,进行对比和说明。 + +- Megatron-LM: + + ```shell + TOKENIZER_MODEL="/path/to/tokenizer.json" + DATA_PATH="/path/to/wiki_text_document" + + DATA_ARGS=( + --tokenizer-type HuggingFaceTokenizer + --tokenizer-model ${TOKENIZER_MODEL} + --data-path $DATA_PATH + --split 1,0,0 + ) + ``` + + 其中, + + - `tokenizer-type`为分词模型文件类型 + - `tokenizer-model`为分词模型文件`tokenizer.json`的所在位置,精确到完整文件名 + - `data-path`为处理好的数据集的所在位置,精确到`.bin`或`.idx`文件的前缀 + - `split`为数据集的采样比例 + +- MindSpore Transformers: + + ```yaml + train_dataset: &train_dataset + data_loader: + type: BlendedMegatronDatasetDataLoader + datasets_type: "GPTDataset" + sizes: + - 4000 # 训练集数据样本数 + - 0 # 测试集数据样本数,当前不支持配置 + - 0 # 评测集数据样本数,当前不支持配置 + config: # GPTDataset配置项 + seed: 1234 # 数据采样随机种子 + split: "1, 0, 0" # 训练、测试、评测集使用比例,当前不支持配置 + seq_length: 4096 # 数据集返回数据的序列长度 + eod_mask_loss: False # 是否在eod处计算loss + reset_position_ids: False # 是否在eod处重置position_ids + create_attention_mask: True # 是否返回attention_mask + reset_attention_mask: False # 是否在eod处重置attention_mask,返回阶梯状attention_mask + create_compressed_eod_mask: False # 是否返回压缩后的attention_mask + eod_pad_length: 128 # 设置压缩后attention_mask的长度 + eod: 0 # 数据集中eod的token id + pad: 1 # 数据集中pad的token id + + data_path: # Megatron数据集采样比例以及路径 + - '1' + - "/home/to/wiki_text_document" + + input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] + construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] + + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + numa_enable: False + prefetch_size: 1 + seed: 1234 + ``` + + 其中,需要注意的是`data_path`的第一个参数是数据集采样比例,第二个参数是处理好的数据集的所在位置,精确到`.bin`或`.idx`文件的前缀 + +### 3.3 权重对齐 + +为了实现不同框架间模型行为的一致性,需将训练得到的权重精确映射到 MindSpore Transformers 和 Megatron-LM 中对应位置,通过合理的权重转换和切分实现。 + +#### 3.3.1 权重转换 + +由于 MindSpore Transformers 和 Megatron-LM 使用的权重格式、参数命名方式及张量排列存在差异,直接加载权重通常会导致不兼容。因此,需要通过专门的转换脚本将源框架导出的模型权重转换为目标框架可识别的格式。 + +1. 生成 MinSpore Transformers 初始权重 + + 通过修改 `example.yaml` 文件并执行[查看结果](#34-查看结果)中提供的命令,即可通过预训练在`example.yaml`中的`output_dir`的`checkpoints`下获得一份初始权重,修改内容如下: + + ```yaml + # Before (example.yaml) + load_checkpoint: '/path/to/checkpoints/' + + callbacks: + - type: MFLossMonitor + per_print_times: 1 + - type: TopkBiasBalanceCallback + balance_via_topk_bias: *balance_via_topk_bias + topk_bias_update_rate: *topk_bias_update_rate + expert_num: *expert_num + num_layers: *num_layers + mtp_depth: *mtp_depth + micro_batch_num: *micro_batch_num + ``` + + ```yaml + # After (example.yaml) + load_checkpoint: '' + + callbacks: + - type: MFLossMonitor + per_print_times: 1 + - type: TopkBiasBalanceCallback + balance_via_topk_bias: *balance_via_topk_bias + topk_bias_update_rate: *topk_bias_update_rate + expert_num: *expert_num + num_layers: *num_layers + mtp_depth: *mtp_depth + micro_batch_num: *micro_batch_num + - type: CheckpointMonitor + prefix: "deepseekv3" + save_checkpoint_steps: 1 + keep_checkpoint_max: 2 + integrated_save: False + async_save: False + checkpoint_format: "safetensors" + - type: TrainCallBack + stop_step: 1 + ``` + + **注意**:获得权重之后,需要将`example.yaml`反向修改复原。 + +2. MindSpore Transformers to Megatron-LM + + 为了将 MindSpore Transformers 的权重精确映射为 Megatron-LM 可加载的等价权重,我们将会提供转换权重脚本,执行权重转换脚本即可获得等价权重。 + +### 3.4 查看结果 + +完成以上步骤后,即可进行训练,从日志中输出的结果中提取关键数据查看精度对比结果。 + +- Megatron-LM + + 将`example.sh`文件放到 Megatron-LM 代码目录下,执行以下代码: + + ```shell + bash example.sh + ``` + +- MindSpore Transformers + + 在 MindSpore Transformer 代码目录下,执行以下代码: + + ```shell + bash scripts/msrun_launcher.sh "run_mindformer.py \ + --config /path/to/example.yaml" + ``` + + 其中,`config`是模型的配置文件,文件在 MindSpore Transformers 代码仓中 config 目录下 + +- 结果对比 + + 分别查看二者的输出日志,Megatron-LM 的日志位置为`example.sh`中的`logs/${logtime}.log`, MindSpore Transformer 的日志位置为`example.yaml`中的`output_dir`的`msrun_log/worker_0.log`。结果对比参考下表: + + | Megatron-LM | MindSpore Transformers | 含义 | + |-----------------|------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------| + | `iteration` | `epoch` 与 `step` 的组合 | 表示训练过程中的全局迭代次数。MindSpore Transformers 通常以 `(epoch, step)` 表示当前训练位置,而 Megatron-LM 使用单一的 `iteration` 表示。两者关系为:`iteration = (epoch - 1) * steps_per_epoch + step` | + | `lm loss` | `loss` | 训练损失,精度对比核心指标 | + | `learning rate` | `lr` | 学习率,精度对比参考指标 | + | `grand norm` | `global norm` | 全局梯度范数,精度对比参考指标 | \ No newline at end of file diff --git a/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.sh b/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.sh new file mode 100644 index 0000000000..251aa9688b --- /dev/null +++ b/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# Runs Mixtral 8x7B model +export PYTHONPATH=/home/work/projects/deepseekv3/Megatron-LM:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +GPUS_PER_NODE=4 +# Change for multinode config +MASTER_ADDR=${MASTER_ADDR:-"localhost"} +MASTER_PORT=${MASTER_PORT:-"6000"} +NNODES=${SLURM_NNODES:-"1"} +NODE_RANK=${RANK:-"0"} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +LOAD_PATH="/path/to/checkpoints" +TOKENIZER_MODEL="/path/to/tokenizer.json" +DATA_PATH="/path/to/wiki_text_document" + +TP=1 +PP=4 +EP=1 + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NNODES + --node_rank $NODE_RANK + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT +) + +MODEL_ARGS=( + --use-mcore-models + --disable-bias-linear + --seq-length 4096 + --max-position-embeddings 163840 + --num-layers 4 + --hidden-size 2048 + --ffn-hidden-size 6144 + --num-attention-heads 8 + --init-method-std 0.01 + --attention-dropout 0.0 + --hidden-dropout 0.0 + --normalization RMSNorm + --norm-epsilon 1e-6 + --position-embedding-type rope + --no-rope-fusion + --swiglu + --untie-embeddings-and-output-weights + --num-query-groups 8 + --no-masked-softmax-fusion + --mtp-num-layers 1 + --mtp-loss-scaling-factor 0.3 + --q-lora-rank 1536 + --kv-lora-rank 512 + --qk-pos-emb-head-dim 64 + --v-head-dim 192 + --qk-head-dim 128 + --qk-layernorm + --vocab-size 129280 + --make-vocab-size-divisible-by 129280 + --use-flash-attn + --multi-latent-attention + --attention-backend flash +) + +MOE_ARGS=( + --moe-layer-freq '[0]+[1]*3' + --num-experts 16 + --moe-router-topk 8 + --moe-router-load-balancing-type seq_aux_loss + --moe-aux-loss-coeff 0 + --moe-grouped-gemm + --moe-token-dispatcher-type alltoall + --overlap-param-gather + --overlap-grad-reduce + --moe-shared-expert-intermediate-size 2048 + --moe-ffn-hidden-size 2048 + --moe-router-group-topk 0 + --moe-router-topk-scaling-factor 1.5 + --moe-router-score-function sigmoid + --moe-router-dtype fp32 +) + +DATA_ARGS=( + --tokenizer-type HuggingFaceTokenizer + --tokenizer-model ${TOKENIZER_MODEL} + --data-path $DATA_PATH + --split 1,0,0 +) + +TRAINING_ARGS=( + --micro-batch-size 1 + --global-batch-size 4 + --train-iters 1000 + --lr 1.e-6 + --lr-decay-style constant + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --clip-grad 1.0 + --bf16 + --finetune +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size ${TP} + --pipeline-model-parallel-size ${PP} + --expert-model-parallel-size ${EP} + --use-distributed-optimizer +) + +LOGGING_ARGS=( + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 1000 \ + --no-load-optim \ + --no-load-rng \ + --ckpt-format torch \ + --load $LOAD_PATH +) + +logtime=$(date +%Y%m%d)_$(date +%H%M%S) +torchrun ${DISTRIBUTED_ARGS[@]} /path/to/Megatron-LM/pretrain_gpt.py \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} 2>&1 | tee logs/${logtime}.log \ No newline at end of file diff --git a/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.yaml b/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.yaml new file mode 100644 index 0000000000..da63724d12 --- /dev/null +++ b/docs/mindformers/docs/source_zh_cn/example/accuracy_comparison/example.yaml @@ -0,0 +1,222 @@ +seed: 1234 +output_dir: './output' # path to save checkpoint/strategy +load_checkpoint: '/path/to/checkpoints/' +load_ckpt_format: 'safetensors' # format of checkpoint files +src_strategy_path_or_dir: '' +auto_trans_ckpt: True # If true, auto transform load_checkpoint to load in distributed model +only_save_strategy: False +resume_training: False +use_parallel: True +run_mode: 'train' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'deepseekV3' + +# runner config +runner_config: + epochs: 1 + batch_size: 1 + sink_mode: True + sink_size: 1 + +# optimizer +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 1.e-8 + +lr_schedule: + type: ConstantWarmUpLR + learning_rate: 1.e-6 + warmup_steps: 0 + total_steps: -1 # -1 means it will load the total steps of the dataset + +# dataset +train_dataset: &train_dataset + data_loader: + type: BlendedMegatronDatasetDataLoader + datasets_type: "GPTDataset" + sizes: + - 4000 # 训练集数据样本数 + - 0 # 测试集数据样本数,当前不支持配置 + - 0 # 评测集数据样本数,当前不支持配置 + config: # GPTDataset配置项 + seed: 1234 # 数据采样随机种子 + split: "1, 0, 0" # 训练、测试、评测集使用比例,当前不支持配置 + seq_length: 4096 # 数据集返回数据的序列长度 + eod_mask_loss: False # 是否在eod处计算loss + reset_position_ids: False # 是否在eod处重置position_ids + create_attention_mask: True # 是否返回attention_mask + reset_attention_mask: False # 是否在eod处重置attention_mask,返回阶梯状attention_mask + create_compressed_eod_mask: False # 是否返回压缩后的attention_mask + eod_pad_length: 128 # 设置压缩后attention_mask的长度 + eod: 0 # 数据集中eod的token id + pad: 1 # 数据集中pad的token id + + data_path: # Megatron数据集采样比例以及路径 + - '1' + - "/home/to/wiki_text_document" + + input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] + construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] + + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + numa_enable: False + prefetch_size: 1 + seed: 1234 + +train_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *train_dataset + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + max_call_depth: 10000 + max_device_memory: "58GB" + save_graphs: False + save_graphs_path: "./graph" + jit_config: + jit_level: "O1" + ascend_config: + parallel_speed_up_json_path: /path/to/mindformers/research/deepseek3/parallel_speed_up.json + +# parallel config for device num = 1024 +parallel_config: + data_parallel: &dp 4 + model_parallel: 2 + pipeline_stage: 1 + expert_parallel: 1 + micro_batch_num: µ_batch_num 1 + vocab_emb_dp: True + use_seq_parallel: False + gradient_aggregation_group: 4 +# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. +micro_batch_interleave_num: 1 + +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: True + full_batch: False + dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1, 1, 1]] # [*dp, 1, 1, 1] + search_mode: "sharding_propagation" + enable_parallel_optimizer: True + strategy_ckpt_config: + save_file: "./ckpt_strategy.ckpt" + only_trainable_params: False + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 8 + +# recompute config +recompute_config: + recompute: False + select_recompute: False + parallel_optimizer_comm_recompute: True + mp_comm_recompute: True + recompute_slice_activation: True + +# model config +use_legacy: False +model: + model_config: + model_type: "deepseekv3" + architectures: "DeepseekV3ForCausalLM" + + # gpt model + num_layers: &num_layers 4 + vocab_size: 129280 # --make-vocab-size-divisible-by 129280 + seq_length: 4096 + hidden_size: 2048 + intermediate_size: 6144 # --ffn-hidden-size + num_heads: 8 + max_position_embeddings: 4096 + use_flash_attention: True + use_eod_reset: False + add_bias_linear: False + rms_norm_eps: 1.e-6 + attention_dropout: 0.0 + hidden_dropout: 0.0 + + # rope module -> use magatron default values, cannot be changed + extend_method: "yarn" + scaling_factor: 40.0 + beta_fast: 32.0 + beta_slow: 1.0 + mscale: 0.707 + mscale_all_dim: 0.707 + theta: 10000.0 + + # mla module + kv_lora_rank: 512 + n_kv_heads: 128 + q_lora_rank: 1536 + qk_rope_head_dim: 64 + v_head_dim: 192 + qk_nope_head_dim: 128 + + # mtp module + mtp_depth: &mtp_depth 1 + mtp_loss_factor: 0.3 + + # params dtypes + param_init_type: "float32" + compute_dtype: "bfloat16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "float32" + router_dense_type: "float32" + + # other options + offset: 0 + ignore_token_id: -100 + input_sliced_sig: True # only used to pass parameter check when applying BlendedMegatronDatasetDataLoader or CommonDataLoader + batch_size: 1 # add for increase predict + +# moe config +moe_config: + expert_num: &expert_num 16 + num_experts_chosen: 8 + balance_via_topk_bias: &balance_via_topk_bias False + topk_bias_update_rate: &topk_bias_update_rate 0.001 + shared_expert_num: 1 + routed_scaling_factor: 2.5 + norm_topk_prob: True + first_k_dense_replace: 1 + moe_intermediate_size: 2048 # --moe-shared-expert-intermediate-size / --moe-ffn-hidden-size + aux_loss_factors: [0.0] + aux_loss_types: ["expert"] + z_loss_factor: 0.0 + expert_model_parallel: 1 + +# callbacks +callbacks: + - type: MFLossMonitor + per_print_times: 1 + - type: TopkBiasBalanceCallback + balance_via_topk_bias: *balance_via_topk_bias + topk_bias_update_rate: *topk_bias_update_rate + expert_num: *expert_num + num_layers: *num_layers + mtp_depth: *mtp_depth + micro_batch_num: *micro_batch_num + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +profile: False +profile_start_step: 1 +profile_stop_step: 10 +init_start_profile: False +profile_communication: False +profile_memory: True \ No newline at end of file diff --git a/docs/mindformers/docs/source_zh_cn/index.rst b/docs/mindformers/docs/source_zh_cn/index.rst index dd142e8086..a4edce1fba 100644 --- a/docs/mindformers/docs/source_zh_cn/index.rst +++ b/docs/mindformers/docs/source_zh_cn/index.rst @@ -141,6 +141,10 @@ MindSpore Transformers功能特性说明 - `开发迁移 `_ - `多模态理解模型开发 `_ +- 精度对比 + + - `Parallel Core精度对比 `_ + 环境变量 ------------------------------------ @@ -210,6 +214,7 @@ FAQ advanced_development/performance_optimization advanced_development/dev_migration advanced_development/multi_modal_dev + advanced_development/accuracy_comparison advanced_development/api .. toctree:: -- Gitee