From 2aa3a8b3684bdbd41c2101668ef4a7349659951e Mon Sep 17 00:00:00 2001 From: tianxi-yi Date: Thu, 4 Jul 2024 12:59:53 +0800 Subject: [PATCH 1/2] Add DeBertA in IxRT link #IA9VI6 Add DeBertA in IxRT Signed-off-by: tianxi-yi --- .../nlp/language_model/deberta/ixrt/README.md | 100 +++++ .../deberta/ixrt/perf_engine.py | 349 ++++++++++++++++++ .../deberta/ixrt/remove_clip_and_cast.py | 95 +++++ .../scripts/infer_deberta_fp16_performance.sh | 41 ++ .../ixrt/scripts/prepare_model_and_dataset.sh | 34 ++ .../language_model/deberta/ixrt/torch2onnx.py | 73 ++++ 6 files changed, 692 insertions(+) create mode 100644 models/nlp/language_model/deberta/ixrt/README.md create mode 100644 models/nlp/language_model/deberta/ixrt/perf_engine.py create mode 100644 models/nlp/language_model/deberta/ixrt/remove_clip_and_cast.py create mode 100644 models/nlp/language_model/deberta/ixrt/scripts/infer_deberta_fp16_performance.sh create mode 100644 models/nlp/language_model/deberta/ixrt/scripts/prepare_model_and_dataset.sh create mode 100644 models/nlp/language_model/deberta/ixrt/torch2onnx.py diff --git a/models/nlp/language_model/deberta/ixrt/README.md b/models/nlp/language_model/deberta/ixrt/README.md new file mode 100644 index 00000000..26b4ea8f --- /dev/null +++ b/models/nlp/language_model/deberta/ixrt/README.md @@ -0,0 +1,100 @@ +# DeBerta + +## Description + +DeBERTa (Decoding-enhanced BERT with disentangled attention) is an enhanced version of the BERT (Bidirectional Encoder Representations from Transformers) model. It improves text representation learning by introducing disentangled attention mechanisms and decoding enhancement techniques.DeBERTa introduces disentangled attention mechanisms that decompose the self-attention matrix into different parts, focusing on different semantic information. This helps the model better capture relationships between texts.By incorporating decoding enhancement techniques, DeBERTa adjusts the decoder during fine-tuning to better suit specific downstream tasks, thereby improving the model’s performance on those tasks. + +## Setup + +### Install + +```bash +pip3 install onnxsim +pip3 install onnx_graphsurgeon +pip3 install scikit-learn +pip3 install tqdm +pip3 install pycuda +pip3 install onnx +pip3 install tabulate +pip3 install cv2 +pip3 install pycocotools +pip3 install opencv-python==4.6.0.66 +``` + +### Download + +Pretrained model: + +Dataset: to download the squad dataset. + +or you can : +```bash +bash /scripts/prepare_model_and_dataset.sh + +``` + +### Model Conversion +Please correct the paths in the following commands or files. +```bash +tar -xvf open_deberta.tar +wget < https://github.com/bytedance/ByteMLPerf/blob/main/byte_infer_perf/general_perf/model_zoo/deberta-torch-fp32.json > +python3 torch2onnx.py --model_path deberta-base-squad.pt --output_path deberta-torch-fp32.onnx +onnxsim deberta-torch-fp32.onnx deberta-torch-fp32-sim.onnx +python3 remove_clip_and_cast.py + +``` + +## Inference + + +```bash +export ORIGIN_ONNX_NAME=/Path/deberta-sim-drop-clip-drop-invaild-cast +export OPTIMIER_FILE=/Path/ixrt/oss/tools/optimizer/optimizer.py +export PROJ_PATH=./ +``` + +### Performance + +```bash + +bash scripts/infer_deberta_fp16_performance.sh +``` + +### Accuracy + +If you want to evaluate the accuracy of this model, please visit the website: < https://github.com/yudefu/ByteMLPerf/tree/iluvatar_general_infer >, which integrates inference and training of many models under this framework, supporting the ILUVATAR backend + +```bash + +git clone https://github.com/yudefu/ByteMLPerf.git -b iluvatar_general_infer +``` + +For detailed steps regarding this model, please refer to this document: < https://github.com/yudefu/ByteMLPerf/blob/iluvatar_general_infer/byte_infer_perf/general_perf/backends/ILUVATAR/README.zh_CN.md > Note: You need to modify the relevant paths in the code to your own correct paths. + +```bash + +pip3 install -r https://github.com/yudefu/ByteMLPerf/blob/iluvatar_general_infer/byte_infer_perf/general_perf/requirements.txt +mv /ixrt/perf_engine.py /ByteMLPerf/byte_infer_perf/general_perf/core/perf_engine.py +sftp -P 29880 vipzjtd@iftp.iluvatar.com.cn 密码:123..com +get /upload/3-app/byteperf/Palak.tar +exit +tar -zxvf Palak.tar + +接着修改代码:ByteMLPerf/byte_infer_perf/general_perf/datasets/open_squad/data_loader.py +AutoTokenizer.from_pretrained("Palak/microsoft_deberta-base_squad") => AutoTokenizer.from_pretrained("/Your/Path/Palak/microsoft_deberta-base_squad") + +mv deberta-sim-drop-clip-drop-invaild-cast.onnx general_perf/model_zoo/popular/open_deberta/ +cd /ByteMLPerf/byte_infer_perf/ +mv /general_perf/general_perf/model_zoo/popular/open_deberta /general_perf/model_zoo/popular/open_deberta +cd /ByteMLPerf/byte_infer_perf/general_perf +python3 core/perf_engine.py --hardware_type ILUVATAR --task deberta-torch-fp32 +``` + +If report ModuleNotFoundError: No module named 'tensorrt_legacy',Please fix /home/xinchi.tian/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/common.py "tensorrt_legacy" to "tensorrt" + + +## Results + +Model |BatchSize |Precision |QPS |Exact Match |F1 Score +--------|-----------|----------|----------|-------------|------------ +DeBerta | 16 | FP16 | 18.58 | 73.76 | 81.24 \ No newline at end of file diff --git a/models/nlp/language_model/deberta/ixrt/perf_engine.py b/models/nlp/language_model/deberta/ixrt/perf_engine.py new file mode 100644 index 00000000..089d9860 --- /dev/null +++ b/models/nlp/language_model/deberta/ixrt/perf_engine.py @@ -0,0 +1,349 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import os +import logging +import importlib +import json +import subprocess +import time + +from typing import Any, Dict, Tuple +from prompt_toolkit.shortcuts import radiolist_dialog, input_dialog, yes_no_dialog +from prompt_toolkit.styles import Style + +BYTE_MLPERF_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.chdir(BYTE_MLPERF_ROOT) +sys.path.insert(0, BYTE_MLPERF_ROOT) + +import argparse +from general_perf.core.configs.workload_store import load_workload +from general_perf.core.configs.dataset_store import load_dataset +from general_perf.core.configs.backend_store import init_compile_backend, init_runtime_backend + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("PerfEngine") +os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' + + +def get_args(): + """Parse commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--task", + default="resnet50-tf-fp32", + help="The task going to be evaluted, refs to workloads/") + parser.add_argument( + "--hardware_type", + default="GPU", + help="The backend going to be evaluted, refs to backends/") + parser.add_argument("--compile_only", + action='store_true', + help="Run compilation only") + + args = parser.parse_args() + return args + + +class PerfEngine: + def __init__(self) -> None: + super().__init__() + self.args = get_args() + self.workload = load_workload(self.args.task) + self.backend_type = self.args.hardware_type + self.compile_backend = None + self.old_os_path = os.environ['PATH'] + self.prev_sys_path = list(sys.path) + self.real_prefix = sys.prefix + self.compile_only_mode = False + + def start_engine(self) -> None: + ''' + Byte MlPerf will create an virtual env for each backend to avoid dependance conflict + ''' + success, total = 0, len(self.workload) + if total == 0: + return + log.info("******************* Backend Env Initization *******************") + status = self.activate_venv(self.backend_type) + if not status: + log.warning("Activate virtualenv Failed, Please Check...") + + self.compile_backend = init_compile_backend(self.backend_type) + self.runtime_backend = init_runtime_backend(self.backend_type) + + output_dir = os.path.abspath('general_perf/reports/' + + self.backend_type) + os.makedirs(output_dir, exist_ok=True) + + status = self.single_workload_perf(self.workload) + + def single_workload_perf( + self, workload: Dict[str, Any]) -> bool: + log.info("******************************************* Start to test model: {}. *******************************************".format(workload['model'])) + + # Check Compile Only Mode + self.compile_only_mode = False + if self.args.compile_only or workload['compile_only']: + self.compile_only_mode = True + + base_report = { + "Model": workload['model'].upper(), + "Backend": self.backend_type, + "Host Info": self.get_cpu_name() + } + + # Initalize Model Config Info + model_info = self.get_model_info(workload['model']) + pre_compile_config = {"workload": workload, 'model_info': model_info} + interact_info = self.check_interact_info(pre_compile_config) + pre_compile_config['interact_info'] = interact_info + if not model_info['dataset_name']: + model_info['dataset_name'] = 'fake_dataset' + + + ''' + Compile Backend could do some optimization like convert model format here + ''' + log.info("******************************************* Running Backend Compilation... *******************************************") + log.info("Running Backend Preoptimization...") + pre_compile_config = self.compile_backend.pre_optimize(pre_compile_config) + + + # Initalize dataset + dataset = load_dataset(model_info) + dataset.preprocess() + base_report['Dataset'] = model_info['dataset_name'].upper( + ) if model_info['dataset_name'] else None + + #Placeholder Only + segment_info = self.compile_backend.segment(pre_compile_config) + + best_batch_sizes = self.compile_backend.get_best_batch_size() + if isinstance(best_batch_sizes, list): + pre_compile_config['workload'][ + 'batch_sizes'] = best_batch_sizes + + log.info("Start to compile the model...") + start = time.time() + compile_info = self.compile_backend.compile(pre_compile_config, + dataset) + end = time.time() + + graph_compile_report = {} + graph_compile_report["Compile Duration"] = round(end - start, 5) + graph_compile_report["Compile Precision"] = compile_info[ + 'compile_precision'] + graph_compile_report["Subgraph Coverage"] = compile_info['sg_percent'] + if 'optimizations' in compile_info: + graph_compile_report['Optimizations'] = compile_info['optimizations'] + if 'instance_count' in compile_info: + base_report['Instance Count'] = compile_info['instance_count'] + if 'device_count' in compile_info: + base_report['Device Count'] = compile_info['device_count'] + base_report['Graph Compile'] = graph_compile_report + + # Initalize Output Dir and Reports + output_dir = os.path.abspath('general_perf/reports/' + + self.backend_type + '/' + + workload['model']) + os.makedirs(output_dir, exist_ok=True) + + # Compile only mode will stop here + if self.compile_only_mode: + base_report.pop("Backend") + return compile_info["compile_status"], base_report + + # load runtime backend + """ + Start Here + """ + batch_sizes = pre_compile_config['workload']['batch_sizes'] + self.runtime_backend.configs = compile_info + self.runtime_backend.workload = workload + self.runtime_backend.model_info = model_info + + self.runtime_backend.load(workload['batch_sizes'][0]) + # test accuracy + accuracy_report = {} + AccuracyChecker = self.get_accuracy_checker( + model_info['dataset_name'] + if model_info['dataset_name'] else 'fake_dataset') + AccuracyChecker.runtime_backend = self.runtime_backend + AccuracyChecker.dataloader = dataset + AccuracyChecker.output_dir = output_dir + AccuracyChecker.configs = compile_info + + if workload['test_accuracy']: + log.info("******************************************* Running Accuracy Checker... *******************************************") + + dataset.rebatch(self.runtime_backend.get_loaded_batch_size()) + accuracy_results = AccuracyChecker.calculate_acc( + workload['data_percent']) + + accuracy_report['Data Percent'] = workload['data_percent'] + accuracy_report.update(accuracy_results) + + # test numeric + if workload['test_numeric']: + log.info("******************************************* Running Numeric Checker... *******************************************") + + dataset.rebatch(self.runtime_backend.get_loaded_batch_size()) + if not workload['test_accuracy']: + accuracy_results = AccuracyChecker.calculate_acc( + workload['data_percent']) + diff_results = AccuracyChecker.calculate_diff() + accuracy_report.update(diff_results) + # accuracy_report['Diff Dist'] = compile_info['model'] + '-to-' + compile_info['compile_precision'].lower() + ".png" + + if accuracy_report: + base_report['Accuracy'] = accuracy_report + + # function to test qps and latency + if workload['test_perf']: + log.info("******************************************* Runing QPS Checker... *******************************************") + performance_reports = [] + qs_status = self.runtime_backend.is_qs_mode_supported() + if qs_status: + qs_config = self.runtime_backend.generate_qs_config() + performance_reports = self.qs_benchmark(qs_config) + else: + for bs in batch_sizes: + self.runtime_backend.load(bs) + batch_reports = self.runtime_backend.benchmark(dataset) + performance_reports.append(batch_reports) + base_report['Performance'] = performance_reports + + if "Instance Count" not in base_report: + log.warning("Vendors need to Add # of instances") + if "Device Count" not in base_report: + log.warning("Vendors need to Add # of devices") + + # write output to json file + output_report_path = output_dir + "/result-" + compile_info['compile_precision'].lower() + ".json" + with open(output_report_path, 'w') as file: + json.dump(base_report, file, indent=4) + + base_report.pop("Backend") + log.info("Testing Finish. Report is saved in path: [ {}/{} ]". + format(output_dir[output_dir.rfind('general_perf'):], + os.path.basename(output_report_path))) + + return compile_info["compile_status"] + + #WIP + def qs_benchmark(self, qs_config: Dict[str, Any]) -> list: + return [] + + def get_accuracy_checker(self, dataset_name: str): + AccuracyChecker = importlib.import_module('general_perf.datasets.' + + dataset_name + + ".test_accuracy") + AccuracyChecker = getattr(AccuracyChecker, 'AccuracyChecker') + return AccuracyChecker() + + def get_model_info(self, model_name: str) -> Dict[str, Any]: + with open("general_perf/model_zoo/" + model_name + '.json', + 'r') as file: + model_info = json.load(file) + return model_info + + def get_cpu_name(self): + command = "lscpu | grep 'Model name' | awk -F: '{print $2}'" + cpu_name = subprocess.check_output(command, shell=True) + return cpu_name.decode().strip() + + def check_interact_info( + self, pre_compile_config: Dict[str, Dict]) -> Dict[str, Any]: + interact_info = self.compile_backend.get_interact_profile( + pre_compile_config) + + answer = {} + if len(interact_info) == 0: + return answer + + dialog_style = Style.from_dict({ + 'dialog': 'bg:#88b8ff', + 'dialog frame.label': 'bg:#ffffff #000000', + 'dialog.body': 'bg:#000000 #a0acde', + 'dialog shadow': 'bg:#004aaa', + }) + + input_style = Style.from_dict({ + 'dialog': 'bg:#88b8ff', + 'dialog frame.label': 'bg:#ffffff #000000', + 'dialog.body': 'bg:#000000 #a0acde', + 'dialog shadow': 'bg:#004aaa', + 'text-area.prompt': 'bg:#ffffff', + 'text-area': '#000000', + }) + + option = yes_no_dialog(title=self.backend_type + '编译配置', + text='[请选择]:是否进行编译后端配置:', + style=dialog_style).run() + if option: + sum_question = len(interact_info) + for i, question in enumerate(interact_info): + if question['depends']: + state = 0 + for title in question['depends'].split(','): + if not answer[title]: + state = 1 + if state: + continue + if question['dialog_type'] == 'Yes/No Dialog': + option = yes_no_dialog( + title=self.backend_type + '编译配置进度(' + str(i + 1) + + '/' + str(sum_question) + ')', + text="[Backend " + self.backend_type + "]: " + + question['note'], + style=dialog_style).run() + elif question['dialog_type'] == "Input Dialog": + option = input_dialog( + title=self.backend_type + '编译配置进度(' + str(i + 1) + + '/' + str(sum_question) + ')', + text="[Backend " + self.backend_type + "]: " + + question['note'], + style=input_style).run() + elif question['dialog_type'] == "Radiolist Dialog": + choice = [(i, text) + for i, text in enumerate(question['options'])] + num = radiolist_dialog( + title=self.backend_type + '编译配置进度(' + str(i + 1) + + '/' + str(sum_question) + ')', + text="[Backend " + self.backend_type + "]: " + + question['note'], + values=choice, + style=dialog_style).run() + option = question['options'][num] if num is not None else question[ + 'default'] + answer[question['name']] = option + + return answer + + def activate_venv(self, hardware_type: str) -> bool: + + return True + + def deactivate_venv(self): + sys.path[: + 0] = self.prev_sys_path #will also revert the added site-packages + sys.prefix = self.real_prefix + os.environ['PATH'] = self.old_os_path + + +if __name__ == "__main__": + engine = PerfEngine() + engine.start_engine() diff --git a/models/nlp/language_model/deberta/ixrt/remove_clip_and_cast.py b/models/nlp/language_model/deberta/ixrt/remove_clip_and_cast.py new file mode 100644 index 00000000..11c080ea --- /dev/null +++ b/models/nlp/language_model/deberta/ixrt/remove_clip_and_cast.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import onnx_graphsurgeon as gs +import onnx + +onnx_op_set_2_ir_version = { + 11:6, + 12:7, + 13:7, +} + +visited_add_tensor = {} +def replace_expand_values(graph, expand_node, clip_node, cast_node, sub_node, add_node): + if add_node.inputs[0].name not in visited_add_tensor: + print(add_node.inputs[0].name) + print(add_node.inputs[0].values) + add_node.inputs[0].values = add_node.inputs[0].values + 384 + add_node.inputs[0].values[add_node.inputs[0].values < 0] = 0 + add_node.inputs[0].values[add_node.inputs[0].values > 767] = 767 + print(add_node.inputs[0].values) + visited_add_tensor[add_node.inputs[0].name] = True + expand_node.inputs = [add_node.inputs[0]] + expand_node.inputs[1:] + +def replace_clip_related_nodes(graph): + node_name_to_index_map = {} + expand_node_names = [] + output_name_to_node_name_map = {} + for i, node in enumerate(graph.nodes): + node_name_to_index_map[node.name] = i + if node.op == "Expand": + expand_node_names.append(node.name) + for j in node.outputs: + output_name_to_node_name_map[j.name] = node.name + + for name in expand_node_names: + expand_node = graph.nodes[node_name_to_index_map[name]] + expand_producer_name = output_name_to_node_name_map[expand_node.inputs[0].name] + expand_producer = graph.nodes[node_name_to_index_map[expand_producer_name]] + if expand_producer.op == "Clip": + clip_node = expand_producer + clip_producer_name = output_name_to_node_name_map[clip_node.inputs[-1].name] + clip_producer = graph.nodes[node_name_to_index_map[clip_producer_name]] + if clip_producer.op == "Cast": + cast_producer_name = output_name_to_node_name_map[clip_producer.inputs[0].name] + cast_producer = graph.nodes[node_name_to_index_map[cast_producer_name]] + if cast_producer.op == "Sub": + add_node_name = output_name_to_node_name_map[clip_node.inputs[0].name] + add_node = graph.nodes[node_name_to_index_map[add_node_name]] + replace_expand_values(graph, expand_node, clip_node, clip_producer, cast_producer, add_node) + +def drop_cast_nodes(graph): + node_name_to_index_map = {} + cast_node_names = [] + output_name_to_node_name_map = {} + for i, node in enumerate(graph.nodes): + node_name_to_index_map[node.name] = i + if node.op == "Cast": + cast_node_names.append(node.name) + for j in node.outputs: + output_name_to_node_name_map[j.name] = node.name + + for name in cast_node_names: + cast_node = graph.nodes[node_name_to_index_map[name]] + cast_producer_name = output_name_to_node_name_map[cast_node.inputs[0].name] + cast_producer = graph.nodes[node_name_to_index_map[cast_producer_name]] + if cast_producer.op == "Cast": + cast_node.inputs = cast_producer.inputs + + +input_path = r"/ixrt/deberta-torch-fp32-sim.onnx" +save_path = r"/ixrt/deberta-sim-drop-clip-drop-invaild-cast.onnx" +graph = gs.import_onnx(onnx.load(input_path)) + +replace_clip_related_nodes(graph) +drop_cast_nodes(graph) + +graph.cleanup().toposort() +onnx.save(gs.export_onnx(graph), save_path) + +model = onnx.load(save_path) +model.ir_version = onnx_op_set_2_ir_version[model.opset_import[0].version] +onnx.save(model, save_path) \ No newline at end of file diff --git a/models/nlp/language_model/deberta/ixrt/scripts/infer_deberta_fp16_performance.sh b/models/nlp/language_model/deberta/ixrt/scripts/infer_deberta_fp16_performance.sh new file mode 100644 index 00000000..c9ced241 --- /dev/null +++ b/models/nlp/language_model/deberta/ixrt/scripts/infer_deberta_fp16_performance.sh @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# Start to test +set -x +ORIGIN_ONNX=${ORIGIN_ONNX_NAME}.onnx +cd ${PROJ_PATH} + +run(){ + BS=${1:-1} + TARGET_ONNX=${ORIGIN_ONNX_NAME}_end.onnx + TARGET_ENGINE=${ORIGIN_ONNX_NAME}_bs_${BS}_end.engine + if [[ ! -f "${ORIGIN_ONNX}" ]];then + echo "${ORIGIN_ONNX} not exists!" + exit 1 + fi + + # Graph optimize + python3 ${OPTIMIER_FILE} --onnx ${ORIGIN_ONNX} --dump_onnx + + # Build Engine + ixrtexec --onnx ${TARGET_ONNX} --save_engine ${TARGET_ENGINE} --log_level error --plugins ixrt_plugin --shapes input_ids.1:${BS}x384,attention_mask.1:${BS}x384\ + --min_shape input_ids.1:${BS}x384,attention_mask.1:${BS}x384 --opt_shape input_ids.1:${BS}x384,attention_mask.1:${BS}x384 --max_shape input_ids.1:${BS}x384,attention_mask.1:${BS}x384 + + # Test Performance + ixrtexec --load_engine ${TARGET_ENGINE} --shapes input_ids.1:${BS}x384,attention_mask.1:${BS}x384 --plugins ixrt_plugin + +} +run 1 \ No newline at end of file diff --git a/models/nlp/language_model/deberta/ixrt/scripts/prepare_model_and_dataset.sh b/models/nlp/language_model/deberta/ixrt/scripts/prepare_model_and_dataset.sh new file mode 100644 index 00000000..575ab8f7 --- /dev/null +++ b/models/nlp/language_model/deberta/ixrt/scripts/prepare_model_and_dataset.sh @@ -0,0 +1,34 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# #!/bin/bash +echo "******************* Downloading Model.... *******************" + +mkdir -p general_perf/model_zoo/regular +mkdir -p general_perf/model_zoo/popular +mkdir -p general_perf/model_zoo/sota +mkdir -p general_perf/download +mkdir -p datasets/open_squad/ + +wget -O general_perf/download/open_deberta.tar https://lf-bytemlperf.17mh.cn/obj/bytemlperf-zoo/open_deberta.tar +tar xf general_perf/download/open_deberta.tar -C general_perf/model_zoo/popular/ + + +# # Download Datasets +wget -O general_perf/download/open_squad.tar https://lf-bytemlperf.17mh.cn/obj/bytemlperf-zoo/open_squad.tar +tar xf general_perf/download/open_squad.tar -C datasets/open_squad/ + + +echo "Extract Done." diff --git a/models/nlp/language_model/deberta/ixrt/torch2onnx.py b/models/nlp/language_model/deberta/ixrt/torch2onnx.py new file mode 100644 index 00000000..3ef6081d --- /dev/null +++ b/models/nlp/language_model/deberta/ixrt/torch2onnx.py @@ -0,0 +1,73 @@ +# Copyright 2023 ByteDance and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import numpy as np +import torch + + +def torch_to_onnx(model_path, output_path): + model_name = output_path.split("/")[-1][:-4] + with open("/ixrt/" + model_name + "json", "r") as f: + model_info = json.load(f) + model_inputs = model_info["inputs"].split(",") + input_shapes = model_info["input_shape"] + input_type = model_info["input_type"].split(",") + example_inputs = _get_fake_samples(input_shapes, input_type) + + model = torch.jit.load(model_path, map_location=torch.device("cpu")) + model.eval() + + names = model_inputs + dynamic_inputs = {} + for i in range(len(names)): + dynamic_inputs[names[i]] = {0: "batch_size"} + outputs = model_info["outputs"].split(",") + for output in outputs: + dynamic_inputs[output] = {0: "batch_size"} + torch.onnx.export( + model, + example_inputs, + output_path, + opset_version=11, + input_names=names, + output_names=outputs, + dynamic_axes=dynamic_inputs, + ) + + +def _get_fake_samples(shape, type): + data = [] + idx = 0 + for key, val in shape.items(): + val = [val[0] * 1] + val[1:] + data.append(torch.from_numpy(np.random.random(val).astype(type[idx].lower()))) + idx += 1 + return data + + +def get_args(): + """Parse commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", default="") + parser.add_argument("--output_path", default="") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + torch_to_onnx(args.model_path, args.output_path) \ No newline at end of file -- Gitee From 10aee84186f74c00ec57d3589b60890c7d532117 Mon Sep 17 00:00:00 2001 From: tianxi-yi Date: Thu, 18 Jul 2024 09:37:03 +0800 Subject: [PATCH 2/2] update model.json path --- models/nlp/language_model/deberta/ixrt/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/nlp/language_model/deberta/ixrt/README.md b/models/nlp/language_model/deberta/ixrt/README.md index 26b4ea8f..7ee7636b 100644 --- a/models/nlp/language_model/deberta/ixrt/README.md +++ b/models/nlp/language_model/deberta/ixrt/README.md @@ -37,7 +37,7 @@ bash /scripts/prepare_model_and_dataset.sh Please correct the paths in the following commands or files. ```bash tar -xvf open_deberta.tar -wget < https://github.com/bytedance/ByteMLPerf/blob/main/byte_infer_perf/general_perf/model_zoo/deberta-torch-fp32.json > +wget python3 torch2onnx.py --model_path deberta-base-squad.pt --output_path deberta-torch-fp32.onnx onnxsim deberta-torch-fp32.onnx deberta-torch-fp32-sim.onnx python3 remove_clip_and_cast.py -- Gitee