From 4efa330eaac89efceb6725368b409333935f25f3 Mon Sep 17 00:00:00 2001 From: zouying Date: Thu, 14 Dec 2023 21:44:54 +0800 Subject: [PATCH] bert large ner --- .../built-in/nlp/Bert_Large_NER/README.md | 120 +++++------ .../nlp/Bert_Large_NER/patchfile.patch | 187 ++++++++++++++++++ .../nlp/Bert_Large_NER/requirements.txt | 7 +- .../nlp/Bert_Large_NER/run_torch_aie.py | 4 +- 4 files changed, 243 insertions(+), 75 deletions(-) create mode 100644 AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/patchfile.patch diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/README.md b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/README.md index 96f5626400..f50bf1aa0b 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/README.md +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/README.md @@ -15,14 +15,22 @@ 具体而言,此模型是一个*bert-large-cased*模型,在标准 [CoNLL-2003 命名实体识别](https://www.aclweb.org/anthology/W03-0419.pdf)数据集的英文版上进行了微调。如果要在同一数据集上使用较小的 BERT 模型进行微调,也可以使用[**基于 NER 的 BERT**](https://huggingface.co/dslim/bert-base-NER/) 版本。 -* 参考实现: +* 模型权重: ``` url = https://huggingface.co/dslim/bert-large-NER - commit_id = 95c62bc0d4109bd97d0578e5ff482e6b84c2b8b9 + commit_id = 95c62bc0d4109bd97d0578e5ff482e6b84c2b8b9 model_name = bert-large-NER ``` +* 参考实现: + + ``` + git clone https://github.com/huggingface/transformers + cd transformers + git checkout -b v4.24.0 v4.24.0 + ``` + ## 1.1. 输入输出数据 - 输入数据 @@ -65,40 +73,25 @@ ## 3.1. 获取模型 -1. 获取开源模型。 - 1. 获取模型权重和配置文件 - ``` - git clone https://huggingface.co/dslim/bert-large-NER - ``` - 得到如下文件: +### 3.1.1. 获取开源模型。 + 获取模型权重和配置文件,下载地址:https://huggingface.co/dslim/bert-large-NER/tree/95c62bc0d4109bd97d0578e5ff482e6b84c2b8b9 + + 文件结构如下: + ``` bert-large-NER/ + ├── README.md ├── config.json - ├── dslim_bert-large-NER #U00b7 Hugging Face_files - │   ├── 1655075923870-5e7565183d77a72421292d00.png - │   ├── analytics.js.#U4e0b#U8f7d - │   ├── css2 - │   ├── css2(1) - │   ├── huggingface_logo-noborder.svg - │   ├── inner.html - │   ├── js - │   ├── katex.min.css - │   ├── m-outer-27c67c0d52761104439bb051c7856ab1.html - │   ├── m-outer-6576085ca35ee42f2f484cda6763e4aa.js.#U4e0b#U8f7d - │   ├── out-4.5.43.js.#U4e0b#U8f7d - │   ├── saved_resource - │   ├── script.js.#U4e0b#U8f7d - │   └── style.css - ├── dslim_bert-large-NER #U00b7 Hugging Face.html + ├── flax_model.mspack + ├── gitattributes ├── pytorch_model.bin - ├── README.md ├── special_tokens_map.json + ├── tf_model.h5 ├── tokenizer_config.json └── vocab.txt ``` - 模型文件只需要下载`pytorch_model.bin`即可。 -2. 安装依赖。 +### 3.1.2. 安装依赖 ``` pip install -r requirements.txt @@ -122,59 +115,46 @@ ├── test.txt └── valid.txt ``` + 修改离线数据集读取路径: +```commandline +vim /root/.cache/huggingface/modules/datasets_modules/datasets/conll2003/{95c62bc0d4109bd97d0578e5ff482e6b84c2b8b9...}/conll2003.py +193 +downloaded_file="path/to/conll2003" +``` ## 3.3. 模型推理 -1. 导出torch script模型: - ``` - python3 export_trace_model.py - ``` - 得到导出后的ts模型:`bert_large_ner.pt` +### 3.3.1 模型推理 +``` +git apply patchfile.patch +cd transformers +pip install . +cd examples/pytorch/token-classification +python run_ner.py --model_name_or_path /path/to/bert-large-NER --dataset_name conll2003 --output_dir /tmp/test-ner --do_predict --overwrite_output_dir --no_cuda --jit_mode_eval --pad_to_max_length --max_seq_length 512 --torch_aie_enable --dataloader_drop_last --per_device_eval_batch_size 1 -2. 修改`bert_large_ner.pt` - > 【注意】 因为`bert_large_ner`模型的`attention_mask`计算部分的`mul`算子的第二个初始化入参`CONSTANTS.c0)`被初始化为`float`数据类型的最小值, - > 该值超出了`fp16`数据类型能够表示的最小值的范围,出现了下溢,如果不做下面的模型修改,将导致模型的精度下降(经测acc=76%)。 - > 经过下面的修改后,模型精度可以达到87.43%, 与om推理的版本仍然存在一定差距,该问题定位当中。 - - 1. 解压 `bert_large_ner.pt` - ``` - unzip -q bert_large_ner.pt - ``` - 得到`bert_large_ner`文件夹 - 2. 修改`bert_large_ner/code/__torch__/transformers/models/bert/modeling_bert.py`文件的第36行, - 修改前: - ``` - attention_mask0 = torch.mul(torch.rsub(_4, 1.), CONSTANTS.c0) - ``` - 修改后: - ``` - attention_mask0 = torch.mul(torch.rsub(_4, 1.), -1000.0) - ``` - 保存修改内容。 - 3. 重新压缩,得到修改后的`bert_large_ner.pt`文件 - ``` - zip -r -q bert_large_ner.pt bert_large_ner/ - ``` - - -3. 模型推理 - ``` - python3 run_torch_aie.py - ``` +# 其他参数说明:max_predict_samples表示控制推理的样本数量 +``` # 4. 模型推理性能&精度 -1. 性能对比 +## 4.1. 性能对比 -| Batch Size | om推理(onnx改图匹配bert大kernel融合算子) | torch-aie推理 | torch-aie/om | -| :--------: | :----------------------------------------: | :-----------: | :----------: | -| 1 | 63.3062 it/s | 39.69 it/s | 0.6269 | +| Batch Size | om推理(onnx改图匹配bert大kernel融合算子) | torch-aie推理 | torch-aie/om | +| :--------: | :----------------------------------------: |:------------:|:------------:| +| 1 | 63.3062 it/s | 43.5007 it/s | 0.6871 | +| 4 | 70.6858 it/s | 42.2668 it/s | 0.5980 | +| 8 | 74.5971 it/s | 37.7312 it/s | 0.5058 | +| 16 | 73.7290 it/s | 35.8528 it/s | 0.4863 | +| 32 | 73.6084 it/s | 35.1328 it/s | 0.4773 | +| 64 | 71.1308 it/s | 36.2688 it/s | 0.5099 | > 性能有改进空间,待通过aie接入bert优化pass。 -1. 精度对比 +## 4.2. 精度对比 | 模型 | Batch Size | om推理 | torch-aie推理 | | :------------: | :--------: | :----: | :-----------: | -| bert_large_NER | 1 | 90.74% | 87.43% | - -> 原始模型在Pytorch CPU 框架下测试的精度为87.43,与PT插件精度一致。 +| bert_large_NER | 1 | 90.74% | 90.90% | +| bert_large_NER | 4 | 90.74% | 90.91% | +| bert_large_NER | 8 | 90.74% | 90.92% | +| bert_large_NER | 16 | 90.74% | 90.89% | +| bert_large_NER | 32 | 90.74% | 90.88% | +| bert_large_NER | 64 | 90.74% | 90.85% | diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/patchfile.patch b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/patchfile.patch new file mode 100644 index 0000000000..2417a1babf --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/patchfile.patch @@ -0,0 +1,187 @@ +diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py +index 152e52fe2..67b7a6bb7 100644 +--- a/src/transformers/modeling_utils.py ++++ b/src/transformers/modeling_utils.py +@@ -788,7 +788,7 @@ class ModuleUtilsMixin: + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility +- extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min ++ extended_attention_mask = (1.0 - extended_attention_mask) * (-1000.0) + return extended_attention_mask + + def get_head_mask( +diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py +index ac166f002..7c869f002 100755 +--- a/src/transformers/trainer.py ++++ b/src/transformers/trainer.py +@@ -53,6 +53,7 @@ from .integrations import ( # isort: split + + import numpy as np + import torch ++import torch_aie + import torch.distributed as dist + from packaging import version + from torch import nn +@@ -204,7 +205,8 @@ TRAINER_STATE_NAME = "trainer_state.json" + OPTIMIZER_NAME = "optimizer.pt" + SCHEDULER_NAME = "scheduler.pt" + SCALER_NAME = "scaler.pt" +- ++NPU_DEVICE = 'npu:0' ++torch_aie.set_device(int(NPU_DEVICE.split(':')[-1])) + + class Trainer: + """ +@@ -327,6 +329,10 @@ class Trainer: + # force device and distributed setup init explicitly + args._setup_devices + ++ # calculate torch aie inference time ++ self.count = 0 ++ self.inference_times = [] ++ + if model is None: + if model_init is not None: + self.model_init = model_init +@@ -1257,17 +1263,24 @@ class Trainer: + example_tensor = torch.ones_like(example_batch[key]) + jit_inputs.append(example_tensor) + jit_inputs = tuple(jit_inputs) ++ is_jit_model = False + try: + jit_model = model.eval() + with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]): +- jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) ++ jit_model = torch.jit.trace(jit_model, jit_inputs[:3], strict=False) + jit_model = torch.jit.freeze(jit_model) +- jit_model(**example_batch) ++ example_batch.pop("labels") ++ input_ids = example_batch["input_ids"] ++ attention_mask = example_batch["attention_mask"] ++ token_type_ids = example_batch["token_type_ids"] ++ jit_model(input_ids, attention_mask, token_type_ids) + model = jit_model ++ is_jit_model = True ++ logger.info("torch jit model success.") + except (RuntimeError, TypeError) as e: + logger.warning(f"failed to use PyTorch jit mode due to: {e}.") + +- return model ++ return model, is_jit_model + + def ipex_optimize_model(self, model, training=False, dtype=torch.float32): + if not is_ipex_available(): +@@ -1291,13 +1304,37 @@ class Trainer: + + return model + ++ def torch_aie_compile(self, dataloader, ts_model): ++ max_seq_length = 512 ++ input_info = [ ++ torch_aie.Input((dataloader.batch_size, max_seq_length), dtype=torch.int64), ++ torch_aie.Input((dataloader.batch_size, max_seq_length), dtype=torch.int64), ++ torch_aie.Input((dataloader.batch_size, max_seq_length), dtype=torch.int64) ++ ] ++ model = torch_aie.compile( ++ ts_model, ++ inputs=input_info, ++ precision_policy=torch_aie.PrecisionPolicy.FP16, ++ truncate_long_and_double=True, ++ require_full_compilation=False, ++ allow_tensor_replace_int=True, ++ torch_executed_ops=[], ++ soc_version="Ascend310P3", ++ optimization_level=0 ++ ) ++ model.eval() ++ return model ++ + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.use_ipex: + dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 + model = self.ipex_optimize_model(model, training, dtype=dtype) + + if self.args.jit_mode_eval: +- model = self.torch_jit_model_eval(model, dataloader, training) ++ model, is_jit_model = self.torch_jit_model_eval(model, dataloader, training) ++ if is_jit_model and self.args.torch_aie_enable: ++ torch_aie.set_device(0) ++ model = self.torch_aie_compile(dataloader, model) + + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. +@@ -2994,6 +3031,7 @@ class Trainer: + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + if logits is not None: ++ logits = logits.view(batch_size, 512, 9) + logits = self._pad_across_processes(logits) + logits = self._nested_gather(logits) + if self.preprocess_logits_for_metrics is not None: +@@ -3024,6 +3062,11 @@ class Trainer: + + # Set back to None to begin a new accumulation + losses_host, preds_host, inputs_host, labels_host = None, None, None, None ++ if args.torch_aie_enable: ++ average_inference_time = np.mean(self.inference_times) ++ logger.info(f"Pure model inference performance per sample = {average_inference_time * 1000} ms") ++ logger.info(f"Performance= {batch_size * 1.0 / average_inference_time} it/s") ++ logger.info(f"inference time count/total count: {len(self.inference_times)} / {self.count}") + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop +@@ -3211,6 +3254,24 @@ class Trainer: + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) ++ elif self.args.torch_aie_enable and has_labels: ++ loss = None ++ logger.info("torch aie model forward... ") ++ inputs_npu1 = inputs["input_ids"].to(NPU_DEVICE) ++ inputs_npu2 = inputs["attention_mask"].to(NPU_DEVICE) ++ inputs_npu3 = inputs["token_type_ids"].to(NPU_DEVICE) ++ stream = torch_aie.npu.Stream(NPU_DEVICE) ++ with torch_aie.npu.stream(stream): ++ inf_start = time.time() ++ outputs = model(inputs_npu1, inputs_npu2, inputs_npu3) ++ stream.synchronize() ++ inf_end = time.time() ++ inf_time = inf_end - inf_start ++ self.count += 1 ++ if self.count > 5: ++ self.inference_times.append(inf_time) ++ logits = outputs["logits"] ++ logger.info(f"step: {self.count} infer time: {inf_time}s") + else: + if has_labels: + with self.compute_loss_context_manager(): +@@ -3506,11 +3567,11 @@ class Trainer: + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device +- if not self.is_in_train: +- if args.fp16_full_eval: +- model = model.to(dtype=torch.float16, device=args.device) +- elif args.bf16_full_eval: +- model = model.to(dtype=torch.bfloat16, device=args.device) ++ # if not self.is_in_train: ++ # if args.fp16_full_eval: ++ # model = model.to(dtype=torch.float16, device=args.device) ++ # elif args.bf16_full_eval: ++ # model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = dataloader.batch_size + num_examples = self.num_examples(dataloader) +diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py +index fc5ace752..19df2cfce 100644 +--- a/src/transformers/training_args.py ++++ b/src/transformers/training_args.py +@@ -995,6 +995,7 @@ class TrainingArguments: + "help": "Overrides the default timeout for distributed training (value should be given in seconds)." + }, + ) ++ torch_aie_enable: bool = field(default=True, metadata={"help": "Whether or not to use torch aie."}) + + def __post_init__(self): + # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/requirements.txt b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/requirements.txt index b708412456..bd26a20eec 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/requirements.txt +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/requirements.txt @@ -1,5 +1,6 @@ -numpy==1.26.1 +numpy==1.26.2 torch==2.0.1+cpu torchvision==0.15.2+cpu -transformers==4.34.0 -tqdm==4.64.0 \ No newline at end of file +datasets==2.15.0 +tqdm==4.66.1 +decorator==4.3.0 \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/run_torch_aie.py b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/run_torch_aie.py index 14394ec893..c7f487bc39 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/run_torch_aie.py +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Large_NER/run_torch_aie.py @@ -105,12 +105,12 @@ def run_test(torchaie_model, test_dataset): tag = test_dataset['tags'][i] input_ids_npu = torch.tensor(np.array(input_id)).to(NPU_DEVICE) - token_type_ids_npu = torch.tensor(np.array(token_type_id)).to(NPU_DEVICE) attention_mask_npu = torch.tensor(np.array(attention_mask)).to(NPU_DEVICE) + token_type_ids_npu = torch.tensor(np.array(token_type_id)).to(NPU_DEVICE) stream = torch_aie.npu.Stream(NPU_DEVICE) with torch_aie.npu.stream(stream): inf_start = time.time() - result = torchaie_model(input_ids_npu, token_type_ids_npu, attention_mask_npu) + result = torchaie_model(input_ids_npu, attention_mask_npu, token_type_ids_npu) stream.synchronize() inf_end = time.time() inf_time = inf_end - inf_start -- Gitee