diff --git a/docs/lite/api/source_en/api_cpp/mindspore.rst b/docs/lite/api/source_en/api_cpp/mindspore.rst index e760aac4ae98ddd16a62a22fdd76ea2f38af9951..039763dfe51e12b05875ad99a4cfe8f26cd8e672 100644 --- a/docs/lite/api/source_en/api_cpp/mindspore.rst +++ b/docs/lite/api/source_en/api_cpp/mindspore.rst @@ -131,6 +131,10 @@ Classes - :doc:`../generate/classmindspore_DelegateModel` +- :doc:`../generate/classmindspore_AbstractDelegate` + +- :doc:`../generate/classmindspore_IDelegate` + - :doc:`../generate/classmindspore_DepComputer` - :doc:`../generate/classmindspore_DeviceEvent` diff --git a/docs/lite/api/source_zh_cn/api_cpp/mindspore.md b/docs/lite/api/source_zh_cn/api_cpp/mindspore.md index 3ee054c892d91437fbf025620cbe2f8d1227cbc5..4ddfd3e47d96b4c8a66adaeddf14c8a653350873 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/mindspore.md +++ b/docs/lite/api/source_zh_cn/api_cpp/mindspore.md @@ -107,6 +107,8 @@ | [Cell](#cell) | 容器类。 | ✕ | √ | | [GraphCell](#graphcell) | 图容器类。 | ✕ | √ | | [Graph](#graph) | 图类。 | ✕ | √ | +| [AbstractDelegate](#abstractdelegate) | MindSpore Lite接入代理(抽象类)。 | √ | ✕ | +| [IDelegate](#idelegate) | MindSpore Lite接入代理(模板类)。 | √ | ✕ | ## Context @@ -915,6 +917,7 @@ Model() | [bool GetTrainMode() const](#gettrainmode) | ✕ | √ | | [Status Train(int epochs, std::shared_ptr< dataset::Dataset> ds, std::vector cbs)](#train) | ✕ | √ | | [Status Evaluate(std::shared_ptr< dataset::Dataset> ds, std::vector cbs)](#evaluate) | ✕ | √ | +| [Status Finalize()](#finalize) | √ | √ | #### Build @@ -1589,6 +1592,18 @@ Status UpdateWeights(const std::vector &new_weights) 状态码。 +#### Finalize + +```cpp +Status Finalize(); +``` + +模型终止。 + +- 返回值 + + 状态码。 + ## MSTensor \#include <[types.h](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/include/api/types.h)> @@ -2233,6 +2248,38 @@ Delegate在线构图。 状态码类`Status`对象,可以使用其公有函数`StatusCode`或`ToString`函数来获取具体错误码及错误信息。 +#### CreateKernel + +```cpp +std::shared_ptr CreateKernel(const std::shared_ptr &node) override; +``` + +创建Kernel。 + +- 参数 + + - `node`: 指向Kernel[Kernel]实例的共享指针。 + +- 返回值 + + Kernel类共享指针。 + +#### IsDelegateNode + +```cpp +bool IsDelegateNode(const std::shared_ptr &node) override { return false; } +``` + +是否是Delegate节点。 + +- 参数 + + - `node`: 指向Kernel[Kernel]实例的共享指针。 + +- 返回值 + + bool值。 + ## CoreMLDelegate \#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/include/api/delegate.h)> @@ -2457,6 +2504,104 @@ const SchemaVersion GetVersion() { return version_; } **enum**值,0: r1.2及r1.2之后的版本,1: r1.1及r1.1之前的版本,-1: 无效版本。 +## AbstractDelegate + +\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/include/api/delegate.h)> + +`AbstractDelegate`定义了MindSpore Lite 创建Delegate(抽象类)。 + +### 构造函数 + +```cpp +AbstractDelegate(); +AbstractDelegate(const std::vector &inputs, const std::vector &outputs) + : inputs_(inputs), outputs_(outputs) {} +``` + +### 析构函数 + +```cpp +virtual ~AbstractDelegate() = default; +``` + +### 公有成员函数 + +#### inputs + +```cpp +std::vector &inputs() { return this->inputs_; } +``` + +返回AbstractDelegate的inputTensor。 + +#### outputs + +```cpp +const std::vector &outputs() { return this->outputs_; } +``` + +返回AbstractDelegate的outputTensor。 + +### 保护成员变量 + +#### inputs_ + +```cpp +std::vector inputs_; +``` + +#### outputs_ + +## IDelegate + +```cpp +std::vector outputs_; +``` + +\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/include/api/delegate.h)> + +`IDelegate`定义了MindSpore Lite 创建Delegate(模板类)。 + +### 构造函数 + +```cpp +IDelegate(); +IDelegate(const std::vector &inputs, const std::vector &outputs) + : AbstractDelegate(inputs, outputs) {} +``` + +### 析构函数 + +```cpp +virtual ~IDelegate() = default; +``` + +### 公有成员函数 + +#### ReplaceNodes + +```cpp +virtual void ReplaceNodes(const std::shared_ptr &graph) = 0; +``` + +替换Delegate节点。 + +#### IsDelegateNode + +```cpp +virtual bool IsDelegateNode(const std::shared_ptr &node) = 0; +``` + +判断节点是否属于Delegate。 + +#### CreateKernel + +```cpp +virtual std::shared_ptr CreateKernel(const std::shared_ptr &node) = 0; +``` + +创建Kernel。 + ## TrainCfg \#include <[cfg.h](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/include/api/cfg.h)> @@ -3509,6 +3654,54 @@ static inline std::string CodeAsString(enum StatusCode c) 指向副本的指针。 +#### Construct + +```cpp +virtual std::vector Construct(const std::vector &inputs) { return {}; } +``` + +构造一份CellBase。 + +- 参数 + + Input组成的vector。 + +- 返回值 + + Output组成的vector。 + +#### Run + +```cpp +virtual Status Run(const std::vector &inputs, std::vector *outputs) { return kSuccess; } +``` + +运行CellBase。 + +- 参数 + + Input组成的vector,Output组成的vector。 + +- 返回值 + + 状态码。 + +#### operator() + +```cpp +std::vector operator()(const std::vector &inputs) const; +``` + +括号运行符。 + +- 参数 + + Input组成的vector。 + +- 返回值 + + Output组成的vector。 + ## GraphCell \#include <[cell.h](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/include/api/cell.h)> @@ -3537,6 +3730,74 @@ static inline std::string CodeAsString(enum StatusCode c) 指向Graph的指针。 +#### SetContext + +```cpp +void SetContext(const std::shared_ptr &context); +``` + +设置Context。 + +- 参数 + + 指向Context[Context]实例的共享指针。 + +#### Run + +```cpp +Status Run(const std::vector &inputs, std::vector *outputs) override; +``` + +运行。 + +- 参数 + + [MSTensor]构成的inputs, outputs vector。 + +- 返回值 + + 状态码。 + +#### GetInputs + +```cpp +std::vector GetInputs(); +``` + +获取输入。 + +- 返回值 + + [MSTensor]构成的vector。 + +#### GetOutputs + +```cpp +std::vector GetOutputs(); +``` + +获取输出。 + +- 返回值 + + [MSTensor]构成的vector。 + +#### Load + +```cpp +Status Load(uint32_t device_id); +``` + +加载。 + +- 参数 + + device_id(芯片编号) + +- 输出 + + 状态码。 + ## RunnerConfig \#include <[model_parallel_runner.h](https://gitee.com/mindspore/mindspore/blob/v2.7.0-rc1/include/api/model_parallel_runner.h)>