diff --git a/OWNERS b/OWNERS index a150718b2a0efb31f9f4dbdb6a7831bf91689de1..dc8f1f095e6b8acad1d63cd20a327e388c1e6183 100644 --- a/OWNERS +++ b/OWNERS @@ -1,6 +1,7 @@ approvers: - dilililiwhy - wujianping996 +- fengliangjun66 reviewers: - xqiangx1991 - sanjaychan diff --git a/README.md b/README.md index 67da4017e3aeb6580b12b4cb84800e5abf606e4a..b051ec0810c6d9ef0cfcbe2ec86259721f1b67d8 100644 --- a/README.md +++ b/README.md @@ -159,10 +159,10 @@ The plan for more tasks, like RLHF and RM, is under way ... 7B 1x8 FP16 - -- - -- - -- - -- + 2460 tokens/p/s + 2348 tokens/p/s + Loss + Train diff --git a/examples/llama2/README.md b/examples/llama2/README.md new file mode 100755 index 0000000000000000000000000000000000000000..a74dd92271366529cb9226c84b2782f733ba6af8 --- /dev/null +++ b/examples/llama2/README.md @@ -0,0 +1,184 @@ + + +# Contents + +- [LLaMA2-7B](#contents) + - [Training](#pre-training) + - [Script](#script) + - [Performance](#performance) + - [Machine performance](#machine-performance) + - [Accuracy of the loss](#accuracy-of-the-loss) + + + +# LLaMA2-7B + +## Training + +Here's a hardware summary of pre-training LLaMA2-7B: + +| Hardware | Value | +| :------: | :---------------------------------------------: | +| NPU | 8 x Ascend NPUs | + +Here's a software summary of pre-training LLaMA2-7B: + + +| Software | Version |link | +| :-----------------------: | :-------------------------------------: | :---:| +| Python | 3.7.16 |-| +| driver | 23.0.RC3.B050 |[link](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/261159045?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C252764743)| +| firmware | 7.0.t8.0.b214 |[link](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/261159045?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C252764743)| +| CANN |Ascend-cann-toolkit-7.0.T8-linux |[link](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373/software/261204647?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C251168373)| +| binary arithmetic package | Ascend-cann-kernels-XXX_7.0.T8_linux |[link](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373/software/261204647?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C251168373)| +| torch | 1.11.0 |[link](https://gitee.com/ascend/pytorch/releases/tag/v5.0.rc2.2-pytorch1.11.0)| +| torch_npu | 1.11.0.post4-20230915 |[link](https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v1.11.0/20230915.2/pytorch_v1.11.0_py37.tar.gz)| + + +### Script + +1. Clone the repository to your local server: + ```shell + git clone https://gitee.com/ascend/AscendSpeed.git + cd AscendSpeed + mkdir logs + mkdir ckpt + ``` + +2. Build environment + + ```bash + # python3.7 + conda create -n test python=3.7 + conda activate test + + # install torch and torch_npu + pip install torch-1.11.0-cp37-cp37m-manylinux2014_aarch64.whl + pip install torch_npu-1.11.0*-cp37-cp37m-linux_aarch64.whl + pip install apex-0.1_ascend*-cp37-cp37m-linux_aarch64.whl + + # install megatron-core + pip3 install -e git+https://github.com/NVIDIA/Megatron-LM.git@23.05#egg=megatron-core + + # install deepspeed and deepspeed_npu + pip install deepspeed==0.9.2 + git clone https://gitee.com/ascend/DeepSpeed.git -b v0.9.2 deepspeed_npu + cd deepspeed_npu + pip3 install -e ./ + cd .. + + # install other packages + pip install -r requirements.txt + ``` + *Note that if you want to train with the weight from huggingface, please run fix a deepspeed loading checkpointing bug by modified `if zero_sd_list is None` as `if zero_sd_list is None or len(zero_sd_list) == 0` in the `_load_zero_checkpoint` function of `/runtime/engine.py`* + + ```text + # original deepspeed/runtime/engine.py, about #Lines2746-2748 + zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) + if zero_sd_list is None: + return False + + # modified + zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) + if zero_sd_list is None or len(zero_sd_list) == 0: + return False + ``` +3. Prepare pretrained weights and tokenizer + Download the LLaMA2-7B checkpoint from [here](https://huggingface.co/daryl149/llama-2-7b-hf/tree/main) + + ```shell + #!/bin/bash + mkdir -p llama-2-7b-hf + cd llama-2-7b-hf + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/config.json + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/generation_config.json + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/pytorch_model-00001-of-00002.bin + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/pytorch_model-00002-of-00002.bin + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/pytorch_model.bin.index.json + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/special_tokens_map.json + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/tokenizer.json + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/tokenizer.model + wget https://huggingface.co/daryl149/llama-2-7b-hf/resolve/main/tokenizer_config.json + cd .. + ``` + + *Note that if you want to use the weight from huggingface, please run the weight conversion script first. The following uses llama-2-7b model weight conversion in deepspeed as an example.* + ```bash + # modify the script according to your own ascend-toolkit path + source /usr/local/Ascend/ascend-toolkit/set_env.sh + + # convert to deepspeed weights + python tools/ckpt_convert/llama/convert_weights_from_huggingface.py --input-model-dir llama-2-7b-hf \ + --output-model-dir ckpt \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --type 7B \ + --deepspeed + ``` + +4. Prepare dataset + + Download the LLaMA2-7B datasets from [here](https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet) + + ```shell + # download datasets + mkdir dataset_llama2 + cd ./dataset_llama2 + wget https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet + cd .. + + # process datasets + python ./tools/preprocess_data.py \ + --input ./dataset_llama2/train-00000-of-00001-a09b74b3ef9c3b56.parquet \ + --tokenizer-name-or-path ./llama-2-7b-hf \ + --output-prefix ./dataset_llama2/alpaca \ + --workers 4 \ + --log-interval 1000 \ + --tokenizer-type PretrainedFromHF + ``` + +5. Config LLaMA2-7B pre-training script: examples/llama2/pretrain_llama2_7b_zero_8p.sh + + ```shell + # modify the script according to your own ascend-toolkit path + source /usr/local/Ascend/ascend-toolkit/set_env.sh + + # modify script orign dataset path according to your own dataset path + TOKENIZER_PATH=./llama-2-7b-hf/ #tokenizer path + DATA_PATH=./dataset_llama2/alpaca_text_document #processed dataset + ``` + +6. Launch LLaMA2-7B pre-training script: examples/llama2/pretrain_llama2_7b_zero_8p.sh + + ```shell + bash examples/llama2/pretrain_llama2_7b_zero_8p.sh + ``` + +### Performance + +#### Machine performance + +The performance of LLaMA2-7B in **Ascend NPU** and **Reference**: + +| Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) | floating point operation (TFLOPs/s) | +| :------: | :-----------: | :----------------: | :-----------------------------: | :----------------------------: | :-------------------------: | :-----------------------------------: | +| NPUs | LLaMA2-7B | 1024 | 4.804 | 2459.648 | 6.66 | 147.42 | +| Reference | LLaMA2-7B | 1024 | 4.585 | 2347.63 | 6.99 | 143.01 | + + +#### Accuracy of the loss + +NPU vs Reference loss. + +The NPU runs smoothly, the resource usage is stable, no errors are reported in the middle of the process, the Loss is on a decreasing trend, and the convergence speed is as expected. +The precision meets the requirements. + +![NPU-LOSS](../../sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights.png) + +The relative error of the average loss is 0.0046, less than 2%, the maximum relative error is 0.0852. + +![NPU-LOSS and NPU-Relative-Error](../../sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_relative.png) + +The absolute error of the average loss is 0.0009, less than 2%, the maximum absolute error is 0.0246. + +![NPU-LOSS and NPU-Absolute-Error](../../sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_absolute.png) diff --git a/examples/llama2/pretrain_llama2_7b_zero_8p.sh b/examples/llama2/pretrain_llama2_7b_zero_8p.sh new file mode 100755 index 0000000000000000000000000000000000000000..f8c599604939e0583902209174b2afb961855a8e --- /dev/null +++ b/examples/llama2/pretrain_llama2_7b_zero_8p.sh @@ -0,0 +1,107 @@ +# This is an example: training llama2 using zero/ + +# The number of parameters is not aligned +export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +export COMBINED_ENABLE=1 + +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6001 +NNODES=1 +NODE_RANK=0 +NPUS_PER_NODE=8 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/llama_text_document +CHECKPOINT_PATH=./ckpt + +DS_CONFIG=deepspeed_config_7B.json +ZERO_STAGE=2 +GLOBAL_BATCH=32 +MICRO_BATCH=4 + +cat < $DS_CONFIG +{ + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 8, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "optimizer": { + "type": "Adam" + }, + + "zero_optimization": { + "stage": $ZERO_STAGE, + "allgather_partitions": true, + "allgather_bucket_size": 1e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": 1, + "train_batch_size": $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu":$MICRO_BATCH, + "zero_allow_untested_optimizer": true +} +EOT + +ds_args="" +ds_args=" --deepspeed ${ds_args}" +ds_args=" --no-pipeline-parallel ${ds_args}" +ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}" +ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}" +ds_args=" --deepspeed-activation-checkpointing ${ds_args}" + +# Main script +deepspeed pretrain_llama.py \ + --checkpoint-activations \ + --use-fused-rotary-pos-emb \ + --triangle-attn \ + --DDP-impl local \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 11008 \ + --num-attention-heads 32 \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --tokenizer-name-or-path ./dataset/llama/ \ + --tokenizer-not-use-fast \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.0003 \ + --lr-decay-style cosine \ + --min-lr 3.0e-5 \ + --weight-decay 1.0e-1 \ + --clip-grad 1.0 \ + --lr-warmup-iters 5000 \ + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1.0e-5 \ + --initial-loss-scale 4096.0 \ + --use-cpu-initialization \ + $ds_args \ + --fp16 | tee logs/NPU_llama2_7b_shape_fp16_layer32_8p_pretrain.out + diff --git a/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights.png b/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights.png new file mode 100755 index 0000000000000000000000000000000000000000..4493cdee03b2bed1388f23949bb5eb238cbb64aa Binary files /dev/null and b/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights.png differ diff --git a/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_absolute.png b/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_absolute.png new file mode 100755 index 0000000000000000000000000000000000000000..b9490cbadca464be327518bc48d03e1745109551 Binary files /dev/null and b/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_absolute.png differ diff --git a/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_relative.png b/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_relative.png new file mode 100755 index 0000000000000000000000000000000000000000..4ff4e13387cc0f68e7e3c43e3bc7d680661d7836 Binary files /dev/null and b/sources/images/llama2/llama2_7b_shape_fp16_layer32_loss_with_weights_comparison_relative.png differ