From 8832cfd29272f80d22c32f1c754770a43c3c4969 Mon Sep 17 00:00:00 2001 From: guohuanliang Date: Wed, 22 Nov 2023 12:04:42 +0800 Subject: [PATCH 1/2] torch aie add gpt2 model --- .../built-in/nlp/GPT2_Chinese/LICENSE | 29 ++++ .../built-in/nlp/GPT2_Chinese/bin2pth.py | 70 ++++++++ .../built-in/nlp/GPT2_Chinese/compare_loss.py | 129 ++++++++++++++ .../built-in/nlp/GPT2_Chinese/pre_data.py | 112 ++++++++++++ .../built-in/nlp/GPT2_Chinese/readme.md | 161 ++++++++++++++++++ .../built-in/nlp/GPT2_Chinese/requirement.txt | 3 + 6 files changed, 504 insertions(+) create mode 100644 AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/LICENSE create mode 100644 AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/bin2pth.py create mode 100644 AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/compare_loss.py create mode 100644 AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/pre_data.py create mode 100644 AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md create mode 100644 AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/requirement.txt diff --git a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/LICENSE b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/LICENSE new file mode 100644 index 0000000000..09d493bf1f --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2017, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/bin2pth.py b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/bin2pth.py new file mode 100644 index 0000000000..3ac37d0796 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/bin2pth.py @@ -0,0 +1,70 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import transformers +import torch +import torch_aie +from torch_aie import _enums +import numpy as np +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_config', default='./config.json', type=str, required=False, + help='model config json path') + parser.add_argument('--pretrained_model', default='./model/pytorch_model.bin', + type=str, required=False, help='model checkpoint path') + parser.add_argument('--batch_size', default=1, type=int, + required=False, help='batch size') + parser.add_argument('--device', default=0, type=int, + required=False, help='npu device') + parser.add_argument('--optimization_level', default=0, type=int, + required=False, help='optimization_level') + + args = parser.parse_args() + device = args.device + batch_size = args.batch_size + optimization_level = args.optimization_level + model_config = transformers.modeling_gpt2.GPT2Config.from_json_file( + args.model_config) + model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained( + args.pretrained_model, config=model_config) + model.eval() + aie_model_path = "gpt2_bs" + str(batch_size) + ".pth" + + torch_aie.set_device(device) + + accept_size = [batch_size, 512] + dummy_input = torch.ones(accept_size).long() + with torch.inference_mode(): + jit_model = torch.jit.trace(model, dummy_input) + aie_input_spec = [torch_aie.Input( + accept_size, dtype=torch_aie.dtype.INT64),] + aie_model = torch_aie.compile( + jit_model, + inputs=aie_input_spec, + precision_policy=_enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + require_full_compilation=False, + allow_tensor_replace_int=False, + min_block_size=3, + torch_executed_ops=[], + soc_version="Ascend310P3", + optimization_level=optimization_level) + aie_model.save(aie_model_path) + +if __name__ == '__main__': + main() diff --git a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/compare_loss.py b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/compare_loss.py new file mode 100644 index 0000000000..c8d61ebc00 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/compare_loss.py @@ -0,0 +1,129 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import random +import transformers +import torch +import torch_npu +import torch_aie +import numpy as np +import argparse +import time +from torch.nn import CrossEntropyLoss + + +random.seed(0) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--device', default=0, type=int, + required=False, help='npu device id') + parser.add_argument('--tokenized_data_path', default='data/tokenized_eval/', type=str, required=False, + help='tokenized语料存放位置') + parser.add_argument('--batch_size', default=1, type=int, + required=False, help='batch size') + parser.add_argument('--log_step', default=100, type=int, + required=False, help='多少步汇报一次') + parser.add_argument('--n_ctx', default=512, type=int, + required=False, help='文字长度') + parser.add_argument('--stride', default=768, type=int, + required=False, help='取数据的窗口步长') + parser.add_argument('--num_pieces', default=100, + type=int, required=False, help='将训练语料分成多少份') + + args = parser.parse_args() + device_id = args.device + tokenized_data_path = args.tokenized_data_path + batch_size = args.batch_size + log_step = args.log_step + stride = args.stride + num_pieces = args.num_pieces + n_ctx = args.n_ctx + + torch.npu.set_device(device_id) + aie_model_path = "gpt2_bs" + str(batch_size) + ".pth" + if not os.path.exists(aie_model_path): + print('aie model path not exist!') + exit() + aie_model = torch.jit.load(aie_model_path).eval() + + total_loss = 0 + total_steps = 0 + # eval + piece_num = 0 + modeltime = [] + for i in range(num_pieces): + with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: + line = f.read().strip() + tokens = line.split() + tokens = [int(token) for token in tokens] + start_point = 0 + samples = [] + while start_point < len(tokens) - n_ctx: + samples.append(tokens[start_point: start_point + n_ctx]) + start_point += stride + start_point -= stride + random.shuffle(samples) + for step in range(len(samples) // batch_size): # drop last + # prepare data + batch = samples[step * batch_size: (step + 1) * batch_size] + batch_labels = [] + batch_inputs = [] + for ids in batch: + int_ids_for_labels = [int(x) for x in ids] + int_ids_for_inputs = [int(x) for x in ids] + batch_labels.append(int_ids_for_labels) + batch_inputs.append(int_ids_for_inputs) + batch_labels = np.array(batch_labels).astype(np.int64) + batch_inputs = np.array(batch_labels).astype(np.int64) + # forward pass + with torch.inference_mode(): + inputs_npu = torch.from_numpy(batch_inputs).npu() + torch.npu.synchronize() + start = time.time() + output = aie_model(inputs_npu) + torch.npu.synchronize() + end = time.time() + modeltime.append(end - start) + lm_logits = output.cpu() + # get loss + shift_logits = lm_logits[..., :-1, :].contiguous().float() + labels = torch.from_numpy(batch_labels) + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + total_loss += loss + total_steps += 1 + + if total_steps % log_step == 0: + print('[INFO] Step {} of piece {}, ppl {}, step time {}.'.format( + (step + 1), + piece_num, + torch.exp(loss), + end - start)) + piece_num += 1 + + print("BatchSize = {}, QPS = {}.".format(batch_size, + batch_size * len(modeltime) / sum(modeltime))) + print("PPL = {}.".format(np.exp(total_loss.detach().numpy() / total_steps))) + + +if __name__ == '__main__': + main() diff --git a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/pre_data.py b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/pre_data.py new file mode 100644 index 0000000000..e8b01620a4 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/pre_data.py @@ -0,0 +1,112 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +from tqdm import tqdm +import argparse +import transformers + +from tokenizations import tokenization_bert + + +def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length): + if not os.path.exists(tokenized_data_path): + os.mkdir(tokenized_data_path) + with open(data_path, 'r', encoding='utf8') as f: + print('reading lines') + lines = json.load(f) + lines = [line.replace('\n', ' [SEP] ') + for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 + all_len = len(lines) + for i in tqdm(range(num_pieces)): + sublines = lines[all_len // num_pieces * + i: all_len // num_pieces * (i + 1)] + if i == num_pieces - 1: + # 把尾部例子添加到最后一个piece + sublines.extend(lines[all_len // num_pieces * (i + 1):]) + sublines = [full_tokenizer.tokenize(line) for line in sublines if + len(line) > min_length] # 只考虑长度超过min_length的句子 + sublines = [full_tokenizer.convert_tokens_to_ids( + line) for line in sublines] + full_line = [] + for subline in sublines: + full_line.append(full_tokenizer.convert_tokens_to_ids( + '[MASK]')) # 文章开头添加MASK表示文章开始 + full_line.extend(subline) + full_line.append(full_tokenizer.convert_tokens_to_ids( + '[CLS]')) # 文章之间添加CLS表示文章结束 + with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f: + for id in full_line: + f.write(str(id) + ' ') + print('finish') + + +def prepare_data(data_path, save_dir): + + data = [] + pre_path = os.listdir(data_path) + for mid_path in pre_path: + path_ = os.path.join(data_path, mid_path) + re_path = os.listdir(path_) + for pp in re_path: + p_ = os.path.join(path_, pp) + with open(p_, 'r', encoding='utf8') as f: + lines = f.readlines() + for line in lines: + data.append(json.loads(line)['text']) + break + with open(save_dir, 'w') as f: + json.dump(data, f) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_config', default='./config.json', type=str, required=False, + help='选择模型参数') + parser.add_argument('--raw_data_path', default='./data/wiki_zh', + type=str, required=False, help='原始语料') + parser.add_argument('--data_json_path', default='./eval.json', + type=str, required=False, help='原始语料') + parser.add_argument('--tokenized_data_path', default='data/tokenized_eval/', type=str, required=False, + help='tokenized语料存放位置') + parser.add_argument('--num_pieces', default=100, + type=int, required=False, help='将训练语料分成多少份') + parser.add_argument('--tokenizer_path', default='./vocab.txt', + type=str, required=False, help='选择词库') + parser.add_argument('--min_length', default=128, + type=int, required=False, help='最短收录文章长度') + + args = parser.parse_args() + + model_config = transformers.modeling_gpt2.GPT2Config.from_json_file( + args.model_config) + n_ctx = model_config.n_ctx + full_tokenizer = tokenization_bert.BertTokenizer( + vocab_file=args.tokenizer_path) + full_tokenizer.max_len = n_ctx + raw_data_path = args.raw_data_path + data_json_path = args.data_json_path + tokenized_data_path = args.tokenized_data_path + num_pieces = args.num_pieces + min_length = args.min_length + + prepare_data(raw_data_path, data_json_path) + build_files(data_path=data_json_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces, + full_tokenizer=full_tokenizer, min_length=min_length) + + +if __name__ == '__main__': + main() diff --git a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md new file mode 100644 index 0000000000..ae3f61a9dc --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md @@ -0,0 +1,161 @@ +# GPT2 Chinese模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + ****** + + + + + +# 概述 + +GPT-2 模型只使用了多个Masked Self-Attention和Feed Forward Neural Network,并且由多层单向Transformer的解码器构成,本质上是一个自回归模型。其中自回归的意思是指,每次产生新单词后,将新单词加到原输入句后面,作为新的输入句。而单向是指只会考虑在待预测词位置左侧的词对待预测词的影响。 + + +- 参考实现: + + ``` + url=https://github.com/Morizeyao/GPT2-Chinese + commit_id=bbb44651be8361faef35d2a857451d231b5ebe14 + ``` + +> 说明:所有脚本都在GPT2的仓下运行 + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | input_ids | int64 | batchsize x 512 | ND | + + +- 输出数据 + + | 输出数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | output | FLOAT16 | batchsize x 512 x 21128 | ND | + + + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ | + | 固件与驱动 | 23.0.RC3 | - | + | CANN | 7.0.RC1 | - | + | Python | 3.9.0 | - | + | PyTorch | 2.0.1 | - | + | AscendIE | 6.3.RC2 | - | + | Torch AIE | 6.3.RC2 | - | + + +# 快速上手 + +## 获取源码 + +1. 获取源码。 + + ``` + git clone https://github.com/Morizeyao/GPT2-Chinese + cd GPT2-Chinese + git reset --hard bbb44651be8361faef35d2a857451d231b5ebe14 + ``` + +2. 获取模型checkpoint文件和配置文件 + + 在模型根目录GPT2-Chinese下创建model文件夹。 + + 从[这里](https://pan.baidu.com/s/16x0hfBCekWju75xPeyyRfA#list/path=%2F)下载配置文件,提取码`n3s8`,并把`pytorch_model.bin`放到`model`文件夹下,vocab.txt和config.json文件放到模型根目录GPT2-Chinese下。 + +3. 将bin2pth.py、pre_data.py、compare_loss.py、requirement.txt拷贝到源码根目录下,并安装依赖 + + ``` + pip3 install -r requirement.txt + ``` + +4. 修改源码使模型仅返回lm_logits + 查看transformers的安装路径: + ``` + pip3 show transformers + ``` + 根据Location位置,修改源码第549行: + ``` + vim ${Location}/transformers/modeling_gpt2.py + ``` + + 改为: + ``` + return lm_logits + ``` + +## 准备数据集 + +1. 获取原始数据集。(解压命令参考tar –xvf \*.tar与 unzip \*.zip) +> 提示:请遵循数据集提供方要求使用。 + 本模型支持wiki_zh_2019验证集。用户需自行获取[数据集](https://pan.baidu.com/share/init?surl=22sax9QujO8SUdV3jH5mTQ),提取码`xv7e`。将解压后的数据放在data下,其目录结构如下: + + ``` + data + └── wiki_zh + ``` + +2. 数据预处理,将原始数据集转换为模型输入的数据。 + + ``` + python3 pre_data.py + ``` + 结果保存在`data/tokenized_eval` + +## 模型推理 + +1. 模型编译。 + + 使用torch aie将模型权重文件pytorch_model.bin转换为.pt文件。 + ``` + python3 bin2pth.py --batch_size=1 + ``` + 如果环境为第一次运行,可尝试使用aoe进行调优,参考如下: + ``` + python3 bin2pth.py --batch_size=1 --optimization_level=1 + python3 bin2pth.py --batch_size=1 --optimization_level=2 + ``` + +2. 开始推理验证。 + ``` + python3 compare_loss.py --batch_size=1 + ``` + +# 模型推理性能&精度 + +调用ACL接口推理计算,性能参考下列数据。 + +| 芯片型号 | Batch Size | 数据集 | 精度指标(Loss)| 性能 | +| :------: | :--------: | :----: | :--: | :--: | +| 310P3 | 1 | wiki_zh_2019 | 15.44 | 110 | +| 310P3 | 4 | wiki_zh_2019 | 15.44 | 99 | +| 310P3 | 8 | wiki_zh_2019 | 15.47 | 107 | +| 310P3 | 16 | wiki_zh_2019 | 15.50 | 89 | +| 310P3 | 32 | wiki_zh_2019 | | | +| 310P3 | 64 | wiki_zh_2019 | | | + +> 注:衡量精度的指标为验证集平均交叉熵损失(Cross-Entropy Loss),数值越低越好。 \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/requirement.txt b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/requirement.txt new file mode 100644 index 0000000000..71b0eb33f4 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/requirement.txt @@ -0,0 +1,3 @@ +transformers==2.1.1 +numpy==1.21.6 +tqdm \ No newline at end of file -- Gitee From 18afa2471a825367ddbf73a6446fd4c1bb9d0cf4 Mon Sep 17 00:00:00 2001 From: guohuanliang Date: Mon, 11 Dec 2023 22:26:18 +0800 Subject: [PATCH 2/2] prove gpt2 performance --- .../built-in/nlp/GPT2_Chinese/readme.md | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md index ae3f61a9dc..8a83fa73b9 100644 --- a/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md +++ b/AscendIE/TorchAIE/built-in/nlp/GPT2_Chinese/readme.md @@ -105,7 +105,7 @@ GPT-2 模型只使用了多个Masked Self-Attention和Feed Forward Neural Networ 改为: ``` - return lm_logits + return lm_logits.to(torch.half) ``` ## 准备数据集 @@ -132,30 +132,33 @@ GPT-2 模型只使用了多个Masked Self-Attention和Feed Forward Neural Networ 使用torch aie将模型权重文件pytorch_model.bin转换为.pt文件。 ``` + export ASCENDIE_FASTER_MODE=1 python3 bin2pth.py --batch_size=1 ``` - 如果环境为第一次运行,可尝试使用aoe进行调优,参考如下: + + 其中设置ASCENDIE_FASTER_MODE=1,是为了使用FastGelu算子,提升性能。 + + 如果环境为第一次运行该模型或者实测性能与下方表格数据差距较大,可尝试使用aoe进行调优,参考如下: ``` python3 bin2pth.py --batch_size=1 --optimization_level=1 python3 bin2pth.py --batch_size=1 --optimization_level=2 - ``` + ``` 2. 开始推理验证。 ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=8 python3 compare_loss.py --batch_size=1 ``` # 模型推理性能&精度 -调用ACL接口推理计算,性能参考下列数据。 - | 芯片型号 | Batch Size | 数据集 | 精度指标(Loss)| 性能 | | :------: | :--------: | :----: | :--: | :--: | -| 310P3 | 1 | wiki_zh_2019 | 15.44 | 110 | -| 310P3 | 4 | wiki_zh_2019 | 15.44 | 99 | -| 310P3 | 8 | wiki_zh_2019 | 15.47 | 107 | -| 310P3 | 16 | wiki_zh_2019 | 15.50 | 89 | -| 310P3 | 32 | wiki_zh_2019 | | | -| 310P3 | 64 | wiki_zh_2019 | | | +| 310P3 | 1 | wiki_zh_2019 | 15.6 | 125 | +| 310P3 | 4 | wiki_zh_2019 | 15.6 | 128 | +| 310P3 | 8 | wiki_zh_2019 | 15.7 | 116 | +| 310P3 | 16 | wiki_zh_2019 | 15.7 | 123 | +| 310P3 | 32 | wiki_zh_2019 | 15.7 | 119 | +| 310P3 | 64 | wiki_zh_2019 | 15.8 | 115 | > 注:衡量精度的指标为验证集平均交叉熵损失(Cross-Entropy Loss),数值越低越好。 \ No newline at end of file -- Gitee