From f048946f2d4f1622188eb19d240d9ef30d26b70f Mon Sep 17 00:00:00 2001 From: jackeyGG <1209567334@qq.com> Date: Wed, 26 Oct 2022 07:14:32 +0000 Subject: [PATCH] =?UTF-8?q?update=20fastNLP/core/tester.py.=20=E5=BC=80?= =?UTF-8?q?=E5=8F=91=E8=80=85=E5=A5=BD=EF=BC=8C=E6=88=91=E5=9C=A8=E5=9C=A8?= =?UTF-8?q?0.7.0=E4=BD=BF=E7=94=A8=E8=BF=87=E7=A8=8B=E4=B8=AD=E5=8F=91?= =?UTF-8?q?=E7=8E=B0=E4=BA=86=E4=B8=80=E4=BA=9B=E8=87=AA=E5=B7=B1=E5=86=99?= =?UTF-8?q?=E7=9A=84=E7=BD=91=E7=BB=9C=EF=BC=8C=E5=9C=A8=E6=A1=86=E6=9E=B6?= =?UTF-8?q?=E4=B8=AD=E6=98=AF=E6=B2=A1=E6=9C=89=E6=89=BE=E5=88=B0=E9=A2=84?= =?UTF-8?q?=E6=B5=8B=E5=87=BD=E6=95=B0=E7=9A=84=EF=BC=8C=E5=8F=AA=E5=8D=95?= =?UTF-8?q?=E7=8B=AC=E8=BF=94=E5=9B=9E=E9=A2=84=E6=B5=8B=E7=BB=93=E6=9E=9C?= =?UTF-8?q?=E3=80=82=E5=9B=A0=E6=AD=A4=E5=9C=A8tester=E5=88=86=E6=94=AF?= =?UTF-8?q?=E4=B8=AD=E5=8A=A0=E5=85=A5=E4=BA=86=E4=B8=80=E4=B8=AA=E5=8F=AA?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E9=A2=84=E6=B5=8B=E7=BB=93=E6=9E=9C=E7=9A=84?= =?UTF-8?q?=E5=87=BD=E6=95=B0=EF=BC=8Cflp=5Ftopredict=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E8=80=83=E8=99=91=E5=88=B0=E4=BA=86=E4=B8=8D=E5=90=8Cdivice?= =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index cb05f82d..d7284dba 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -248,3 +248,52 @@ class Tester(object): _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) _str += '\n' return _str[:-1] + + def flp_topredict(self): + r"""开始进行预测,并返回预测结果。 + + :return 本次的预测结果,为一个字典,其中只有{predict}一个key,而key的值类型为tensor。 + """ + # turn on the testing mode; clean up the history + self._model_device = _get_model_device(self._model) + network = self._model + self._mode(network, is_test=True) + data_iterator = self.data_iterator + eval_results = [] + try: + with torch.no_grad(): + if not self.use_tqdm: + from .utils import _pseudo_tqdm as inner_tqdm + else: + inner_tqdm = tqdm + with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: + pbar.set_description_str(desc="Pred") + + start_time = time.time() + + for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, + non_blocking=self.pin_memory) + with self.auto_cast(): + pred_dict = self._data_forward(self._predict_func, batch_x) + + eval_results.extend(pred_dict['predict'].detach().cpu().numpy()) + + if self.use_tqdm: + pbar.update() + + pbar.close() + end_time = time.time() + test_str = f'Predict data in {round(end_time - start_time, 2)} seconds!' + if self.verbose >= 0: + self.logger.info(test_str) + except _CheckError as e: + prev_func_signature = _get_func_signature(self._predict_func) + _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, + check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, + dataset=self.data, check_level=0) + finally: + self._mode(network, is_test=False) + print(f'预测完成') + + return eval_results \ No newline at end of file -- Gitee