diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/README.md b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa6884f1cd8dbe1a153eabc794bf92912d06f271 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/README.md @@ -0,0 +1,167 @@ +# LayoutXLM + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理精度](#ZH-CN_TOPIC_0000001172201573) + + ****** + + + +# 概述 + +LayoutXLM是一种用于多语言文档理解的多模态预训练模型,其是LayoutLMv2在多语种场景下的拓展,预训练数据包含53种语种文档。 + + +- 参考实现: + + ``` + url=https://github.com/microsoft/unilm/blob/master/layoutxlm/README.md + ``` + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 数据类型 | 大小 | 数据排布格式 | + |----------------|----------|---------------------------|--------| + | input_ids | int64 | batchsize x 512 | ND | + | bbox | int64 | batchsize x 512 x 4 | ND | + | image | int64 | batchsize x 3 x 224 x 224 | NCHW | + | attention_mask | int64 | batchsize x 512 | ND | + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | ------- |---------------------| -------- | ------------ | + | output | batchsize x 512 x7 | FLOAT32 | ND | + + + +# 推理环境准备 + +- 该模型需要以下依赖 + + **表 1** 版本配套表 + +| 配套 | 版本 | +|-----------------------|-----------------| +| CANN | 6.3.RC2.alph002 | +| Python | 3.9.0 | +| PyTorch | 2.0.1 | +| torchVison | 0.15.2 | +| datasets | 2.13.1 | +| transformers | 4.30.2 | +| seqeval | 1.2.2 | +| accelerate | 0.21.0 | +| detectron2 | 0.6 | +| Ascend-cann-torch-aie | - | +| Ascend-cann-aie | - | +| 芯片类型 | Ascend310P3 | + +# 快速上手 +## 安装CANN包 + + ``` + chmod +x Ascend-cann-toolkit_6.3.RC2.alpha002_linux-aarch64.run +./Ascend-cann-toolkit_6.3.RC2.alpha002_linux-aarch64.run --install + ``` +下载Ascend-cann-torch-aie和Ascend-cann-aie得到run包和压缩包 +## 安装Ascend-cann-aie + ``` + chmod +x Ascend-cann-aie_6.3.T200_linux-aarch64.run + ./Ascend-cann-aie_6.3.T200_linux-aarch64.run --install + cd Ascend-cann-aie + source set_env.sh + ``` +## 安装Ascend-cann-torch-aie + ``` + tar -zxvf Ascend-cann-torch-aie-6.3.T200-linux_aarch64.tar.gz + pip3 install torch-aie-6.3.T200-linux_aarch64.whl + ``` + +## 安装模型依赖 +``` +pip3 install pytorch==2.0.1 +pip3 install torchVision==0.15.2 +pip3 install datasets==2.13.1 +pip3 install transformers==4.30.2 +pip3 install seqeval==1.2.2 +pip3 install accelerate==0.21.0 +python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' +``` + +## 训练模型 + +1. 获取作者模型源代码 + + ``` + git clone https://github.com/microsoft/unilm.git + ``` +2. 准备模型权重文件 + ``` + cd unilm + cd layoutlmft + mkdir microsoft + cd microsoft + git lfs install + git clone https://huggingface.co/microsoft/layoutxlm-base + ``` +3. 准备数据集 + ``` + cd .. + mkdir xfun + cd xfun + wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.train.json + wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.train.zip + wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.val.json + wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.val.zip + wget https://raw.githubusercontent.com/huggingface/datasets/1.6.2/metrics/seqeval/seqeval.py + + cd .. + git clone https://huggingface.co/xlm-roberta-base + ``` + +4. 代码适配 + +由于transformer版本变更原因,需要对一些代码进行适配 + +``` + 下载layoutXLM_change.patch + git apply layoutXLM_change.patch + ``` +5. 模型训练 +``` +pip install . +python examples/run_xfun_ser.py --model_name_or_path microsoft/layoutxlm-base \ + --output_dir ./test-ner \ + --do_train \ + --do_eval \ + --lang zh +``` +6. 替换已训练的权重文件 + + 将训练好的权重文件替换layoutxlm-base文件夹下的同名文件:pytorch_model.bin和sentencepiece.bpe.model + +## 模型推理 +执行推理脚本 +``` +python layout_xlm.py --model_name_or_path microsoft/layoutxlm-base \ + --output_dir ./test-ner \ + --do_eval \ + --lang zh +``` + +# 模型推理性能及精度 + +调用torch-aie推理计算,精度参考下列数据。 + +| 芯片型号 | Batch Size | 数据集 | 精度 | +|-------|------------|------|-------------------| +| 310P3 | 5 | XFUN | accuracy : 82.48% | + diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layoutXLM_change.patch b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layoutXLM_change.patch new file mode 100644 index 0000000000000000000000000000000000000000..8de0246be67a47b86e87f03b9c6dc953272aa3b0 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layoutXLM_change.patch @@ -0,0 +1,69 @@ +Subject: [PATCH] layoutXLM change +--- +Index: layoutlmft/layoutlmft/data/datasets/xfun.py +IDEA additional info: +Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP +<+>UTF-8 +=================================================================== +diff --git a/layoutlmft/layoutlmft/data/datasets/xfun.py b/layoutlmft/layoutlmft/data/datasets/xfun.py +--- a/layoutlmft/layoutlmft/data/datasets/xfun.py (revision f4695ed0244a275201fff00bee495f76670fbe70) ++++ b/layoutlmft/layoutlmft/data/datasets/xfun.py (date 1692760553752) +@@ -9,7 +9,7 @@ + from transformers import AutoTokenizer + + +-_URL = "https://github.com/doc-analysis/XFUN/releases/download/v1.0/" ++_URL = "../../../xfun/" + + _LANG = ["zh", "de", "es", "fr", "en", "it", "ja", "pt"] + logger = logging.getLogger(__name__) +Index: layoutlmft/examples/run_xfun_ser.py +IDEA additional info: +Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP +<+>UTF-8 +=================================================================== +diff --git a/layoutlmft/examples/run_xfun_ser.py b/layoutlmft/examples/run_xfun_ser.py +--- a/layoutlmft/examples/run_xfun_ser.py (revision f4695ed0244a275201fff00bee495f76670fbe70) ++++ b/layoutlmft/examples/run_xfun_ser.py (date 1692760657546) +@@ -189,7 +189,7 @@ + ) + + # Metrics +- metric = load_metric("seqeval") ++ metric = load_metric("../xfun/seqeval.py") + + def compute_metrics(p): + predictions, labels = p +Index: layoutlmft/layoutlmft/__init__.py +IDEA additional info: +Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP +<+>UTF-8 +=================================================================== +diff --git a/layoutlmft/layoutlmft/__init__.py b/layoutlmft/layoutlmft/__init__.py +--- a/layoutlmft/layoutlmft/__init__.py (revision f4695ed0244a275201fff00bee495f76670fbe70) ++++ b/layoutlmft/layoutlmft/__init__.py (date 1692759816428) +@@ -2,7 +2,8 @@ + + from transformers import CONFIG_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_NAMES_MAPPING, TOKENIZER_MAPPING + from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, BertConverter, XLMRobertaConverter +-from transformers.models.auto.modeling_auto import auto_class_factory ++from transformers.models.auto.modeling_auto import _BaseAutoModelClass, auto_class_update ++import types + + from .models.layoutlmv2 import ( + LayoutLMv2Config, +@@ -37,10 +38,8 @@ + [(LayoutLMv2Config, LayoutLMv2ForRelationExtraction), (LayoutXLMConfig, LayoutXLMForRelationExtraction)] + ) + +-AutoModelForTokenClassification = auto_class_factory( +- "AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification" +-) ++cls = types.new_class("AutoModelForTokenClassification", (_BaseAutoModelClass,)) ++cls._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING ++cls.__name__ = "AutoModelForTokenClassification" + +-AutoModelForRelationExtraction = auto_class_factory( +- "AutoModelForRelationExtraction", MODEL_FOR_RELATION_EXTRACTION_MAPPING, head_doc="relation extraction" +-) ++AutoModelForTokenClassification = auto_class_update(cls, head_doc="token classification") diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layout_xlm.py b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layout_xlm.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd1e82f8ec85309b6a9687fd6c09425c22e6b1c --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layout_xlm.py @@ -0,0 +1,156 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np + +import torch +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, + AutoTokenizer, + HfArgumentParser, + PreTrainedTokenizerFast, + TrainingArguments, + set_seed, +) +from datasets import ClassLabel, load_dataset, load_metric + +import layoutlmft.data.datasets.xfun +from layoutlmft.data import DataCollatorForKeyValueExtraction +from layoutlmft.data.data_args import XFUNDataTrainingArguments +from layoutlmft.models.model_args import ModelArguments +from torch_aie_trainer import XfunSerTorchAieTrainer +import torch_npu + +label_list = [] + +def get_label_list(labels): + unique_labels = set() + for label in labels: + unique_labels = unique_labels | set(label) + unique_label_list = list(unique_labels) + unique_label_list.sort() + return unique_label_list + + +def load_model(): + parser = HfArgumentParser((ModelArguments, XFUNDataTrainingArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + training_args.per_device_eval_batch_size = 5 + datasets = load_dataset( + os.path.abspath(layoutlmft.data.datasets.xfun.__file__), + f"xfun.{data_args.lang}" + ) + features = datasets["validation"].features + label_column_name = "labels" + + global label_list + if isinstance(features[label_column_name].feature, ClassLabel): + label_list = features[label_column_name].feature.names + else: + label_list = get_label_list(datasets["train"][label_column_name]) + num_labels = len(label_list) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=True + ) + model = AutoModelForTokenClassification.from_pretrained( + model_args.model_name_or_path, + config=config, + ) + print("load model done.") + return model, tokenizer, model_args, data_args, training_args, datasets + + +def compute_metrics(p): + metric = load_metric("xfun/seqeval.py") + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + global label_list + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = metric.compute(predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + +def trace_ts_model(model, tokenizer, model_args, data_args, training_args, datasets): + print("start trace model") + padding = "max_length" if data_args.pad_to_max_length else False + data_collator = DataCollatorForKeyValueExtraction( + tokenizer, + pad_to_multiple_of=8 if training_args.fp16 else None, + padding=padding, + max_length=512, + ) + eval_dataset = datasets["validation"] + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + trainer = XfunSerTorchAieTrainer( + model=model, + args=training_args, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, + ) + eval_dataloader = trainer.get_eval_dataloader() + for step, inputs in enumerate(eval_dataloader): + tmp_input = trainer.prepare_inputs(inputs) + trace_model = torch.jit.trace(model, ( + tmp_input["input_ids"], + tmp_input["bbox"], + tmp_input["image"].tensor, + tmp_input["attention_mask"]), + strict=False, check_trace=False + ) + trace_model.save("layoutXLM_base.ts") + break + print("trace model done") + return trainer + + +def main(): + torch_npu.set_device(0) + model, tokenizer, model_args, data_args, training_args, datasets = load_model() + trainer = trace_ts_model(model, tokenizer, model_args, data_args, training_args, datasets) + trainer.prepare_aie_model() + metrics = trainer.evaluate() + for key, value in metrics.items(): + print(key, " : ", value) + + +if __name__ == '__main__': + main() diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/torch_aie_trainer.py b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/torch_aie_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a870eaddced45d32eea478a3df8a0f4f97ee9bd --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/torch_aie_trainer.py @@ -0,0 +1,104 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import collections +from typing import Any, Dict, List, Optional, Tuple, Union +from collections.abc import Mapping + +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from layoutlmft.trainers.funsd_trainer import FunsdTrainer +import torch_aie + +class XfunSerTorchAieTrainer(FunsdTrainer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.result = None + + def prepare_aie_model(self): + print("start compile") + trace_model = torch.jit.load("layoutXLM_base.ts") + compile_input = [ + torch_aie.Input((5, 512), dtype=torch_aie.dtype.INT64), + torch_aie.Input((5, 512, 4), dtype=torch_aie.dtype.INT64, tensor_domain=([1,2])), + torch_aie.Input((5, 3, 224, 224), dtype=torch_aie.dtype.INT64), + torch_aie.Input((5, 512), dtype=torch_aie.dtype.INT64), + ] + torch_aie.set_device(0) + self.result = torch_aie.compile(trace_model, inputs=compile_input) + torch.jit.save(self.result, "aie_res.pt") + print("compile success") + + def nested_detach(self, tensors): + if isinstance(tensors, (list, tuple)): + return type(tensors)(self.nested_detach(t) for t in tensors) + elif isinstance(tensors, Mapping): + return type(tensors)({k: self.nested_detach(t) for k, t in tensors.items()}) + return tensors.detach() + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self.prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + if has_labels or loss_without_labels: + labels = self.nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + with self.compute_loss_context_manager(): + outputs = self.result.forward(inputs["input_ids"].int().to("npu"), + inputs["bbox"].int().to("npu"), + inputs["image"].tensor.int().to("npu"), + inputs["attention_mask"].int().to("npu")) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + + logits = self.nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (None, logits, labels) + + def prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: + for k, v in inputs.items(): + if hasattr(v, "to") and hasattr(v, "device"): + inputs[k] = v.to(self.args.device) + + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + + return inputs \ No newline at end of file