diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/README.md b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa6884f1cd8dbe1a153eabc794bf92912d06f271
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/README.md
@@ -0,0 +1,167 @@
+# LayoutXLM
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+- [环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+- [模型推理精度](#ZH-CN_TOPIC_0000001172201573)
+
+ ******
+
+
+
+# 概述
+
+LayoutXLM是一种用于多语言文档理解的多模态预训练模型,其是LayoutLMv2在多语种场景下的拓展,预训练数据包含53种语种文档。
+
+
+- 参考实现:
+
+ ```
+ url=https://github.com/microsoft/unilm/blob/master/layoutxlm/README.md
+ ```
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 数据类型 | 大小 | 数据排布格式 |
+ |----------------|----------|---------------------------|--------|
+ | input_ids | int64 | batchsize x 512 | ND |
+ | bbox | int64 | batchsize x 512 x 4 | ND |
+ | image | int64 | batchsize x 3 x 224 x 224 | NCHW |
+ | attention_mask | int64 | batchsize x 512 | ND |
+
+
+- 输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | ------- |---------------------| -------- | ------------ |
+ | output | batchsize x 512 x7 | FLOAT32 | ND |
+
+
+
+# 推理环境准备
+
+- 该模型需要以下依赖
+
+ **表 1** 版本配套表
+
+| 配套 | 版本 |
+|-----------------------|-----------------|
+| CANN | 6.3.RC2.alph002 |
+| Python | 3.9.0 |
+| PyTorch | 2.0.1 |
+| torchVison | 0.15.2 |
+| datasets | 2.13.1 |
+| transformers | 4.30.2 |
+| seqeval | 1.2.2 |
+| accelerate | 0.21.0 |
+| detectron2 | 0.6 |
+| Ascend-cann-torch-aie | - |
+| Ascend-cann-aie | - |
+| 芯片类型 | Ascend310P3 |
+
+# 快速上手
+## 安装CANN包
+
+ ```
+ chmod +x Ascend-cann-toolkit_6.3.RC2.alpha002_linux-aarch64.run
+./Ascend-cann-toolkit_6.3.RC2.alpha002_linux-aarch64.run --install
+ ```
+下载Ascend-cann-torch-aie和Ascend-cann-aie得到run包和压缩包
+## 安装Ascend-cann-aie
+ ```
+ chmod +x Ascend-cann-aie_6.3.T200_linux-aarch64.run
+ ./Ascend-cann-aie_6.3.T200_linux-aarch64.run --install
+ cd Ascend-cann-aie
+ source set_env.sh
+ ```
+## 安装Ascend-cann-torch-aie
+ ```
+ tar -zxvf Ascend-cann-torch-aie-6.3.T200-linux_aarch64.tar.gz
+ pip3 install torch-aie-6.3.T200-linux_aarch64.whl
+ ```
+
+## 安装模型依赖
+```
+pip3 install pytorch==2.0.1
+pip3 install torchVision==0.15.2
+pip3 install datasets==2.13.1
+pip3 install transformers==4.30.2
+pip3 install seqeval==1.2.2
+pip3 install accelerate==0.21.0
+python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
+```
+
+## 训练模型
+
+1. 获取作者模型源代码
+
+ ```
+ git clone https://github.com/microsoft/unilm.git
+ ```
+2. 准备模型权重文件
+ ```
+ cd unilm
+ cd layoutlmft
+ mkdir microsoft
+ cd microsoft
+ git lfs install
+ git clone https://huggingface.co/microsoft/layoutxlm-base
+ ```
+3. 准备数据集
+ ```
+ cd ..
+ mkdir xfun
+ cd xfun
+ wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.train.json
+ wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.train.zip
+ wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.val.json
+ wget https://github.com/doc-analysis/XFUND/releases/download/v1.0/zh.val.zip
+ wget https://raw.githubusercontent.com/huggingface/datasets/1.6.2/metrics/seqeval/seqeval.py
+
+ cd ..
+ git clone https://huggingface.co/xlm-roberta-base
+ ```
+
+4. 代码适配
+
+由于transformer版本变更原因,需要对一些代码进行适配
+
+```
+ 下载layoutXLM_change.patch
+ git apply layoutXLM_change.patch
+ ```
+5. 模型训练
+```
+pip install .
+python examples/run_xfun_ser.py --model_name_or_path microsoft/layoutxlm-base \
+ --output_dir ./test-ner \
+ --do_train \
+ --do_eval \
+ --lang zh
+```
+6. 替换已训练的权重文件
+
+ 将训练好的权重文件替换layoutxlm-base文件夹下的同名文件:pytorch_model.bin和sentencepiece.bpe.model
+
+## 模型推理
+执行推理脚本
+```
+python layout_xlm.py --model_name_or_path microsoft/layoutxlm-base \
+ --output_dir ./test-ner \
+ --do_eval \
+ --lang zh
+```
+
+# 模型推理性能及精度
+
+调用torch-aie推理计算,精度参考下列数据。
+
+| 芯片型号 | Batch Size | 数据集 | 精度 |
+|-------|------------|------|-------------------|
+| 310P3 | 5 | XFUN | accuracy : 82.48% |
+
diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layoutXLM_change.patch b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layoutXLM_change.patch
new file mode 100644
index 0000000000000000000000000000000000000000..8de0246be67a47b86e87f03b9c6dc953272aa3b0
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layoutXLM_change.patch
@@ -0,0 +1,69 @@
+Subject: [PATCH] layoutXLM change
+---
+Index: layoutlmft/layoutlmft/data/datasets/xfun.py
+IDEA additional info:
+Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
+<+>UTF-8
+===================================================================
+diff --git a/layoutlmft/layoutlmft/data/datasets/xfun.py b/layoutlmft/layoutlmft/data/datasets/xfun.py
+--- a/layoutlmft/layoutlmft/data/datasets/xfun.py (revision f4695ed0244a275201fff00bee495f76670fbe70)
++++ b/layoutlmft/layoutlmft/data/datasets/xfun.py (date 1692760553752)
+@@ -9,7 +9,7 @@
+ from transformers import AutoTokenizer
+
+
+-_URL = "https://github.com/doc-analysis/XFUN/releases/download/v1.0/"
++_URL = "../../../xfun/"
+
+ _LANG = ["zh", "de", "es", "fr", "en", "it", "ja", "pt"]
+ logger = logging.getLogger(__name__)
+Index: layoutlmft/examples/run_xfun_ser.py
+IDEA additional info:
+Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
+<+>UTF-8
+===================================================================
+diff --git a/layoutlmft/examples/run_xfun_ser.py b/layoutlmft/examples/run_xfun_ser.py
+--- a/layoutlmft/examples/run_xfun_ser.py (revision f4695ed0244a275201fff00bee495f76670fbe70)
++++ b/layoutlmft/examples/run_xfun_ser.py (date 1692760657546)
+@@ -189,7 +189,7 @@
+ )
+
+ # Metrics
+- metric = load_metric("seqeval")
++ metric = load_metric("../xfun/seqeval.py")
+
+ def compute_metrics(p):
+ predictions, labels = p
+Index: layoutlmft/layoutlmft/__init__.py
+IDEA additional info:
+Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
+<+>UTF-8
+===================================================================
+diff --git a/layoutlmft/layoutlmft/__init__.py b/layoutlmft/layoutlmft/__init__.py
+--- a/layoutlmft/layoutlmft/__init__.py (revision f4695ed0244a275201fff00bee495f76670fbe70)
++++ b/layoutlmft/layoutlmft/__init__.py (date 1692759816428)
+@@ -2,7 +2,8 @@
+
+ from transformers import CONFIG_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_NAMES_MAPPING, TOKENIZER_MAPPING
+ from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, BertConverter, XLMRobertaConverter
+-from transformers.models.auto.modeling_auto import auto_class_factory
++from transformers.models.auto.modeling_auto import _BaseAutoModelClass, auto_class_update
++import types
+
+ from .models.layoutlmv2 import (
+ LayoutLMv2Config,
+@@ -37,10 +38,8 @@
+ [(LayoutLMv2Config, LayoutLMv2ForRelationExtraction), (LayoutXLMConfig, LayoutXLMForRelationExtraction)]
+ )
+
+-AutoModelForTokenClassification = auto_class_factory(
+- "AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
+-)
++cls = types.new_class("AutoModelForTokenClassification", (_BaseAutoModelClass,))
++cls._model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
++cls.__name__ = "AutoModelForTokenClassification"
+
+-AutoModelForRelationExtraction = auto_class_factory(
+- "AutoModelForRelationExtraction", MODEL_FOR_RELATION_EXTRACTION_MAPPING, head_doc="relation extraction"
+-)
++AutoModelForTokenClassification = auto_class_update(cls, head_doc="token classification")
diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layout_xlm.py b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layout_xlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecd1e82f8ec85309b6a9687fd6c09425c22e6b1c
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/layout_xlm.py
@@ -0,0 +1,156 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import numpy as np
+
+import torch
+from transformers import (
+ AutoConfig,
+ AutoModelForTokenClassification,
+ AutoTokenizer,
+ HfArgumentParser,
+ PreTrainedTokenizerFast,
+ TrainingArguments,
+ set_seed,
+)
+from datasets import ClassLabel, load_dataset, load_metric
+
+import layoutlmft.data.datasets.xfun
+from layoutlmft.data import DataCollatorForKeyValueExtraction
+from layoutlmft.data.data_args import XFUNDataTrainingArguments
+from layoutlmft.models.model_args import ModelArguments
+from torch_aie_trainer import XfunSerTorchAieTrainer
+import torch_npu
+
+label_list = []
+
+def get_label_list(labels):
+ unique_labels = set()
+ for label in labels:
+ unique_labels = unique_labels | set(label)
+ unique_label_list = list(unique_labels)
+ unique_label_list.sort()
+ return unique_label_list
+
+
+def load_model():
+ parser = HfArgumentParser((ModelArguments, XFUNDataTrainingArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ training_args.per_device_eval_batch_size = 5
+ datasets = load_dataset(
+ os.path.abspath(layoutlmft.data.datasets.xfun.__file__),
+ f"xfun.{data_args.lang}"
+ )
+ features = datasets["validation"].features
+ label_column_name = "labels"
+
+ global label_list
+ if isinstance(features[label_column_name].feature, ClassLabel):
+ label_list = features[label_column_name].feature.names
+ else:
+ label_list = get_label_list(datasets["train"][label_column_name])
+ num_labels = len(label_list)
+
+ config = AutoConfig.from_pretrained(
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+ num_labels=num_labels,
+ finetuning_task=data_args.task_name,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
+ cache_dir=model_args.cache_dir,
+ use_fast=True
+ )
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ )
+ print("load model done.")
+ return model, tokenizer, model_args, data_args, training_args, datasets
+
+
+def compute_metrics(p):
+ metric = load_metric("xfun/seqeval.py")
+ predictions, labels = p
+ predictions = np.argmax(predictions, axis=2)
+
+ global label_list
+ true_predictions = [
+ [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+ true_labels = [
+ [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ results = metric.compute(predictions=true_predictions, references=true_labels)
+ return {
+ "precision": results["overall_precision"],
+ "recall": results["overall_recall"],
+ "f1": results["overall_f1"],
+ "accuracy": results["overall_accuracy"],
+ }
+
+
+def trace_ts_model(model, tokenizer, model_args, data_args, training_args, datasets):
+ print("start trace model")
+ padding = "max_length" if data_args.pad_to_max_length else False
+ data_collator = DataCollatorForKeyValueExtraction(
+ tokenizer,
+ pad_to_multiple_of=8 if training_args.fp16 else None,
+ padding=padding,
+ max_length=512,
+ )
+ eval_dataset = datasets["validation"]
+ if data_args.max_val_samples is not None:
+ eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
+ trainer = XfunSerTorchAieTrainer(
+ model=model,
+ args=training_args,
+ eval_dataset=eval_dataset if training_args.do_eval else None,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ compute_metrics=compute_metrics,
+ )
+ eval_dataloader = trainer.get_eval_dataloader()
+ for step, inputs in enumerate(eval_dataloader):
+ tmp_input = trainer.prepare_inputs(inputs)
+ trace_model = torch.jit.trace(model, (
+ tmp_input["input_ids"],
+ tmp_input["bbox"],
+ tmp_input["image"].tensor,
+ tmp_input["attention_mask"]),
+ strict=False, check_trace=False
+ )
+ trace_model.save("layoutXLM_base.ts")
+ break
+ print("trace model done")
+ return trainer
+
+
+def main():
+ torch_npu.set_device(0)
+ model, tokenizer, model_args, data_args, training_args, datasets = load_model()
+ trainer = trace_ts_model(model, tokenizer, model_args, data_args, training_args, datasets)
+ trainer.prepare_aie_model()
+ metrics = trainer.evaluate()
+ for key, value in metrics.items():
+ print(key, " : ", value)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/torch_aie_trainer.py b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/torch_aie_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a870eaddced45d32eea478a3df8a0f4f97ee9bd
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/nlp/multimodal/LayoutXLM/torch_aie_trainer.py
@@ -0,0 +1,104 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import collections
+from typing import Any, Dict, List, Optional, Tuple, Union
+from collections.abc import Mapping
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader, Dataset
+from layoutlmft.trainers.funsd_trainer import FunsdTrainer
+import torch_aie
+
+class XfunSerTorchAieTrainer(FunsdTrainer):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.result = None
+
+ def prepare_aie_model(self):
+ print("start compile")
+ trace_model = torch.jit.load("layoutXLM_base.ts")
+ compile_input = [
+ torch_aie.Input((5, 512), dtype=torch_aie.dtype.INT64),
+ torch_aie.Input((5, 512, 4), dtype=torch_aie.dtype.INT64, tensor_domain=([1,2])),
+ torch_aie.Input((5, 3, 224, 224), dtype=torch_aie.dtype.INT64),
+ torch_aie.Input((5, 512), dtype=torch_aie.dtype.INT64),
+ ]
+ torch_aie.set_device(0)
+ self.result = torch_aie.compile(trace_model, inputs=compile_input)
+ torch.jit.save(self.result, "aie_res.pt")
+ print("compile success")
+
+ def nested_detach(self, tensors):
+ if isinstance(tensors, (list, tuple)):
+ return type(tensors)(self.nested_detach(t) for t in tensors)
+ elif isinstance(tensors, Mapping):
+ return type(tensors)({k: self.nested_detach(t) for k, t in tensors.items()})
+ return tensors.detach()
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
+ return_loss = inputs.get("return_loss", None)
+ if return_loss is None:
+ return_loss = self.can_return_loss
+ loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
+
+ inputs = self.prepare_inputs(inputs)
+ if ignore_keys is None:
+ if hasattr(self.model, "config"):
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ if has_labels or loss_without_labels:
+ labels = self.nested_detach(tuple(inputs.get(name) for name in self.label_names))
+ if len(labels) == 1:
+ labels = labels[0]
+ else:
+ labels = None
+
+ with torch.no_grad():
+ with self.compute_loss_context_manager():
+ outputs = self.result.forward(inputs["input_ids"].int().to("npu"),
+ inputs["bbox"].int().to("npu"),
+ inputs["image"].tensor.int().to("npu"),
+ inputs["attention_mask"].int().to("npu"))
+ if isinstance(outputs, dict):
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
+ else:
+ logits = outputs
+
+ logits = self.nested_detach(logits)
+ if len(logits) == 1:
+ logits = logits[0]
+
+ return (None, logits, labels)
+
+ def prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
+ for k, v in inputs.items():
+ if hasattr(v, "to") and hasattr(v, "device"):
+ inputs[k] = v.to(self.args.device)
+
+ if self.args.past_index >= 0 and self._past is not None:
+ inputs["mems"] = self._past
+
+ return inputs
\ No newline at end of file