From 4ea04e2e00a96947d9bbdee60f39294efb556a89 Mon Sep 17 00:00:00 2001 From: simson <526422051@qq.com> Date: Thu, 16 Jul 2020 14:53:12 +0800 Subject: [PATCH] modify the tutorial example: GradOperation only support tensor --- .../source_en/advanced_use/debugging_in_pynative_mode.md | 4 ++-- .../source_zh_cn/advanced_use/debugging_in_pynative_mode.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tutorials/source_en/advanced_use/debugging_in_pynative_mode.md b/tutorials/source_en/advanced_use/debugging_in_pynative_mode.md index 1142684a4e..e1a0ca07a2 100644 --- a/tutorials/source_en/advanced_use/debugging_in_pynative_mode.md +++ b/tutorials/source_en/advanced_use/debugging_in_pynative_mode.md @@ -243,7 +243,7 @@ print(z.asnumpy()) ## Debugging Network Train Model -In PyNative mode, the gradient can be calculated separately. As shown in the following example, `GradOperation` is used to calculate all input gradients of the function or the network. +In PyNative mode, the gradient can be calculated separately. As shown in the following example, `GradOperation` is used to calculate all input gradients of the function or the network. Note that the inputs have to be Tensor. **Example Code** @@ -259,7 +259,7 @@ def mul(x, y): def mainf(x, y): return C.GradOperation('get_all', get_all=True)(mul)(x, y) -print(mainf(1,2)) +print(mainf(Tensor(1, mstype.int32), Tensor(2, mstype.int32))) ``` **Output** diff --git a/tutorials/source_zh_cn/advanced_use/debugging_in_pynative_mode.md b/tutorials/source_zh_cn/advanced_use/debugging_in_pynative_mode.md index 4ab6293123..ba3ea2089a 100644 --- a/tutorials/source_zh_cn/advanced_use/debugging_in_pynative_mode.md +++ b/tutorials/source_zh_cn/advanced_use/debugging_in_pynative_mode.md @@ -245,7 +245,7 @@ print(z.asnumpy()) ## 调试网络训练模型 -PyNative模式下,还可以支持单独求梯度的操作。如下例所示,可通过`GradOperation`求该函数或者网络所有的输入梯度。 +PyNative模式下,还可以支持单独求梯度的操作。如下例所示,可通过`GradOperation`求该函数或者网络所有的输入梯度。需要注意,输入类型仅支持Tensor。 **示例代码** @@ -261,7 +261,7 @@ def mul(x, y): def mainf(x, y): return C.GradOperation('get_all', get_all=True)(mul)(x, y) -print(mainf(1,2)) +print(mainf(Tensor(1, mstype.int32), Tensor(2, mstype.int32))) ``` **输出** -- Gitee