diff --git a/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/.keep b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/main.py b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/main.py new file mode 100644 index 0000000000000000000000000000000000000000..cc49fdea53c699e62154b172eb777e1b9d3d49f1 --- /dev/null +++ b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/main.py @@ -0,0 +1,117 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# @Time : 2020/10/20 11:03 下午 +# @File : main.py +# @Software: PyCharm +import argparse +import os +from collections import OrderedDict + +import tensorflow as tf +import tensorflow.compat.v1 as tf1 + +# 在使用t5时需要使用 tensorflow_text 注册一些算子,不导入该模块会有问题 +import tensorflow_text + +# 在推理代码中,虽并未明文使用,但是不加这一句会在代码格式化时因模块未使用而被优化掉,故加上这一句更好 +_ = tensorflow_text + + +def get_config(): + import npu_device + from npu_device.compat.v1.npu_init import RewriterConfig + npu_device.compat.enable_v1() + config = tf1.ConfigProto() + custom_op = config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = 'NpuOptimizer' + config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # 必须显式关闭 + config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF # 必须显式关闭 + return config + + +class T5: + def __init__(self, pb_file, config): + graph = tf.Graph() + with graph.as_default(): + with tf.io.gfile.GFile(pb_file, 'rb') as f: + graph_def = tf1.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + self.graph = graph + self.config = config + + def predict(self, questions): + results = OrderedDict() + with tf1.Session(graph=self.graph, config=self.config) as sess: + out_nodes = ['strided_slice_1:0', 'strided_slice_2:0'] + input_node = self.graph.get_tensor_by_name('inputs:0') + for question in questions: + feed_dict = {input_node: [question]} + results[question] = sess.run(out_nodes, feed_dict=feed_dict) + return results + + def answer(self, questions): + return OrderedDict([(k, v[1][0].decode('utf8')) for k, v in self.predict(questions).items()]) + + +def parse_args(): + parser = argparse.ArgumentParser(description='t5 runner') + parser.add_argument('-m', '--model', '--model-path', dest='model_path', type=str, default='model/t5.pb', + help='model路径') + parser.add_argument('-f', '--file', dest='question_file', type=str, default='', + help='问题所在文本文件,以\\n分隔,如果此参数有值,则优先使用此参数') + parser.add_argument(nargs='*', dest='questions', default=list(), help='问题列表') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if os.path.exists(args.question_file): + with open(args.question_file, encoding='utf8') as f: + questions = [question for question in f.read().split('\n') if question] + else: + questions = args.questions + + assert questions, '请提供问题' + + t5 = T5(args.model_path, config=get_config()) + result_dict = t5.answer(questions) + print('-' * 50) + for question, answer in result_dict.items(): + print(f"{question}:\n {answer}") + print('=' * 50) + + +if __name__ == '__main__': + main() diff --git a/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/model/.keep b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/model/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/questions.txt b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/questions.txt new file mode 100644 index 0000000000000000000000000000000000000000..7677936e0eff631507c5996ac1a1df1470047f77 --- /dev/null +++ b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/questions.txt @@ -0,0 +1,4 @@ +nq question: where is google's headquarters +nq question: what is the most populous country in the world +nq question: name a member of the beatles +nq question: how many teeth do humans have \ No newline at end of file diff --git a/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/readme.md b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..8f735baaa3aabb753923dcc3c03ddcc935578e8f --- /dev/null +++ b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/readme.md @@ -0,0 +1,74 @@ +中文|[英文](readme_en.md) + +# TextToTextTransferTransformer Tensorflow 2.x 在线推理 +> 此链接提供 TextToTextTransferTransformer TensorFlow 2.x pb模型在NPU上在线推理的脚本和方法 + +# 注意 +> 此案例仅为您学习Ascend软件栈提供学习参考,不用于商业目的。 + +在开始之前,请注意以下适配条件。如果不匹配,可能导致运行失败 + +|依赖|要求| +|---|---| +|CANN 版本|>=6.0.0| +|芯片平台|Ascend310/Ascend310P3| +|第三主依赖|请参考 requirements.txt| + +# 使用步骤 + +## 1. 拷贝代码 + +```shell +git clone https://gitee.com/ascend/ModelZoo-TensorFlow.git +cd ModelZoo-TensorFlow/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X +``` + +## 2. 下载pb模型 +1. [下载pb模型](https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/model/2022-12-12_tf/t5_tf2_online_inference/t5.pb) +2. 将pb模型放置在任意位置,建议路径为 ./model/t5.pb +3. 目录结构如下 +``` +TextToTextTransferTransformer_ID4150_for_TensorFlow2.x +|---model +|---|---t5.pb +|---main.py +|---questions.txt +|---readme.md +|---readme_en.md +|---requirements.txt +``` + +## 3. 在线推理 + +> 请准备好输入数据,即文本问题,可以是txt文件,也可以是字符串,以下两种方法皆可 + +```python3.7 +python3 main.py -f questions.txt +python3 main.py 'question1' 'question2' +``` + +## 性能结果 +本结果是通过运行上边适配的推理脚本获得的。 + +1. gpu 结果 + +|问题|结果| +|---|---| +|nq question: where is google's headquarters|in Columbus, Ohio| +|nq question: what is the most populous country in the world|China| +|nq question: name a member of the beatles|Harrison| +|nq question: how many teeth do humans have|20 primary| + +1. npu 结果 + +|问题|结果| +|---|---| +|nq question: where is google's headquarters|in Columbus, Ohio| +|nq question: what is the most populous country in the world|China| +|nq question: name a member of the beatles|Harrison| +|nq question: how many teeth do humans have|20 primary| + + +## 参考 +[谷歌t5网络 github地址](https://github.com/google-research/text-to-text-transfer-transformer) + diff --git a/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/readme_en.md b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/readme_en.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/requirements.txt b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..81743e7f5bdb567b6feaa8640775117335e696c7 --- /dev/null +++ b/TensorFlow2/built-in/nlp/TextToTextTransferTransformer_ID4150_for_TensorFlow2.X/requirements.txt @@ -0,0 +1,3 @@ +tensorflow>=2.6.0 +tensorflow-text==tensorflow>=2.6.0 +npu-device CANN平台提供 \ No newline at end of file