diff --git a/README.md b/README.md index 7f758d77bda93295d5aa6e5470cd39e8dda6c579..6dbc9e005ce4ab6a4d8179d75e3b84cce9a2873f 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,11 @@ mkdir logs mkdir ckpt ``` -2. Download [dataset](https://onebox.huawei.com/p/8cb41da9ed8ae1579fbc50d2a2bdebd7): +2. Download GPT dataset or LLama dataset: ```bash mv dataset AscendSpeed cd dataset -tar -xzvf enwiki-gpt.tar.gz +tar -xzvf enwiki-gpt.tar.gz # or llama-data.tar.gz ``` 3. Build conda envs: @@ -24,8 +24,7 @@ conda create -n py37 python=3.7 conda activate py37 ``` -4. Install pytorch and dependency: \ -(pytorch should be selected according to your server, X86 or Arm, "[2023060213](https://onebox.huawei.com/p/d79535eac8fb159399febca1afa81f95)" version is used here) +4. Install pytorch and dependency: ```bash pip install apex-0.1_ascend_XXXX-cp37-cp37m-linux_XXXX.whl pip install torch-1.11.0+cpu-cp37-cp37m-linux_XXXXX.whl @@ -35,7 +34,6 @@ pip install -r requirements ``` 5. Install deepspeed and adapter: -(Deepspeed adapter v0.9.2 can also be downloaded from [here](https://onebox.huawei.com/p/6cdf01273f5d4edbd38a6da7b972c8f9)) ```bash pip install deepspeed==0.9.2 git clone https://gitee.com/ascend/DeepSpeed.git (Deepspeed adapter v0.9.2, not master!) @@ -50,7 +48,19 @@ sh examples/gpt_task/pretrain_gpt_ptd_8p.sh (Demon gpt with 200M parameters) sh examples/llama_task/pretrain_llama_ptd_8p.sh (Demon llama with 250M parameters) ``` +# Using HuggingFace Tokenizer +Llama Tokenizer +------ +1. Set `--tokenizer-name-or-path` in the training script : +```bash +# examples/llama_task/pretrain_llama_ptd_8p.sh +--tokenizer-name-or-path ./dataset/llama/ \ +--tokenizer-not-use-fast \ +``` +2. Remove `--vocab-file` and`--merge-file` arguments. +3. Make sure the `tokenizer_type` of `args_defaults` in `pretrain_llama.py` is `PretrainedFromHF`. + diff --git a/ascendspeed/arguments.py b/ascendspeed/arguments.py index 89c1b8fad7a1ca8200d9554580b24f75140d6c7b..221c396d73d943dff242afeedeba2dfea240102f 100644 --- a/ascendspeed/arguments.py +++ b/ascendspeed/arguments.py @@ -783,6 +783,8 @@ def _add_data_args(parser): help='What type of tokenizer to use.') group.add_argument("--tokenizer-name-or-path", type=str, default=None, help="Name or path of the huggingface tokenizer.") + group.add_argument("--tokenizer-not-use-fast", action='store_false', + help="HuggingFace tokenizer not use the fast version.") group.add_argument('--data-impl', type=str, default='infer', choices=['lazy', 'cached', 'mmap', 'infer'], help='Implementation of indexed datasets.') diff --git a/ascendspeed/tokenizer/tokenizer.py b/ascendspeed/tokenizer/tokenizer.py index 3787d039dfbb61e561d20a185a8ff4f6cdcd5e35..4adf97b2dee219968df2356dac2b34d3c23f830c 100644 --- a/ascendspeed/tokenizer/tokenizer.py +++ b/ascendspeed/tokenizer/tokenizer.py @@ -55,7 +55,11 @@ def build_tokenizer(args): if args.rank == 0: print(" vocab file is un-used. loading tokenizer from pre-trained model") - tokenizer = _AutoTokenizer(args.tokenizer_name_or_path, vocab_extra_ids=args.vocab_extra_ids) + tokenizer = _AutoTokenizer( + args.tokenizer_name_or_path, + vocab_extra_ids=args.vocab_extra_ids, + model_max_length=args.seq_length, + use_fast=args.tokenizer_not_use_fast) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) @@ -321,13 +325,16 @@ class _GPT2BPETokenizer(AbstractTokenizer): class _AutoTokenizer(AbstractTokenizer): """AutoTokenizer for Hf Pretrained model loading.""" - def __init__(self, tokenizer_name_or_path, vocab_extra_ids): + def __init__(self, tokenizer_name_or_path, vocab_extra_ids, model_max_length, use_fast): name = tokenizer_name_or_path super().__init__(name) hf_tokenizer_kwargs = {} if vocab_extra_ids > 0: # TODO @thomasw21 we might need to concatenate to a pre-existing list? hf_tokenizer_kwargs["additional_special_tokens"] = [f"" for _id in range(vocab_extra_ids)] + + hf_tokenizer_kwargs["model_max_length"] = model_max_length + hf_tokenizer_kwargs["use_fast"] = use_fast self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs) self.encoder = self.tokenizer.get_vocab() self.decoder = {v: k for k, v in self.encoder.items()} @@ -399,4 +406,4 @@ class _AutoTokenizer(AbstractTokenizer): def _check_token_candidate(candidate): if candidate is None: raise AttributeError("Token doesn't exist") - return candidate \ No newline at end of file + return candidate diff --git a/examples/llama_task/pretrain_llama_1p.sh b/examples/llama_task/pretrain_llama_1p.sh index 735d0c8dfbaa4367f521c68831553d3564928155..26487653767d22eb39822304679302e4ee21edf4 100644 --- a/examples/llama_task/pretrain_llama_1p.sh +++ b/examples/llama_task/pretrain_llama_1p.sh @@ -6,7 +6,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh RANK=0 WORLD_SIZE=1 -DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +DATA_PATH=./dataset/llama_text_document CHECKPOINT_PATH=./ckpt export LOCAL_RANK=0 @@ -26,8 +26,8 @@ python pretrain_llama.py \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ --data-path $DATA_PATH \ - --vocab-file ./dataset/gpt2-vocab.json \ - --merge-file ./dataset/gpt2-merges.txt \ + --tokenizer-name-or-path ./dataset/llama/ \ + --tokenizer-not-use-fast \ --data-impl mmap \ --split 949,50,1 \ --distributed-backend nccl \ diff --git a/examples/llama_task/pretrain_llama_ptd_8p.sh b/examples/llama_task/pretrain_llama_ptd_8p.sh index 73efe04e20a081220f9182ec54d5eddc0e806348..f6bdedbd7deb554479f7f2601963b637a63ef358 100644 --- a/examples/llama_task/pretrain_llama_ptd_8p.sh +++ b/examples/llama_task/pretrain_llama_ptd_8p.sh @@ -11,7 +11,7 @@ NNODES=1 NODE_RANK=0 WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) -DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +DATA_PATH=./dataset/llama_text_document CHECKPOINT_PATH=./ckpt DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" @@ -34,8 +34,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ --data-path $DATA_PATH \ - --vocab-file ./dataset/gpt2-vocab.json \ - --merge-file ./dataset/gpt2-merges.txt \ + --tokenizer-name-or-path ./dataset/llama/ \ + --tokenizer-not-use-fast \ --data-impl mmap \ --split 949,50,1 \ --distributed-backend nccl \ diff --git a/examples/llama_task/pretrain_llama_td_8p.sh b/examples/llama_task/pretrain_llama_td_8p.sh index 9cfec293819fdb9ed2f3c8936dd73fc8a851003a..2a73a99637edbc877fc29045e8af9782c897ca2f 100644 --- a/examples/llama_task/pretrain_llama_td_8p.sh +++ b/examples/llama_task/pretrain_llama_td_8p.sh @@ -11,7 +11,7 @@ NNODES=1 NODE_RANK=0 WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) -DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +DATA_PATH=./dataset/llama_text_document CHECKPOINT_PATH=./ckpt DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" @@ -33,8 +33,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ --data-path $DATA_PATH \ - --vocab-file ./dataset/gpt2-vocab.json \ - --merge-file ./dataset/gpt2-merges.txt \ + --tokenizer-name-or-path ./dataset/llama/ \ + --tokenizer-not-use-fast \ --data-impl mmap \ --split 949,50,1 \ --distributed-backend nccl \ diff --git a/pretrain_llama.py b/pretrain_llama.py index e060dcbff66ea1bc04fc0adacd1825ef3aa90d77..01021048e9c843a017d3c6ce691713870b9a25f9 100644 --- a/pretrain_llama.py +++ b/pretrain_llama.py @@ -213,5 +213,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": pretrain(train_valid_test_datasets_provider, model_provider, forward_step, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - data_post_process=data_post_process) \ No newline at end of file + args_defaults={'tokenizer_type': 'PretrainedFromHF'}, + data_post_process=data_post_process) diff --git a/requirements.txt b/requirements.txt index 0e2c615a83e2d0aade025fff3c617ae9c4883b67..9e1e3a7b41dcd9f2a82265d97314dc96d758b7e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ attrs expecttest pytest wrapt_timeout_decorator -transformers \ No newline at end of file +transformers +sentencepiece \ No newline at end of file