diff --git a/CONTRIBUTING_DOC.md b/CONTRIBUTING_DOC.md index 8dd5a1a04af234257266de4f32a9d615909ca7d7..9b5d81614c654701a45af80966010ff140efbfb8 100644 --- a/CONTRIBUTING_DOC.md +++ b/CONTRIBUTING_DOC.md @@ -91,7 +91,7 @@ If you want to update an existing API, find the source file of the A If you do not know the file link, click **source** and find the file link by referring to the content following `_modules` in the link. -Take Tensor as an example. After clicking **source**, you can obtain the link . Then, the source file link is . +Take Tensor as an example. After clicking **source**, you can obtain the source file link . ![API Source](./resource/_static/api_source.png) diff --git a/CONTRIBUTING_DOC_CN.md b/CONTRIBUTING_DOC_CN.md index 018d6cd80f33386c9689beb1ad8cb24b068aefaf..cd0ae0e36a229027b8effd6725ed340a3ece42a2 100644 --- a/CONTRIBUTING_DOC_CN.md +++ b/CONTRIBUTING_DOC_CN.md @@ -91,7 +91,7 @@ MindSpore docs仓提供了[API注释写作要求](https://gitee.com/mindspore/do 如果不清楚所在文件,可点击“source”,并参考跳转的链接地址中`_modules`后的内容,找到该文件。 -以Tensor为例,点击“source”后得到地址,源文件地址即为。 +以Tensor为例,点击“source”后得到源文件地址为。 ![API Source](./resource/_static/api_source.png) diff --git a/LICENSE-CC-BY-4.0 b/LICENSE-CC-BY-4.0 deleted file mode 100644 index 2f244ac814036ecd9ba9f69782e89ce6b1dca9eb..0000000000000000000000000000000000000000 --- a/LICENSE-CC-BY-4.0 +++ /dev/null @@ -1,395 +0,0 @@ -Attribution 4.0 International - -======================================================================= - -Creative Commons Corporation ("Creative Commons") is not a law firm and -does not provide legal services or legal advice. Distribution of -Creative Commons public licenses does not create a lawyer-client or -other relationship. Creative Commons makes its licenses and related -information available on an "as-is" basis. Creative Commons gives no -warranties regarding its licenses, any material licensed under their -terms and conditions, or any related information. Creative Commons -disclaims all liability for damages resulting from their use to the -fullest extent possible. - -Using Creative Commons Public Licenses - -Creative Commons public licenses provide a standard set of terms and -conditions that creators and other rights holders may use to share -original works of authorship and other material subject to copyright -and certain other rights specified in the public license below. The -following considerations are for informational purposes only, are not -exhaustive, and do not form part of our licenses. - - Considerations for licensors: Our public licenses are - intended for use by those authorized to give the public - permission to use material in ways otherwise restricted by - copyright and certain other rights. Our licenses are - irrevocable. Licensors should read and understand the terms - and conditions of the license they choose before applying it. - Licensors should also secure all rights necessary before - applying our licenses so that the public can reuse the - material as expected. Licensors should clearly mark any - material not subject to the license. This includes other CC- - licensed material, or material used under an exception or - limitation to copyright. More considerations for licensors: - wiki.creativecommons.org/Considerations_for_licensors - - Considerations for the public: By using one of our public - licenses, a licensor grants the public permission to use the - licensed material under specified terms and conditions. If - the licensor's permission is not necessary for any reason--for - example, because of any applicable exception or limitation to - copyright--then that use is not regulated by the license. Our - licenses grant only permissions under copyright and certain - other rights that a licensor has authority to grant. Use of - the licensed material may still be restricted for other - reasons, including because others have copyright or other - rights in the material. A licensor may make special requests, - such as asking that all changes be marked or described. - Although not required by our licenses, you are encouraged to - respect those requests where reasonable. More_considerations - for the public: - wiki.creativecommons.org/Considerations_for_licensees - -======================================================================= - -Creative Commons Attribution 4.0 International Public License - -By exercising the Licensed Rights (defined below), You accept and agree -to be bound by the terms and conditions of this Creative Commons -Attribution 4.0 International Public License ("Public License"). To the -extent this Public License may be interpreted as a contract, You are -granted the Licensed Rights in consideration of Your acceptance of -these terms and conditions, and the Licensor grants You such rights in -consideration of benefits the Licensor receives from making the -Licensed Material available under these terms and conditions. - - -Section 1 -- Definitions. - - a. Adapted Material means material subject to Copyright and Similar - Rights that is derived from or based upon the Licensed Material - and in which the Licensed Material is translated, altered, - arranged, transformed, or otherwise modified in a manner requiring - permission under the Copyright and Similar Rights held by the - Licensor. For purposes of this Public License, where the Licensed - Material is a musical work, performance, or sound recording, - Adapted Material is always produced where the Licensed Material is - synched in timed relation with a moving image. - - b. Adapter's License means the license You apply to Your Copyright - and Similar Rights in Your contributions to Adapted Material in - accordance with the terms and conditions of this Public License. - - c. Copyright and Similar Rights means copyright and/or similar rights - closely related to copyright including, without limitation, - performance, broadcast, sound recording, and Sui Generis Database - Rights, without regard to how the rights are labeled or - categorized. For purposes of this Public License, the rights - specified in Section 2(b)(1)-(2) are not Copyright and Similar - Rights. - - d. Effective Technological Measures means those measures that, in the - absence of proper authority, may not be circumvented under laws - fulfilling obligations under Article 11 of the WIPO Copyright - Treaty adopted on December 20, 1996, and/or similar international - agreements. - - e. Exceptions and Limitations means fair use, fair dealing, and/or - any other exception or limitation to Copyright and Similar Rights - that applies to Your use of the Licensed Material. - - f. Licensed Material means the artistic or literary work, database, - or other material to which the Licensor applied this Public - License. - - g. Licensed Rights means the rights granted to You subject to the - terms and conditions of this Public License, which are limited to - all Copyright and Similar Rights that apply to Your use of the - Licensed Material and that the Licensor has authority to license. - - h. Licensor means the individual(s) or entity(ies) granting rights - under this Public License. - - i. Share means to provide material to the public by any means or - process that requires permission under the Licensed Rights, such - as reproduction, public display, public performance, distribution, - dissemination, communication, or importation, and to make material - available to the public including in ways that members of the - public may access the material from a place and at a time - individually chosen by them. - - j. Sui Generis Database Rights means rights other than copyright - resulting from Directive 96/9/EC of the European Parliament and of - the Council of 11 March 1996 on the legal protection of databases, - as amended and/or succeeded, as well as other essentially - equivalent rights anywhere in the world. - - k. You means the individual or entity exercising the Licensed Rights - under this Public License. Your has a corresponding meaning. - - -Section 2 -- Scope. - - a. License grant. - - 1. Subject to the terms and conditions of this Public License, - the Licensor hereby grants You a worldwide, royalty-free, - non-sublicensable, non-exclusive, irrevocable license to - exercise the Licensed Rights in the Licensed Material to: - - a. reproduce and Share the Licensed Material, in whole or - in part; and - - b. produce, reproduce, and Share Adapted Material. - - 2. Exceptions and Limitations. For the avoidance of doubt, where - Exceptions and Limitations apply to Your use, this Public - License does not apply, and You do not need to comply with - its terms and conditions. - - 3. Term. The term of this Public License is specified in Section - 6(a). - - 4. Media and formats; technical modifications allowed. The - Licensor authorizes You to exercise the Licensed Rights in - all media and formats whether now known or hereafter created, - and to make technical modifications necessary to do so. The - Licensor waives and/or agrees not to assert any right or - authority to forbid You from making technical modifications - necessary to exercise the Licensed Rights, including - technical modifications necessary to circumvent Effective - Technological Measures. For purposes of this Public License, - simply making modifications authorized by this Section 2(a) - (4) never produces Adapted Material. - - 5. Downstream recipients. - - a. Offer from the Licensor -- Licensed Material. Every - recipient of the Licensed Material automatically - receives an offer from the Licensor to exercise the - Licensed Rights under the terms and conditions of this - Public License. - - b. No downstream restrictions. You may not offer or impose - any additional or different terms or conditions on, or - apply any Effective Technological Measures to, the - Licensed Material if doing so restricts exercise of the - Licensed Rights by any recipient of the Licensed - Material. - - 6. No endorsement. Nothing in this Public License constitutes or - may be construed as permission to assert or imply that You - are, or that Your use of the Licensed Material is, connected - with, or sponsored, endorsed, or granted official status by, - the Licensor or others designated to receive attribution as - provided in Section 3(a)(1)(A)(i). - - b. Other rights. - - 1. Moral rights, such as the right of integrity, are not - licensed under this Public License, nor are publicity, - privacy, and/or other similar personality rights; however, to - the extent possible, the Licensor waives and/or agrees not to - assert any such rights held by the Licensor to the limited - extent necessary to allow You to exercise the Licensed - Rights, but not otherwise. - - 2. Patent and trademark rights are not licensed under this - Public License. - - 3. To the extent possible, the Licensor waives any right to - collect royalties from You for the exercise of the Licensed - Rights, whether directly or through a collecting society - under any voluntary or waivable statutory or compulsory - licensing scheme. In all other cases the Licensor expressly - reserves any right to collect such royalties. - - -Section 3 -- License Conditions. - -Your exercise of the Licensed Rights is expressly made subject to the -following conditions. - - a. Attribution. - - 1. If You Share the Licensed Material (including in modified - form), You must: - - a. retain the following if it is supplied by the Licensor - with the Licensed Material: - - i. identification of the creator(s) of the Licensed - Material and any others designated to receive - attribution, in any reasonable manner requested by - the Licensor (including by pseudonym if - designated); - - ii. a copyright notice; - - iii. a notice that refers to this Public License; - - iv. a notice that refers to the disclaimer of - warranties; - - v. a URI or hyperlink to the Licensed Material to the - extent reasonably practicable; - - b. indicate if You modified the Licensed Material and - retain an indication of any previous modifications; and - - c. indicate the Licensed Material is licensed under this - Public License, and include the text of, or the URI or - hyperlink to, this Public License. - - 2. You may satisfy the conditions in Section 3(a)(1) in any - reasonable manner based on the medium, means, and context in - which You Share the Licensed Material. For example, it may be - reasonable to satisfy the conditions by providing a URI or - hyperlink to a resource that includes the required - information. - - 3. If requested by the Licensor, You must remove any of the - information required by Section 3(a)(1)(A) to the extent - reasonably practicable. - - 4. If You Share Adapted Material You produce, the Adapter's - License You apply must not prevent recipients of the Adapted - Material from complying with this Public License. - - -Section 4 -- Sui Generis Database Rights. - -Where the Licensed Rights include Sui Generis Database Rights that -apply to Your use of the Licensed Material: - - a. for the avoidance of doubt, Section 2(a)(1) grants You the right - to extract, reuse, reproduce, and Share all or a substantial - portion of the contents of the database; - - b. if You include all or a substantial portion of the database - contents in a database in which You have Sui Generis Database - Rights, then the database in which You have Sui Generis Database - Rights (but not its individual contents) is Adapted Material; and - - c. You must comply with the conditions in Section 3(a) if You Share - all or a substantial portion of the contents of the database. - -For the avoidance of doubt, this Section 4 supplements and does not -replace Your obligations under this Public License where the Licensed -Rights include other Copyright and Similar Rights. - - -Section 5 -- Disclaimer of Warranties and Limitation of Liability. - - a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE - EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS - AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF - ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, - IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, - WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR - PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, - ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT - KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT - ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. - - b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE - TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, - NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, - INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, - COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR - USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN - ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR - DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR - IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. - - c. The disclaimer of warranties and limitation of liability provided - above shall be interpreted in a manner that, to the extent - possible, most closely approximates an absolute disclaimer and - waiver of all liability. - - -Section 6 -- Term and Termination. - - a. This Public License applies for the term of the Copyright and - Similar Rights licensed here. However, if You fail to comply with - this Public License, then Your rights under this Public License - terminate automatically. - - b. Where Your right to use the Licensed Material has terminated under - Section 6(a), it reinstates: - - 1. automatically as of the date the violation is cured, provided - it is cured within 30 days of Your discovery of the - violation; or - - 2. upon express reinstatement by the Licensor. - - For the avoidance of doubt, this Section 6(b) does not affect any - right the Licensor may have to seek remedies for Your violations - of this Public License. - - c. For the avoidance of doubt, the Licensor may also offer the - Licensed Material under separate terms or conditions or stop - distributing the Licensed Material at any time; however, doing so - will not terminate this Public License. - - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public - License. - - -Section 7 -- Other Terms and Conditions. - - a. The Licensor shall not be bound by any additional or different - terms or conditions communicated by You unless expressly agreed. - - b. Any arrangements, understandings, or agreements regarding the - Licensed Material not stated herein are separate from and - independent of the terms and conditions of this Public License. - - -Section 8 -- Interpretation. - - a. For the avoidance of doubt, this Public License does not, and - shall not be interpreted to, reduce, limit, restrict, or impose - conditions on any use of the Licensed Material that could lawfully - be made without permission under this Public License. - - b. To the extent possible, if any provision of this Public License is - deemed unenforceable, it shall be automatically reformed to the - minimum extent necessary to make it enforceable. If the provision - cannot be reformed, it shall be severed from this Public License - without affecting the enforceability of the remaining terms and - conditions. - - c. No term or condition of this Public License will be waived and no - failure to comply consented to unless expressly agreed to by the - Licensor. - - d. Nothing in this Public License constitutes or may be interpreted - as a limitation upon, or waiver of, any privileges and immunities - that apply to the Licensor or You, including from the legal - processes of any jurisdiction or authority. - - -======================================================================= - -Creative Commons is not a party to its public -licenses. Notwithstanding, Creative Commons may elect to apply one of -its public licenses to material it publishes and in those instances -will be considered the “Licensor.” The text of the Creative Commons -public licenses is dedicated to the public domain under the CC0 Public -Domain Dedication. Except for the limited purpose of indicating that -material is shared under a Creative Commons public license or as -otherwise permitted by the Creative Commons policies published at -creativecommons.org/policies, Creative Commons does not authorize the -use of the trademark "Creative Commons" or any other trademark or logo -of Creative Commons without its prior written consent including, -without limitation, in connection with any unauthorized modifications -to any of its public licenses or any other arrangements, -understandings, or agreements concerning use of licensed material. For -the avoidance of doubt, this paragraph does not form part of the -public licenses. - -Creative Commons may be contacted at creativecommons.org. diff --git a/docs/devtoolkit/docs/Makefile b/docs/devtoolkit/docs/Makefile deleted file mode 100644 index 1eff8952707bdfa503c8d60c1e9a903053170ba2..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source_zh_cn -BUILDDIR = build_zh_cn - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/devtoolkit/docs/requirements.txt b/docs/devtoolkit/docs/requirements.txt deleted file mode 100644 index a1b6a69f6dbd9c6f78710f56889e14f0e85b27f4..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -sphinx == 4.4.0 -docutils == 0.17.1 -myst-parser == 0.18.1 -sphinx_rtd_theme == 1.0.0 -numpy -IPython -jieba diff --git a/docs/devtoolkit/docs/source_en/PyCharm_change_version.md b/docs/devtoolkit/docs/source_en/PyCharm_change_version.md deleted file mode 100644 index 0098ecec9f755781bbef0a1500bf5e11409d9efc..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/PyCharm_change_version.md +++ /dev/null @@ -1,38 +0,0 @@ -# API Mapping - API Version Switching - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/PyCharm_change_version.md) - -## Overview - -API mapping refers to the mapping relationship between PyTorch API and MindSpore API. -In MindSpore Dev Toolkit, it provides two functions: API mapping search and API mapping scan, and users can freely switch the version of API mapping data. - -## API Mapping Data Version Switching - -1. When the plug-in starts, it defaults to the same API mapping data version as the current version of the plug-in. The API mapping data version is shown in the lower right. This version number only affects the API mapping functionality of this section and does not change the version of MindSpore in the environment. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image137.jpg) - -2. Click the API mapping data version to bring up the selection list. You can choose to switch to other version by clicking on the preset version, or you can choose "other version" to try to switch by inputting other version number. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image138.jpg) - -3. Click on any version number to start switching versions. An animation below indicates the switching status. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image139.jpg) - -4. If you want to customize the version number, select "other version" in the selection list, enter the version number in the popup box, and click ok to start switching versions. Note: Please input the version number in 2.1 or 2.1.0 format, otherwise there will be no response when you click ok. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image140.jpg) - -5. If the switch is successful, the lower right status bar displays the API mapping data version information after the switch. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image141.jpg) - -6. If the switch fails, the lower right status bar shows the API mapping data version information before the switch. If the switch fails due to non-existent version number or network error, please check and try again. If you want to see the latest documentation, you can switch to the master version. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image142.jpg) - -7. When a customized version number is successfully switched, this version number is added to the list of versions to display. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image143.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_en/PyCharm_plugin_install.md b/docs/devtoolkit/docs/source_en/PyCharm_plugin_install.md deleted file mode 100644 index 2850c6f4bd53ee005b6be78f15ccb3d7719a0dc3..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/PyCharm_plugin_install.md +++ /dev/null @@ -1,13 +0,0 @@ -# PyCharm Plug-in Installation - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/PyCharm_plugin_install.md) - -## Installation Steps - -1. Obtain [Plug-in Zip package](https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/IdePlugin/any/MindSpore_Dev_ToolKit-2.1.0.zip). -2. Start Pycharm and click on the upper left menu bar, select File->Settings->Plugins->Install Plugin from Disk. - As shown in the figure: - - ![image-20211223175637989](./images/clip_image050.jpg) - -3. Select the plug-in zip package. \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_en/VSCode_api_scan.md b/docs/devtoolkit/docs/source_en/VSCode_api_scan.md deleted file mode 100644 index f4c5940a574d63163fc6481b7e32943d30def042..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/VSCode_api_scan.md +++ /dev/null @@ -1,49 +0,0 @@ -# API Mapping - API Sacnning - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/VSCode_api_scan.md) - -## Functions Introduction - -* Quickly scan for APIs that appear in your code and display API details directly in the sidebar. -* For the convenience of users of other machine learning frameworks, the corresponding MindSpore APIs are matched by association by scanning the mainstream framework APIs that appear in the code. -* The data version of API mapping supports switching. Please refer to the section [API Mapping - Version Switching](https://www.mindspore.cn/devtoolkit/docs/en/master/VSCode_change_version.html) for details. - -## File-level API Mapping Scanning - -1. Right-click anywhere in the current file to open the menu and select "Scan Local Files". - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image116.jpg) - -2. The right-hand column will populate with the scanned operators in the current file, including three scanning result list "PyTorch APIs that can be transformed", "Probably the result of torch.Tensor API" and "PyTorch API that does not provide a direct mapping relationship at this time". - - where - - * "PyTorch APIs that can be transformed" means PyTorch APIs used in the Documentation can be converted to MindSpore APIs. - * "Probably the result of torch.Tensor API" means APIs with the same name as torch.Tensor, which may be torch.Tensor APIs and can be converted to MindSpore APIs. - * "PyTorch API that does not provide a direct mapping relationship at this time" means APIs that are PyTorch APIs or possibly torch.Tensor APIs, but don't directly correspond to MindSpore APIs. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image117.jpg) - -## Project-level API Mapping Scanning - -1. Click the MindSpore API Mapping Scan icon on the left sidebar of Visual Studio Code. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image118.jpg) - -2. Generate a project tree view of the current IDE project containing only Python files on the left sidebar. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image119.jpg) - -3. If you select a single Python file in the view, you can get a list of operator scan results for that file. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image120.jpg) - -4. If you select a file directory in the view, you can get a list of operator scan results for all Python files in that directory. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image121.jpg) - -5. The blue font parts are all clickable and will automatically open the page in the user-default browser. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image122.jpg) - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image123.jpg) diff --git a/docs/devtoolkit/docs/source_en/VSCode_api_search.md b/docs/devtoolkit/docs/source_en/VSCode_api_search.md deleted file mode 100644 index 83b4fb4bcc26ce8d07de01074930a7efca4a4b38..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/VSCode_api_search.md +++ /dev/null @@ -1,29 +0,0 @@ -# API Mapping - API Search - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/VSCode_api_search.md) - -## Function Introduction - -* Quickly search for MindSpore APIs and display API details directly in the sidebar. -* For the convenience of users of other machine learning frameworks, the association matches the corresponding MindSpore API by searching for other mainstream framework APIs. -* The data version of API mapping supports switching. Please refer to the section [API Mapping - Version Switching](https://www.mindspore.cn/devtoolkit/docs/en/master/VSCode_change_version.html) for details. - -## Usage Steps - -1. Click the MindSpore API Mapping Search icon on the left sidebar of Visual Studio Code. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image124.jpg) - -2. An input box is generated in the left sidebar. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image125.jpg) - -3. Enter any word in the input box, the search results for the current keyword will be displayed below, and the search results are updated in real time according to the input content. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image126.jpg) - -4. Click on any search result and open the page in the user default browser. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image127.jpg) - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image128.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_en/VSCode_change_version.md b/docs/devtoolkit/docs/source_en/VSCode_change_version.md deleted file mode 100644 index 6469fa1f78a21af96e7f0fff45f7412af03fca45..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/VSCode_change_version.md +++ /dev/null @@ -1,39 +0,0 @@ -# API Mapping - Version Switching - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/VSCode_change_version.md) - -## Overview - -API mapping refers to the mapping relationship between PyTorch API and MindSpore API. In MindSpore Dev Toolkit, it provides two functions: API mapping search and API mapping scan, and users can freely switch the version of API mapping data. - -## API Mapping Data Version Switching - -1. Different versions of API mapping data will result in different API mapping scans and API mapping search results, but will not affect the version of MindSpore in the environment. The default version is the same as the plugin version, and the version information is displayed in the bottom left status bar. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image129.jpg) - -2. Clicking on this status bar will bring up a drop-down box at the top of the page containing options for the default version numbers that can be switched. Users can click on any version number to switch versions, or click on the "Customize Input" option and enter another version number in the pop-up input box to switch versions. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image130.jpg) - -3. Click on any version number to start switching versions, and the status bar in the lower left indicates the status of version switching. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image131.jpg) - -4. If you want to customize the version number, click the "Customize Input" option in the drop-down box, and the drop-down box will be changed to an input box, enter the version number according to the format of 2.1 or 2.1.0, and then press the Enter key to start switching the version, and the status bar in the lower-left corner will indicate the status of the switching. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image132.jpg) - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image133.jpg) - -5. If the switch is successful, the message in the lower right indicates that the switch is successful, and the status bar in the lower left displays information about the version of the API mapping data after the switch. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image134.jpg) - -6. If the switch fails, the message in the lower right indicates that the switch fails, and the status bar in the lower left shows the API mapping data version information before the switch. If the switch fails due to non-existent version number or network error, please check and try again. If you want to see the latest documentation, you can switch to the master version. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image135.jpg) - -7. When the customized version number is switched successfully, this version number is added to the drop-down box for display. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image136.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_en/VSCode_plugin_install.md b/docs/devtoolkit/docs/source_en/VSCode_plugin_install.md deleted file mode 100644 index 7f192ae88ac085657900b582550961f6592344eb..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/VSCode_plugin_install.md +++ /dev/null @@ -1,18 +0,0 @@ -# Visual Studio Code Plug-in Installation - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/VSCode_plugin_install.md) - -## Installation Steps - -1. Obtain [Plug-in vsix package](https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/IdePlugin/any/mindspore-dev-toolkit-2.1.0.vsix). -2. Click the fifth button on the left, "Extensions". Click the three dots in the upper right corner, and then click "Install from VSIX..." - - ![image-20211223175637989](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image112.jpg) - -3. Select the downloaded vsix file from the folder and the plug-in will be installed automatically. When there says "Completed installing MindSpore Dev Toolkit extension from VSIX" in the bottom right corner, the plug-in is successfully installed. - - ![image-20211223175637989](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image113.jpg) - -4. Click the refresh button in the left column, and you can see the "MindSpore Dev Toolkit" plug-in in the "INSTALLED" directory. In this way, the plug-in is successfully installed. - - ![image-20211223175637989](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image114.jpg) diff --git a/docs/devtoolkit/docs/source_en/VSCode_smart_completion.md b/docs/devtoolkit/docs/source_en/VSCode_smart_completion.md deleted file mode 100644 index ad4f7ec61e109e178c86d6a25a3d74509054369d..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/VSCode_smart_completion.md +++ /dev/null @@ -1,22 +0,0 @@ -# Code Completion - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/VSCode_smart_completion.md) - -## Functions Description - -* Provide AI code completion based on the MindSpore project. -* Easily develop MindSpore without installing MindSpore environment. - -## Usage Steps - -1. When you install or use the plug-in for the first time, the model will be downloaded automatically. "Start Downloading Model" will appear in the lower right corner, and "Download Model Successfully" indicates that the model is downloaded and started successfully. If the Internet speed is slow, it will take more than ten minutes to download the model. The message "Downloaded Model Successfully" will appear only after the download is complete. If this is not the first time you use it, the message will not appear. - - ![image-20211223175637989](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image115.jpg) - -2. Open the Python file to write the code. - - ![image-20211223175637989](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image097.jpg) - -3. The completion will take effect automatically when coding. The code with the MindSpore Dev Toolkit suffix name is the code provided by plug-in smart completion. - - ![image-20211223175637989](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image111.jpg) diff --git a/docs/devtoolkit/docs/source_en/api_scanning.md b/docs/devtoolkit/docs/source_en/api_scanning.md deleted file mode 100644 index 687c6f09bc1a5ac8a3f734bc2391e5a7e2ee46a6..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/api_scanning.md +++ /dev/null @@ -1,62 +0,0 @@ -# API Mapping - API Scanning - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/api_scanning.md) - -## Functions Introduction - -* Quickly scan the APIs in the code and display the API details directly in the sidebar. -* For the convenience of other machine learning framework users, by scanning the mainstream framework APIs that appear in the code, associative matching the corresponding MindSpore API. -* The data version of API mapping supports switching, and please refer to the section [API Mapping - Version Switching](https://www.mindspore.cn/devtoolkit/docs/en/master/PyCharm_change_version.html) for details. - -## Usage Steps - -### Document-level API Scanning - -1. Right click anywhere in the current file to open the menu, and click "API scan" at the top of the menu. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image100.jpg) - -2. The right sidebar will automatically pop up to show the scanned operator and display a detailed list containing the name, URL and other information. If no operator is scanned in this document, no pop-up window will appear. - - where: - - * "PyTorch/TensorFlow APIs that can be converted to MindSpore APIs" means PyTorch or TensorFlow APIs used in the Documentation that can be converted to MindSpore APIs. - * "APIs that cannot be converted at this time" means APIs that are PyTorch or TensorFlow APIs but do not have a direct equivalent to MindSpore APIs. - * "Possible PyTorch/TensorFlow API" refers to a convertible case where there is a possible PyTorch or TensorFlow API because of chained calls. - * TensorFlow API scanning is an experimental feature. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image101.jpg) - -3. Click the blue words, and another column will automatically open at the top to show the page. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image102.jpg) - -4. Click the "export" button in the upper right corner to export the content to a csv table. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image103.jpg) - -### Project-level API Scanning - -1. Right-click anywhere on the current file to open the menu, click the second option "API scan project-level" at the top of the menu, or select "Tools" in the toolbar above, and then select "API scan project-level". - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image104.jpg) - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image105.jpg) - -2. The right sidebar pops up a list of scanned operators from the entire project, and displays a detailed list containing information such as name, URL, etc. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image106.jpg) - -3. In the upper box you can select a single file, and in the lower box the operators in this file will be shown separately, and the file selection can be switched at will. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image107.jpg) - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image108.jpg) - -4. Click the blue words, and another column will automatically open at the top to show the page. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image109.jpg) - -5. Click the "export" button in the upper right corner to export the content to a csv table. - - ![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/devtoolkit/docs/source_zh_cn/images/clip_image110.jpg) diff --git a/docs/devtoolkit/docs/source_en/api_search.md b/docs/devtoolkit/docs/source_en/api_search.md deleted file mode 100644 index fdc498109f16503d20ee17297fe71cb3491de439..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/api_search.md +++ /dev/null @@ -1,29 +0,0 @@ -# API Mapping - API Search - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/api_search.md) - -## Functions - -* You can quickly search for MindSpore APIs and view API details in the sidebar. -* If you use other machine learning frameworks, you can search for APIs of other mainstream frameworks to match MindSpore APIs. -* The data version of API mapping supports switching, and please refer to the section [API Mapping - Version Switching](https://www.mindspore.cn/devtoolkit/docs/en/master/PyCharm_change_version.html) for details. - -## Procedure - -1. Double-click **Shift**. The global search page is displayed. - - ![img](images/clip_image060.jpg) - -2. Click **MindSpore**. - - ![img](images/clip_image062.jpg) - -3. Search for a PyTorch or TensorFlow API to obtain the mapping between the PyTorch or TensorFlow API and the MindSpore API. - - ![img](images/clip_image064.jpg) - - ![img](images/clip_image066.jpg) - -4. Click an item in the list to view its official document in the sidebar. - - ![img](images/clip_image068.jpg) diff --git a/docs/devtoolkit/docs/source_en/compiling.md b/docs/devtoolkit/docs/source_en/compiling.md deleted file mode 100644 index b7b531c1ea2459f340cc18c62317014f6cf2faa5..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/compiling.md +++ /dev/null @@ -1,87 +0,0 @@ -# Source Code Compilation Guide - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/compiling.md) - -The following describes how to compile the MindSpore Dev ToolKit project based on the IntelliJ IDEA source code. - -## Background - -* MindSpore Dev ToolKit is a PyCharm plug-in developed using IntelliJ IDEA. [IntelliJ IDEA](https://www.jetbrains.com/idea/download) and PyCharm are IDEs developed by JetBrains. -* MindSpore Dev ToolKit is developed based on JDK 11. To learn JDK- and Java-related knowledge, visit [https://jdk.java.net/](https://jdk.java.net/). -* [Gradle 6.6.1](https://gradle.org) is used to build MindSpore Dev Toolkit and it does not need to be installed in advance. IntelliJ IDEA automatically uses the "gradle wrapper" mechanism to configure the required Gradle based on the code. - -## Required Software - -* [IntelliJ IDEA](https://www.jetbrains.com/idea/download) - -* JDK 11 - - Note: IntelliJ IDEA 2021.3 contains a built-in JDK named **jbr-11 JetBrains Runtime version 11.0.10**, which can be used directly. - - ![img](images/clip_image031.jpg) - -## Compilation - -1. Verify that the required software has been successfully configured. - -2. Download the [project](https://gitee.com/mindspore/ide-plugin) source code from the code repository. - - * Download the ZIP package. - - ![img](images/clip_image032.jpg) - - * Run the git command to download the package. - - ``` - git clone https://gitee.com/mindspore/ide-plugin.git - ``` - -3. Use IntelliJ IDEA to open the project. - - 3.1 Choose **File** > **Open**. - - ![img](images/clip_image033.jpg) - - 3.2 Go to the directory for storing the project. - - ![img](images/clip_image034.jpg) - - 3.3 Click **Load** in the dialog box that is displayed in the lower right corner. Alternatively, click **pycharm**, right-click the **settings.gradle** file, and choose **Link Gradle Project** from the shortcut menu. - - ![img](images/clip_image035.jpg) - - ![img](images/clip_image036.jpg) - -4. If the system displays a message indicating that no JDK is available, select a JDK. ***Skip this step if the JDK is available.*** - - 4.1 The following figure shows the situation when the JDK is not available. - - ![img](images/clip_image037.jpg) - - 4.2 Choose **File** > **Project Structure**. - - ![img](images/clip_image038.jpg) - - 4.3 Select JDK 11. - - ![img](images/clip_image039.jpg) - -5. Wait until the synchronization is complete. - - ![img](images/clip_image040.jpg) - -6. Build a project. - - ![img](images/clip_image042.jpg) - -7. Wait till the build is complete. - - ![img](images/clip_image044.jpg) - -8. Obtain the plug-in installation package from the **/pycharm/build/distributions** directory in the project directory. - - ![img](images/clip_image046.jpg) - -## References - -* This project is built based on section [Building Plugins with Gradle](https://plugins.jetbrains.com/docs/intellij/gradle-build-system.html) in *IntelliJ Platform Plugin SDK*. For details about advanced functions such as debugging, see the official document. diff --git a/docs/devtoolkit/docs/source_en/conf.py b/docs/devtoolkit/docs/source_en/conf.py deleted file mode 100644 index 06e4e63c9fe2caa7ae185c9784c961b9dff9a4ac..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/conf.py +++ /dev/null @@ -1,85 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import re - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'myst_parser', - 'sphinx.ext.autodoc' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -html_search_options = {'dict': '../../../resource/jieba.txt'} - -html_static_path = ['_static'] - -src_release = os.path.join(os.getenv("DT_PATH"), 'RELEASE.md') -des_release = "./RELEASE.md" -with open(src_release, "r", encoding="utf-8") as f: - data = f.read() -if len(re.findall("\n## (.*?)\n",data)) > 1: - content = re.findall("(## [\s\S\n]*?)\n## ", data) -else: - content = re.findall("(## [\s\S\n]*)", data) -#result = content[0].replace('# MindSpore', '#', 1) -with open(des_release, "w", encoding="utf-8") as p: - p.write("# Release Notes"+"\n\n") - p.write(content[0]) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_en/images/clip_image002.jpg b/docs/devtoolkit/docs/source_en/images/clip_image002.jpg deleted file mode 100644 index 24132302f1552bed6be56b7dd660625448774680..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image002.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image004.jpg b/docs/devtoolkit/docs/source_en/images/clip_image004.jpg deleted file mode 100644 index 7ed0e3729940f514a7bfd61c1a7be22166c0bb02..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image004.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image006.jpg b/docs/devtoolkit/docs/source_en/images/clip_image006.jpg deleted file mode 100644 index e0c323eec249024fe19126ce4c931133564cf7b7..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image006.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image008.jpg b/docs/devtoolkit/docs/source_en/images/clip_image008.jpg deleted file mode 100644 index a071ad67222931372d4b62f7b0cf334a4015e70d..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image008.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image010.jpg b/docs/devtoolkit/docs/source_en/images/clip_image010.jpg deleted file mode 100644 index 43ca88d40bc56d5a5113bc29b97a7b559ef659af..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image010.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image012.jpg b/docs/devtoolkit/docs/source_en/images/clip_image012.jpg deleted file mode 100644 index 0e35c9f219292913a51f1f0d5b7a5e154008620f..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image012.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image014.jpg b/docs/devtoolkit/docs/source_en/images/clip_image014.jpg deleted file mode 100644 index 794de60c7d7e76a58e8d7212e449a1bd8e194b21..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image014.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image015.jpg b/docs/devtoolkit/docs/source_en/images/clip_image015.jpg deleted file mode 100644 index 8172e21f871bed6866b3a91b83252838228d257a..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image015.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image016.jpg b/docs/devtoolkit/docs/source_en/images/clip_image016.jpg deleted file mode 100644 index c836c0cf7898e4757ddb3410dde18e754894dd25..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image016.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image018.jpg b/docs/devtoolkit/docs/source_en/images/clip_image018.jpg deleted file mode 100644 index 777738f7cf60454b7b5f26e6c5a29ba5d55750ff..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image018.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image019.jpg b/docs/devtoolkit/docs/source_en/images/clip_image019.jpg deleted file mode 100644 index ab02b702bfd1c0986adb4a15c1f455b56df0a4a1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image019.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image020.jpg b/docs/devtoolkit/docs/source_en/images/clip_image020.jpg deleted file mode 100644 index d946c3cb3a851f690a5643b5afe119597aed5b22..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image020.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image021.jpg b/docs/devtoolkit/docs/source_en/images/clip_image021.jpg deleted file mode 100644 index 74672d9513f4a60f77450ae5516cee2060215241..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image021.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image022.jpg b/docs/devtoolkit/docs/source_en/images/clip_image022.jpg deleted file mode 100644 index 6b26f18c7d8bb43db0beb8a8d2bd386489192922..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image022.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image023.jpg b/docs/devtoolkit/docs/source_en/images/clip_image023.jpg deleted file mode 100644 index 5981a0fb25c681417a0bdf24d2392411b41f0faf..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image023.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image024.jpg b/docs/devtoolkit/docs/source_en/images/clip_image024.jpg deleted file mode 100644 index 505e8b6c5c4d81dfd67d91b1dad08f938444368a..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image024.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image025.jpg b/docs/devtoolkit/docs/source_en/images/clip_image025.jpg deleted file mode 100644 index 946e276b6a982303b6b146266037a739e6c2639a..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image025.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image026.jpg b/docs/devtoolkit/docs/source_en/images/clip_image026.jpg deleted file mode 100644 index 8b787215af5cf9ae223c2b33121c3171923d2de2..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image026.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image027.jpg b/docs/devtoolkit/docs/source_en/images/clip_image027.jpg deleted file mode 100644 index aa4d7d4a8a6b503fe29885368547daa535e34796..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image027.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image028.jpg b/docs/devtoolkit/docs/source_en/images/clip_image028.jpg deleted file mode 100644 index 3126f80aecac28e8beaa54dc393122c60dbe1357..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image028.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image029.jpg b/docs/devtoolkit/docs/source_en/images/clip_image029.jpg deleted file mode 100644 index 6587240e4a456f3792fece52bcdcbbed077ca67b..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image029.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image031.jpg b/docs/devtoolkit/docs/source_en/images/clip_image031.jpg deleted file mode 100644 index 2f829b48e72e62525860cfe599e0a4ada82010ca..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image031.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image032.jpg b/docs/devtoolkit/docs/source_en/images/clip_image032.jpg deleted file mode 100644 index 37589efbe0f57c442f665824831a2685d81c8713..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image032.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image033.jpg b/docs/devtoolkit/docs/source_en/images/clip_image033.jpg deleted file mode 100644 index bdca68324cf7ee8f4e9bd18817a82954910e52c9..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image033.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image034.jpg b/docs/devtoolkit/docs/source_en/images/clip_image034.jpg deleted file mode 100644 index 874b10d4b2ca476da16a4d1e749cdb6b31ecb59e..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image034.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image035.jpg b/docs/devtoolkit/docs/source_en/images/clip_image035.jpg deleted file mode 100644 index 0b0465169553e57795320255295b8fa789950522..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image035.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image036.jpg b/docs/devtoolkit/docs/source_en/images/clip_image036.jpg deleted file mode 100644 index c7c6c72819b655884d97637b696d1814e5a7fdbf..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image036.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image037.jpg b/docs/devtoolkit/docs/source_en/images/clip_image037.jpg deleted file mode 100644 index 531e8184e02c43aa177a51c3cc32355cc3df9d42..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image037.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image038.jpg b/docs/devtoolkit/docs/source_en/images/clip_image038.jpg deleted file mode 100644 index a8b4d88190c626139bad49cd42a9f7e908b4d0e4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image038.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image039.jpg b/docs/devtoolkit/docs/source_en/images/clip_image039.jpg deleted file mode 100644 index 2eab0ceac9c1bd5d8b6ade3d65a6a3ce8b1f8fd4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image039.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image040.jpg b/docs/devtoolkit/docs/source_en/images/clip_image040.jpg deleted file mode 100644 index a879fb1f12d8b6c4bd02332abf9b3bd734207763..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image040.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image042.jpg b/docs/devtoolkit/docs/source_en/images/clip_image042.jpg deleted file mode 100644 index 2454ade258da6d428c9e23ece2adf7f0291d1a12..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image042.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image044.jpg b/docs/devtoolkit/docs/source_en/images/clip_image044.jpg deleted file mode 100644 index cbff652015c36a5856afc909518f3c0fd22f23ff..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image044.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image046.jpg b/docs/devtoolkit/docs/source_en/images/clip_image046.jpg deleted file mode 100644 index 58a493ea4f69b264fc69cfd3e34f32d5a171c303..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image046.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image050.jpg b/docs/devtoolkit/docs/source_en/images/clip_image050.jpg deleted file mode 100644 index 35cc26d483358550c9a53ce855c2ae483eddb7e1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image050.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image060.jpg b/docs/devtoolkit/docs/source_en/images/clip_image060.jpg deleted file mode 100644 index 7723975694f7f56d88187a69626343af11efbd23..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image060.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image062.jpg b/docs/devtoolkit/docs/source_en/images/clip_image062.jpg deleted file mode 100644 index 838bc48ab8d77f7dbba9ca02925838a49b19ce53..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image062.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image064.jpg b/docs/devtoolkit/docs/source_en/images/clip_image064.jpg deleted file mode 100644 index fb39e70b78b45af301973ea802a219c482a21590..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image064.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image066.jpg b/docs/devtoolkit/docs/source_en/images/clip_image066.jpg deleted file mode 100644 index 0a596cfb3ef7a79674ff33a7be5c97859cc2b9c4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image066.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image068.jpg b/docs/devtoolkit/docs/source_en/images/clip_image068.jpg deleted file mode 100644 index 0023ba9236a768001e462d6a10434719f3f733fd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image068.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image072.jpg b/docs/devtoolkit/docs/source_en/images/clip_image072.jpg deleted file mode 100644 index d1e5fad4192d4cb5821cafe4031dbc4ff599eccd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image072.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image074.jpg b/docs/devtoolkit/docs/source_en/images/clip_image074.jpg deleted file mode 100644 index 97fa2b21b4029ff75156893f8abdff2e77aa38bd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image074.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image076.jpg b/docs/devtoolkit/docs/source_en/images/clip_image076.jpg deleted file mode 100644 index e754c7dcd30ee6fa82ab20bde4d66f69aabe2fa7..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image076.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image088.jpg b/docs/devtoolkit/docs/source_en/images/clip_image088.jpg deleted file mode 100644 index 8b85f0727893c4cf6cd258550466ca4f4a340e6e..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image088.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image090.jpg b/docs/devtoolkit/docs/source_en/images/clip_image090.jpg deleted file mode 100644 index a3f405388fd75b23b652bc86475be5fd5e1f48ac..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image090.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image092.jpg b/docs/devtoolkit/docs/source_en/images/clip_image092.jpg deleted file mode 100644 index 68ca9c66fc3f03760873075af20c8a9e28aaab48..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image092.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image093.jpg b/docs/devtoolkit/docs/source_en/images/clip_image093.jpg deleted file mode 100644 index 594b2ceadb2c27290e0339e14b298fa2feffe6a9..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image093.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image094.jpg b/docs/devtoolkit/docs/source_en/images/clip_image094.jpg deleted file mode 100644 index e931a95180d27d55590e73948ebe80a1f81bede1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image094.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/images/clip_image096.jpg b/docs/devtoolkit/docs/source_en/images/clip_image096.jpg deleted file mode 100644 index 3ed0c88500bb4caffccea4d08aaa3a6310e177bd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_en/images/clip_image096.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_en/index.rst b/docs/devtoolkit/docs/source_en/index.rst deleted file mode 100644 index a5723cfae54780d9911757eadbe14d0991952572..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/index.rst +++ /dev/null @@ -1,61 +0,0 @@ -MindSpore Dev Toolkit -============================ - -MindSpore Dev Toolkit is a development kit supporting the `PyCharm `_ (cross-platform Python IDE) plug-in developed by MindSpore, and provides functions such as `Project creation `_ , `intelligent supplement `_ , `API search `_ , and `Document search `_ . - -MindSpore Dev Toolkit creates the best intelligent computing experience, improve the usability of the MindSpore framework, and facilitate the promotion of the MindSpore ecosystem through deep learning, intelligent search, and intelligent recommendation. - -Code repository address: - -System Requirements ------------------------------- - -- Operating systems supported by the plug-in: - - - Windows 10 - - - Linux - - - macOS (Only the x86 architecture is supported. The code completion function is not available currently.) - -- PyCharm versions supported by the plug-in: - - - 2020.3 - - - 2021.X - - - 2022.X - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: PyCharm Plugin Usage Guide - :hidden: - - PyCharm_plugin_install - compiling - smart_completion - PyCharm_change_version - api_search - api_scanning - knowledge_search - mindspore_project_wizard - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: VSCode Plugin Usage Guide - :hidden: - - VSCode_plugin_install - VSCode_smart_completion - VSCode_change_version - VSCode_api_search - VSCode_api_scan - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE diff --git a/docs/devtoolkit/docs/source_en/knowledge_search.md b/docs/devtoolkit/docs/source_en/knowledge_search.md deleted file mode 100644 index 91a0606a577644b0a18684e55f0197b05bb72cfb..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/knowledge_search.md +++ /dev/null @@ -1,22 +0,0 @@ -# Document Search - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/knowledge_search.md) - -## Functions - -* Recommendation: It provides exact search results based on user habits. -* It provides immersive document search experience to avoid switching between the IDE and browser. It resides on the sidebar to adapt to the page layout. - -## Procedure - -1. Click the sidebar to display the search page. - - ![img](images/clip_image072.jpg) - -2. Enter **API Mapping** and click the search icon to view the result. - - ![img](images/clip_image074.jpg) - -3. Click the home icon to return to the search page. - - ![img](images/clip_image076.jpg) diff --git a/docs/devtoolkit/docs/source_en/mindspore_project_wizard.md b/docs/devtoolkit/docs/source_en/mindspore_project_wizard.md deleted file mode 100644 index b26c4cc49d24905c229a2a1c32d71f05a7f8fd30..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/mindspore_project_wizard.md +++ /dev/null @@ -1,103 +0,0 @@ -# Creating a Project - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/mindspore_project_wizard.md) - -## Background - -This function is implemented based on the [conda](https://conda.io). Conda is a package management and environment management system. It is one of the recommended installation modes for MindSpore. - -## Functions - -* It creates a conda environment or selects an existing conda environment, and installs the MindSpore binary package in the conda environment. -* It deploys the best practice template. In addition to testing whether the environment is successfully installed, it also provides a tutorial for MindSpore beginners. -* When the network condition is good, the environment can be installed within 10 minutes and you can experience MindSpore immediately. It reduces up to 80% environment configuration time for beginners. - -## Procedure - -1. Choose **File** > **New Project**. - - ![img](images/clip_image002.jpg) - -2. Select **MindSpore**. - - ![img](images/clip_image004.jpg) - -3. Download and install Miniconda. ***If conda has been installed, skip this step.*** - - 3.1 Click **Install Miniconda Automatically**. - - ![img](images/clip_image006.jpg) - - 3.2 Select an installation folder. **You are advised to use the default path to install conda.** - - ![img](images/clip_image008.jpg) - - 3.3 Click **Install**. - - ![img](images/clip_image010.jpg) - - ![img](images/clip_image012.jpg) - - 3.4 Wait for the installation to complete. - - ![img](images/clip_image014.jpg) - - 3.5 Restart PyCharm as prompted or restart PyCharm later. ***Note: The following steps can be performed only after PyCharm is restarted.*** - - ![img](images/clip_image015.jpg) - -4. If **Conda executable** is not automatically filled, select the path of the installed conda. - - ![img](images/clip_image016.jpg) - -5. Create a conda environment or select an existing conda environment. - - * Create a conda environment. **You are advised to use the default path to create the conda environment. Due to PyCharm restrictions on Linux, you can only select the default directory.** - - ![img](images/clip_image018.jpg) - - * Select an existing conda environment in PyCharm. - - ![img](images/clip_image019.jpg) - -6. Select a hardware environment and a MindSpore best practice template. - - 6.1 Select a hardware environment. - - ![img](images/clip_image020.jpg) - - 6.2 Select a best practice template. The best practice template provides some sample projects for beginners to get familiar with MindSpore. The best practice template can be run directly. - - ![img](images/clip_image021.jpg) - -7. Click **Create** to create a project and wait until MindSpore is successfully downloaded and installed. - - 7.1 Click **Create** to create a MindSpore project. - - ![img](images/clip_image022.jpg) - - 7.2 The conda environment is being created. - - ![img](images/clip_image023.jpg) - - 7.3 MindSpore is being configured through conda. - - ![img](images/clip_image024.jpg) - -8. Wait till the MindSpore project is created. - - ![img](images/clip_image025.jpg) - -9. Check whether the MindSpore project is successfully created. - - * Click **Terminal**, enter **python -c "import mindspore;mindspore.run_check()"**, and check the output. If the version number shown in the following figure is displayed, the MindSpore environment is available. - - ![img](images/clip_image026.jpg) - - * If you select a best practice template, you can run the best practice to test the MindSpore environment. - - ![img](images/clip_image027.jpg) - - ![img](images/clip_image028.jpg) - - ![img](images/clip_image029.jpg) diff --git a/docs/devtoolkit/docs/source_en/smart_completion.md b/docs/devtoolkit/docs/source_en/smart_completion.md deleted file mode 100644 index a3082c79785732ff85a829618bce1c5c8ca72f11..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_en/smart_completion.md +++ /dev/null @@ -1,36 +0,0 @@ -# Code Completion - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_en/smart_completion.md) - -## Functions - -* It completes code based on AI for the MindSpore project. -* You can easily develop MindSpore without installing the MindSpore environment. - -## Procedure - -1. Open a Python file and write code. - - ![img](images/clip_image088.jpg) - -2. During encoding, the code completion function is enabled automatically. Code lines with the "MindSpore" identifier are automatically completed by MindSpore Dev Toolkit. - - ![img](images/clip_image090.jpg) - - ![img](images/clip_image092.jpg) - -## Description - -1. In versions later than PyCharm 2021, the completed code will be rearranged based on machine learning. This behavior may cause the plug-in's completed code to be displayed with lower priority. You can disable this function in **Settings** and use MindSpore Dev Toolkit to sort code. - - ![img](images/clip_image093.jpg) - -2. Comparison before and after this function is disabled. - - * Function disabled - - ![img](images/clip_image094.jpg) - - * Function enabled - - ![img](images/clip_image096.jpg) diff --git a/docs/devtoolkit/docs/source_zh_cn/PyCharm_change_version.md b/docs/devtoolkit/docs/source_zh_cn/PyCharm_change_version.md deleted file mode 100644 index d0d553574bedc10d411ff45310718a05903b981d..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/PyCharm_change_version.md +++ /dev/null @@ -1,39 +0,0 @@ -# API映射 - API版本切换 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/PyCharm_change_version.md) - -## 简介 - -API 映射指PyTorch API与MindSpore API的映射关系。 -在MindSpore Dev Toolkit中,提供了API映射搜索和API映射扫描两大功能,且用户可以自由切换API映射数据的版本。 - -## API映射数据版本切换 - -1. 插件启动时,默认使用与插件目前版本相同的API映射数据版本。API映射数据版本在右下显示,此版本号仅影响本章节的API映射功能,不会改变环境中的MindSpore版本。 - - ![img](./images/clip_image137.jpg) - -2. 点击API映射数据版本,弹出选择列表。可以选择点击预设版本切换至其他版本,也可以选择"other version"输入其他版本号尝试切换。 - - ![img](./images/clip_image138.jpg) - -3. 点击任意版本号,开始切换版本。下方有动画提示正在切换的状态。 - - ![img](./images/clip_image139.jpg) - -4. 若想自定义输入版本号,在选择列表中选择"other version"的选项,在弹框中输入版本号,点击ok开始切换版本。注:请按照2.1或2.1.0的格式输入版本号,否则点击ok键会没有反应。 - - ![img](./images/clip_image140.jpg) - -5. 若切换成功,右下状态栏展示切换后的API映射数据版本信息。 - - ![img](./images/clip_image141.jpg) - -6. 若切换失败,右下状态栏展示切换前的API映射数据版本信息。版本号不存在、网络错误会导致切换失败,请排查后再次尝试。如需查看最新文档,可以切换到master版本。 - - ![img](./images/clip_image142.jpg) - -7. 当自定义输入的版本号切换成功后,此版本号会加入到版本列表中展示。 - - ![img](./images/clip_image143.jpg) - diff --git a/docs/devtoolkit/docs/source_zh_cn/PyCharm_plugin_install.md b/docs/devtoolkit/docs/source_zh_cn/PyCharm_plugin_install.md deleted file mode 100644 index 9622c7c3159ed08594dc123af9bffaf56fafc354..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/PyCharm_plugin_install.md +++ /dev/null @@ -1,13 +0,0 @@ -# PyCharm 插件安装 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/PyCharm_plugin_install.md) - -## 安装步骤 - -1. 获取[插件Zip包](https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/IdePlugin/any/MindSpore_Dev_ToolKit-2.1.0.zip)。 -2. 启动Pycharm,单击左上菜单栏,选择File->Settings->Plugins->Install Plugin from Disk。 - 如图: - - ![image-20211223175637989](./images/clip_image050.jpg) - -3. 选择插件zip包。 \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/VSCode_api_scan.md b/docs/devtoolkit/docs/source_zh_cn/VSCode_api_scan.md deleted file mode 100644 index b639ac7c0951b3011232d5723c9cb27c65e12e35..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/VSCode_api_scan.md +++ /dev/null @@ -1,49 +0,0 @@ -# API映射 - API扫描 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/VSCode_api_scan.md) - -## 功能介绍 - -* 快速扫描代码中出现的API,在侧边栏直接展示API详情。 -* 为方便其他机器学习框架用户,通过扫描代码中出现的主流框架API,联想匹配对应MindSpore API。 -* API映射的数据版本支持切换,详情请参考[API映射-版本切换](https://www.mindspore.cn/devtoolkit/docs/zh-CN/master/VSCode_change_version.html)章节。 - -## 文件级API映射扫描 - -1. 在当前文件任意位置处右键,打开菜单,选择“扫描本地文件”。 - - ![img](./images/clip_image116.jpg) - -2. 右边栏会弹出当前文件中扫描出的算子,包括“可以转化的PyTorch API”、“可能是torch.Tensor API的结果”、“暂未提供直接映射关系的PyTorch API”三种扫描结果列表。 - - 其中: - - * "可以转换的PyTorch API"指在文件中被使用的且可以转换为MindSpore API的PyTorch API - * "可能是torch.Tensor API"指名字和torch.Tensor的API名字相同,可能是torch.Tensor的API且可以转换为MindSpore API的API - * "暂未提供直接映射关系的PyTorch API"指虽然是PyTorch API或可能是torch.Tensor的API,但是暂时没有直接对应为MindSpore API的API - - ![img](./images/clip_image117.jpg) - -## 项目级API映射扫描 - -1. 点击Visual Studio Code左侧边栏MindSpore API映射扫描图标。 - - ![img](./images/clip_image118.jpg) - -2. 左边栏会生成当前IDE工程中仅含Python文件的工程树视图。 - - ![img](./images/clip_image119.jpg) - -3. 若选择视图中单个Python文件,可获取该文件的算子扫描结果列表。 - - ![img](./images/clip_image120.jpg) - -4. 若选择视图中文件目录,可获取该目录下所有Python文件的算子扫描结果列表。 - - ![img](./images/clip_image121.jpg) - -5. 蓝色字体部分均可以点击,会自动在用户默认浏览器中打开网页。 - - ![img](./images/clip_image122.jpg) - - ![img](./images/clip_image123.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/VSCode_api_search.md b/docs/devtoolkit/docs/source_zh_cn/VSCode_api_search.md deleted file mode 100644 index e57ed777648e995b432f9a96ec58bc82a64502c3..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/VSCode_api_search.md +++ /dev/null @@ -1,29 +0,0 @@ -# API映射 - API搜索 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/VSCode_api_search.md) - -## 功能介绍 - -* 快速搜索MindSpore API,在侧边栏直接展示API详情。 -* 为方便其他机器学习框架用户,通过搜索其他主流框架API,联想匹配对应MindSpore API。 -* API映射的数据版本支持切换,详情请参考[API映射-版本切换](https://www.mindspore.cn/devtoolkit/docs/zh-CN/master/VSCode_change_version.html)章节。 - -## 使用步骤 - -1. 点击Visual Studio Code左侧边栏MindSpore API映射搜索图标。 - - ![img](./images/clip_image124.jpg) - -2. 左侧边栏会生成一个输入框。 - - ![img](./images/clip_image125.jpg) - -3. 在输入框中输入任意单词,下方会展示出当前关键词的搜索结果,且搜索结果根据输入内容实时更新。 - - ![img](./images/clip_image126.jpg) - -4. 点击任意搜索结果,会在用户默认浏览器中打开网页。 - - ![img](./images/clip_image127.jpg) - - ![img](./images/clip_image128.jpg) diff --git a/docs/devtoolkit/docs/source_zh_cn/VSCode_change_version.md b/docs/devtoolkit/docs/source_zh_cn/VSCode_change_version.md deleted file mode 100644 index 2822c877f55f06ea1c2927f850fb32582cc9dc2c..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/VSCode_change_version.md +++ /dev/null @@ -1,39 +0,0 @@ -# API映射 - 版本切换 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/VSCode_change_version.md) - -## 简介 - -API 映射指PyTorch API与MindSpore API的映射关系。在MindSpore Dev Toolkit中,提供了API映射搜索和API映射扫描两大功能,且用户可以自由切换API映射数据的版本。 - -## API映射数据版本切换 - -1. 不同版本的API映射数据会导致不同的API映射扫描和API映射搜索结果,但不会影响环境中的MindSpore版本。默认版本与插件版本一致,版本信息会展示在左下角状态栏。 - - ![img](./images/clip_image129.jpg) - -2. 点击此状态栏,页面上方会弹出下拉框,包含了默认可以切换的版本号选项。用户可以点击任意版本号切换版本,或者点击”自定义输入“的选项以后,在再次弹出的输入框中输入其他版本号切换版本。 - - ![img](./images/clip_image130.jpg) - -3. 点击任意版本号,开始切换版本,左下角状态栏提示版本切换中的状态。 - - ![img](./images/clip_image131.jpg) - -4. 若想自定义输入版本号,在下拉框时点击“自定义输入”的选项,下拉框转变为输入框,按照2.1或2.1.0的格式输入版本号,按回车键开始切换版本,左下角状态栏提示切换中的状态。 - - ![img](./images/clip_image132.jpg) - - ![img](./images/clip_image133.jpg) - -5. 若切换成功,右下角信息提示切换成功,左下角状态栏展示切换后的API映射数据版本信息。 - - ![img](./images/clip_image134.jpg) - -6. 若切换失败,右下角信息提示切换失败,左下角状态栏展示切换前的API映射数据版本信息。版本号不存在、网络错误会导致切换失败,请排查后再次尝试。如需查看最新文档,可以切换到master版本。 - - ![img](./images/clip_image135.jpg) - -7. 当自定义输入的版本号切换成功后,此版本号会加入到下拉框中展示。 - - ![img](./images/clip_image136.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/VSCode_plugin_install.md b/docs/devtoolkit/docs/source_zh_cn/VSCode_plugin_install.md deleted file mode 100644 index 14ee27f84e661910b7db7959533b337050d0adf3..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/VSCode_plugin_install.md +++ /dev/null @@ -1,18 +0,0 @@ -# Visual Studio Code 插件安装 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/VSCode_plugin_install.md) - -## 安装步骤 - -1. 获取[插件vsix包](https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/IdePlugin/any/mindspore-dev-toolkit-2.1.0.vsix)。 -2. 点击左侧第五个按钮“Extensions”,点击右上角三个点,再点击“Install from VSIX...” - - ![img](./images/clip_image112.jpg) - -3. 从文件夹中选择下载好的vsix文件,插件自动开始安装。右下角提示"Completed installing MindSpore Dev Toolkit extension from VSIX",则插件安装成功。 - - ![img](./images/clip_image113.jpg) - -4. 点击左边栏的刷新按钮,能看到”INSTALLED“目录中有”MindSpore Dev Toolkit"插件,至此插件安装成功。 - - ![img](./images/clip_image114.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/VSCode_smart_completion.md b/docs/devtoolkit/docs/source_zh_cn/VSCode_smart_completion.md deleted file mode 100644 index 80634d7f3c87c03e78e3f85f4b3e5bac03ca3b24..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/VSCode_smart_completion.md +++ /dev/null @@ -1,22 +0,0 @@ -# 代码补全 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/VSCode_smart_completion.md) - -## 功能介绍 - -* 提供基于MindSpore项目的AI代码补全。 -* 无需安装MindSpore环境,也可轻松开发MindSpore。 - -## 使用步骤 - -1. 第一次安装或使用插件时,会自动下载模型,右下角出现"开始下载Model","下载Model成功"提示则表示模型下载且启动成功。若网速较慢,模型需要花费十余分钟下载。下载完成后才会出现"下载Model成功"的字样。若非第一次使用,将不会出现提示。 - - ![img](./images/clip_image115.jpg) - -2. 打开Python文件编写代码。 - - ![img](./images/clip_image097.jpg) - -3. 编码时,补全会自动生效。有MindSpore Dev Toolkit后缀名称的为此插件智能补全提供的代码。 - - ![img](./images/clip_image111.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/api_scanning.md b/docs/devtoolkit/docs/source_zh_cn/api_scanning.md deleted file mode 100644 index f425da32cad074e934b0bf199184ac67d20368ce..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/api_scanning.md +++ /dev/null @@ -1,62 +0,0 @@ -# API映射 - API扫描 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/api_scanning.md) - -## 功能介绍 - -* 快速扫描代码中出现的API,在侧边栏直接展示API详情。 -* 为方便其他机器学习框架用户,通过扫描代码中出现的主流框架API,联想匹配对应MindSpore API。 -* API映射的数据版本支持切换,详情请参考[API映射-版本切换](https://www.mindspore.cn/devtoolkit/docs/zh-CN/master/PyCharm_change_version.html)章节。 - -## 使用步骤 - -### 文件级别API扫描 - -1. 在当前文件任意位置处点击鼠标右键,打开菜单,点击菜单最上方的"API scan"。 - - ![img](./images/clip_image100.jpg) - -2. 右边栏会自动弹出,展示扫描出的API,并展示包含名称,网址等信息的详细列表。若本文件中未扫描到API,则不会弹出窗口。 - - 其中: - - * "可以转换为MindSpore API的PyTorch/TensorFlow API"指在文件中被使用的且可以转换为MindSpore API的PyTorch或TensorFlow API - * "暂时不能转换的API"指虽然是PyTorch或TensorFlow API的API,但是暂时没有直接对应为MindSpore API的API - * "可能是PyTorch/TensorFlow API的情况"指因为链式调用的原因,有可能是PyTorch或TensorFlow的API的可转换情况 - * TensorFlow API扫描是实验性功能 - - ![img](./images/clip_image101.jpg) - -3. 蓝色字体的部分均可以点击,会自动在上方再打开一栏,展示网页。 - - ![img](./images/clip_image102.jpg) - -4. 点击右上角"导出"按钮,可将内容导出到csv表格。 - - ![img](./images/clip_image103.jpg) - -### 项目级别API扫描 - -1. 在当前文件任意位置处点击鼠标右键,打开菜单,点击菜单上方第二个"API scan project-level",或在上方工具栏选择"Tools",再选择"API scan project-level"。 - - ![img](./images/clip_image104.jpg) - - ![img](./images/clip_image105.jpg) - -2. 右边栏会弹出整个项目中扫描出的API,并展示包含名称,网址等信息的详细列表。 - - ![img](./images/clip_image106.jpg) - -3. 在上方框中可以选择单个文件,下方框中将单独展示此文件中的API,文件选择可以随意切换。 - - ![img](./images/clip_image107.jpg) - - ![img](./images/clip_image108.jpg) - -4. 蓝色字体部分均可以点击,会自动在上方再打开一栏,展示网页。 - - ![img](./images/clip_image109.jpg) - -5. 点击"导出"按钮,可将内容导出到csv表格。 - - ![img](./images/clip_image110.jpg) diff --git a/docs/devtoolkit/docs/source_zh_cn/api_search.md b/docs/devtoolkit/docs/source_zh_cn/api_search.md deleted file mode 100644 index 948e7d70cc9131ce71b3571232d025bce5c70b09..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/api_search.md +++ /dev/null @@ -1,29 +0,0 @@ -# API映射 - API互搜 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/api_search.md) - -## 功能介绍 - -* 快速搜索MindSpore API,在侧边栏直接展示API详情。 -* 为方便其他机器学习框架用户,通过搜索其他主流框架API,联想匹配对应MindSpore API。 -* API映射的数据版本支持切换,详情请参考[API映射-版本切换](https://www.mindspore.cn/devtoolkit/docs/zh-CN/master/PyCharm_change_version.html)章节。 - -## 使用步骤 - -1. 双击shift弹出全局搜索页面。 - - ![img](images/clip_image060.jpg) - -2. 选择MindSpore。 - - ![img](images/clip_image062.jpg) - -3. 输入要搜索的PyTorch或TensorFlow的API,获取与MindSpore API的对应关系列表。 - - ![img](images/clip_image064.jpg) - - ![img](images/clip_image066.jpg) - -4. 点击列表中的条目,可以在右边侧边栏浏览对应条目的官网文档。 - - ![img](images/clip_image068.jpg) diff --git a/docs/devtoolkit/docs/source_zh_cn/compiling.md b/docs/devtoolkit/docs/source_zh_cn/compiling.md deleted file mode 100644 index ce30f02e48dbe63d091a3d4d99ee64dddad2b5e7..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/compiling.md +++ /dev/null @@ -1,86 +0,0 @@ -# 源码编译指导 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/compiling.md) - -本文档介绍如何基于IntelliJ IDEA源码编译MindSpore Dev ToolKit项目。 - -## 背景介绍 - -* MindSpore Dev ToolKit是一个PyCharm插件,需使用IntelliJ IDEA开发。[IntelliJ IDEA](https://www.jetbrains.com/idea/download)与Pycharm均为JetBrains公司开发的IDE。 -* MindSpore Dev ToolKit 基于JDK 11开发。 如果您不了解JDK,请访问[https://jdk.java.net/](https://jdk.java.net/)了解并学习JDK以及java的相关知识。 -* MindSpore Dev ToolKit使用[Gradle](https://gradle.org)6.6.1构建,但无需提前安装。IntelliJ IDEA会自动根据代码使用"gradle wrapper"机制配置好所需的gradle。 - -## 依赖软件 - -* 确认安装[IntelliJ IDEA](https://www.jetbrains.com/idea/download)。 - -* 确认安装JDK 11版本。 - 注:2021.3版本的IntelliJ IDEA自带一个名为jbr-11 JetBrains Runtime version 11.0.10的JDK,可以直接使用。 - - ![img](images/clip_image031.jpg) - -## 编译 - -1. 保证依赖软件均已成功配置。 - -2. 从代码仓下载[本项目](https://gitee.com/mindspore/ide-plugin)源码。 - - * 直接下载代码的zip包。 - - ![img](images/clip_image032.jpg) - - * 使用git下载。 - - ``` - git clone https://gitee.com/mindspore/ide-plugin.git - ``` - -3. 使用IntelliJ IDEA打开项目。 - - 3.1 选择File选项卡下的Open选项。***File -> Open*** - - ![img](images/clip_image033.jpg) - - 3.2 打开下载项目文件位置。 - - ![img](images/clip_image034.jpg) - - 3.3 点击右下角弹窗中的load或右键pycharm/settings.gradle文件选中Link Gradle Project。 - - ![img](images/clip_image035.jpg) - - ![img](images/clip_image036.jpg) - -4. 如果提示没有JDK,请选择一个JDK。***有JDK可以跳过此步骤*** - - 4.1 没有JDK情况下,页面如下图显示。 - - ![img](images/clip_image037.jpg) - - 4.2 File->Project Structure。 - - ![img](images/clip_image038.jpg) - - 4.3 选择JDK11。 - - ![img](images/clip_image039.jpg) - -5. 等待同步完成。 - - ![img](images/clip_image040.jpg) - -6. 构建项目。 - - ![img](images/clip_image042.jpg) - -7. 构建完成。 - - ![img](images/clip_image044.jpg) - -8. 构建完成后至项目目录下/pycharm/build/distributions目录下获取插件安装包。 - - ![img](images/clip_image046.jpg) - -## 相关参考文档 - -* 本项目构建基于IntelliJ Platform Plugin SDK之[Building Plugins with Gradle](https://plugins.jetbrains.com/docs/intellij/gradle-build-system.html)章节。如需了解调试等进阶功能,请阅读官方文档。 \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/conf.py b/docs/devtoolkit/docs/source_zh_cn/conf.py deleted file mode 100644 index edcae7146aa7df7da99445745b5b2d269f55f9d6..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/conf.py +++ /dev/null @@ -1,89 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import re - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'myst_parser', - 'sphinx.ext.autodoc' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -# -- Options for HTML output ------------------------------------------------- - -language = 'zh_CN' -locale_dirs = ['../../../../resource/locale/'] -gettext_compact = False - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -html_search_options = {'dict': '../../../resource/jieba.txt'} - -html_static_path = ['_static'] - -src_release = os.path.join(os.getenv("DT_PATH"), 'RELEASE_CN.md') -des_release = "./RELEASE.md" -with open(src_release, "r", encoding="utf-8") as f: - data = f.read() -if len(re.findall("\n## (.*?)\n",data)) > 1: - content = re.findall("(## [\s\S\n]*?)\n## ", data) -else: - content = re.findall("(## [\s\S\n]*)", data) -#result = content[0].replace('# MindSpore', '#', 1) -with open(des_release, "w", encoding="utf-8") as p: - p.write("# Release Notes"+"\n\n") - p.write(content[0]) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image002.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image002.jpg deleted file mode 100644 index 24132302f1552bed6be56b7dd660625448774680..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image002.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image004.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image004.jpg deleted file mode 100644 index 7ed0e3729940f514a7bfd61c1a7be22166c0bb02..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image004.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image006.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image006.jpg deleted file mode 100644 index e0c323eec249024fe19126ce4c931133564cf7b7..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image006.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image008.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image008.jpg deleted file mode 100644 index a071ad67222931372d4b62f7b0cf334a4015e70d..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image008.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image010.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image010.jpg deleted file mode 100644 index 43ca88d40bc56d5a5113bc29b97a7b559ef659af..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image010.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image012.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image012.jpg deleted file mode 100644 index 0e35c9f219292913a51f1f0d5b7a5e154008620f..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image012.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image014.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image014.jpg deleted file mode 100644 index 794de60c7d7e76a58e8d7212e449a1bd8e194b21..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image014.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image015.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image015.jpg deleted file mode 100644 index 8172e21f871bed6866b3a91b83252838228d257a..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image015.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image016.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image016.jpg deleted file mode 100644 index c836c0cf7898e4757ddb3410dde18e754894dd25..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image016.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image018.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image018.jpg deleted file mode 100644 index 777738f7cf60454b7b5f26e6c5a29ba5d55750ff..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image018.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image019.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image019.jpg deleted file mode 100644 index ab02b702bfd1c0986adb4a15c1f455b56df0a4a1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image019.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image020.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image020.jpg deleted file mode 100644 index d946c3cb3a851f690a5643b5afe119597aed5b22..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image020.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image021.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image021.jpg deleted file mode 100644 index 74672d9513f4a60f77450ae5516cee2060215241..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image021.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image022.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image022.jpg deleted file mode 100644 index 6b26f18c7d8bb43db0beb8a8d2bd386489192922..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image022.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image023.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image023.jpg deleted file mode 100644 index 5981a0fb25c681417a0bdf24d2392411b41f0faf..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image023.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image024.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image024.jpg deleted file mode 100644 index 505e8b6c5c4d81dfd67d91b1dad08f938444368a..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image024.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image025.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image025.jpg deleted file mode 100644 index 946e276b6a982303b6b146266037a739e6c2639a..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image025.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image026.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image026.jpg deleted file mode 100644 index 8b787215af5cf9ae223c2b33121c3171923d2de2..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image026.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image027.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image027.jpg deleted file mode 100644 index aa4d7d4a8a6b503fe29885368547daa535e34796..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image027.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image028.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image028.jpg deleted file mode 100644 index 3126f80aecac28e8beaa54dc393122c60dbe1357..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image028.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image029.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image029.jpg deleted file mode 100644 index 6587240e4a456f3792fece52bcdcbbed077ca67b..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image029.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image031.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image031.jpg deleted file mode 100644 index 2f829b48e72e62525860cfe599e0a4ada82010ca..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image031.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image032.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image032.jpg deleted file mode 100644 index 37589efbe0f57c442f665824831a2685d81c8713..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image032.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image033.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image033.jpg deleted file mode 100644 index bdca68324cf7ee8f4e9bd18817a82954910e52c9..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image033.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image034.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image034.jpg deleted file mode 100644 index 874b10d4b2ca476da16a4d1e749cdb6b31ecb59e..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image034.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image035.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image035.jpg deleted file mode 100644 index 0b0465169553e57795320255295b8fa789950522..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image035.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image036.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image036.jpg deleted file mode 100644 index c7c6c72819b655884d97637b696d1814e5a7fdbf..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image036.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image037.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image037.jpg deleted file mode 100644 index 531e8184e02c43aa177a51c3cc32355cc3df9d42..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image037.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image038.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image038.jpg deleted file mode 100644 index a8b4d88190c626139bad49cd42a9f7e908b4d0e4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image038.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image039.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image039.jpg deleted file mode 100644 index 2eab0ceac9c1bd5d8b6ade3d65a6a3ce8b1f8fd4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image039.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image040.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image040.jpg deleted file mode 100644 index a879fb1f12d8b6c4bd02332abf9b3bd734207763..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image040.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image042.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image042.jpg deleted file mode 100644 index 2454ade258da6d428c9e23ece2adf7f0291d1a12..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image042.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image044.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image044.jpg deleted file mode 100644 index cbff652015c36a5856afc909518f3c0fd22f23ff..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image044.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image046.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image046.jpg deleted file mode 100644 index 58a493ea4f69b264fc69cfd3e34f32d5a171c303..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image046.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image050.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image050.jpg deleted file mode 100644 index 35cc26d483358550c9a53ce855c2ae483eddb7e1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image050.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image060.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image060.jpg deleted file mode 100644 index 7723975694f7f56d88187a69626343af11efbd23..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image060.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image062.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image062.jpg deleted file mode 100644 index 838bc48ab8d77f7dbba9ca02925838a49b19ce53..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image062.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image064.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image064.jpg deleted file mode 100644 index fb39e70b78b45af301973ea802a219c482a21590..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image064.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image066.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image066.jpg deleted file mode 100644 index 0a596cfb3ef7a79674ff33a7be5c97859cc2b9c4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image066.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image068.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image068.jpg deleted file mode 100644 index 0023ba9236a768001e462d6a10434719f3f733fd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image068.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image072.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image072.jpg deleted file mode 100644 index d1e5fad4192d4cb5821cafe4031dbc4ff599eccd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image072.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image074.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image074.jpg deleted file mode 100644 index 97fa2b21b4029ff75156893f8abdff2e77aa38bd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image074.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image076.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image076.jpg deleted file mode 100644 index e754c7dcd30ee6fa82ab20bde4d66f69aabe2fa7..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image076.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image088.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image088.jpg deleted file mode 100644 index 8b85f0727893c4cf6cd258550466ca4f4a340e6e..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image088.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image090.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image090.jpg deleted file mode 100644 index a3f405388fd75b23b652bc86475be5fd5e1f48ac..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image090.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image092.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image092.jpg deleted file mode 100644 index 68ca9c66fc3f03760873075af20c8a9e28aaab48..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image092.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image093.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image093.jpg deleted file mode 100644 index 594b2ceadb2c27290e0339e14b298fa2feffe6a9..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image093.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image094.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image094.jpg deleted file mode 100644 index e931a95180d27d55590e73948ebe80a1f81bede1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image094.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image096.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image096.jpg deleted file mode 100644 index 3ed0c88500bb4caffccea4d08aaa3a6310e177bd..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image096.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image097.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image097.jpg deleted file mode 100644 index 0cb303bea0e9e88bf56fe22806c91605f1822606..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image097.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image100.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image100.jpg deleted file mode 100644 index 7dd66fa814e1dcc67b30e41e05ff2d36a2cae2a8..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image100.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image101.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image101.jpg deleted file mode 100644 index 656ec720e30e09c72ea2c61c9caad41282a7f923..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image101.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image102.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image102.jpg deleted file mode 100644 index 973c9940bc8e72ed2355026c98f9885f30096ba4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image102.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image103.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image103.jpg deleted file mode 100644 index e945dc73adf0b3755bbed6f54b8f6255d7d8f3fc..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image103.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image104.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image104.jpg deleted file mode 100644 index 6a20059ff34b48f657c7d7d998597c7f1332d220..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image104.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image105.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image105.jpg deleted file mode 100644 index 62606f6b5b4cc9eabea35e79f2da9dae45b91c29..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image105.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image106.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image106.jpg deleted file mode 100644 index 60a0c13f748cf1265ee43a766a86041fc1abe7c6..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image106.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image107.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image107.jpg deleted file mode 100644 index 19e59bb533d230c64fd12b2a0f4abb5156428695..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image107.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image108.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image108.jpg deleted file mode 100644 index 14bfeaaf6bc7416d4b53aeb85dda5883dfb22aa9..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image108.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image109.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image109.jpg deleted file mode 100644 index 3a155082deeab5246182f3fceb28cc1b65e709e1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image109.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image110.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image110.jpg deleted file mode 100644 index dee5a5107e59244dc24c1a41a928b3aaf705b052..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image110.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image111.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image111.jpg deleted file mode 100644 index 6324111bd6e2bf8f9f62429058b4e3f5fd5b36b9..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image111.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image112.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image112.jpg deleted file mode 100644 index c71a61026a48cd7f1732b4b9d41c00d0c03bc521..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image112.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image113.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image113.jpg deleted file mode 100644 index d44ede801205f4bf4981cf751ca49f5df3324c2e..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image113.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image114.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image114.jpg deleted file mode 100644 index 43dbe99fc91ee7907928e9dabe5e50b1fd202fef..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image114.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image115.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image115.jpg deleted file mode 100644 index e676224e6be68f2b770d8effac56ab5b3f433e99..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image115.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image116.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image116.jpg deleted file mode 100644 index 3c0569c618cb45d19783b5034209050ad1dee716..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image116.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image117.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image117.jpg deleted file mode 100644 index 1f3d079ed853fef95ce56f258b489d31117da8cf..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image117.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image118.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image118.jpg deleted file mode 100644 index 2729fbe6133df67d4fab7438244182ba836f5908..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image118.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image119.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image119.jpg deleted file mode 100644 index 44fd8b841548900ab6ddc598b2ad0124d8341864..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image119.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image120.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image120.jpg deleted file mode 100644 index 19d88f70e7675b110b582164e43ad4f60419c7be..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image120.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image121.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image121.jpg deleted file mode 100644 index 4f997ac892b2ce7d1daecc358afabcccec0d04be..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image121.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image122.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image122.jpg deleted file mode 100644 index 9a16de514b7ec68c579cc02ec680df3db1292746..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image122.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image123.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image123.jpg deleted file mode 100644 index f5f4a82d076f0b22b2e85d139c2ac5b6572bb571..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image123.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image124.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image124.jpg deleted file mode 100644 index e835a4b22a7032069e7c2edb6eda1b012a15671f..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image124.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image125.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image125.jpg deleted file mode 100644 index 3b779d9ba44f054c3673a761417035a86f97db06..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image125.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image126.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image126.jpg deleted file mode 100644 index 93f72bd16af2289bdfdd24120acadb3506f37ab5..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image126.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image127.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image127.jpg deleted file mode 100644 index 09787b77310e4dabb74f48d154168c67204f2243..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image127.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image128.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image128.jpg deleted file mode 100644 index 074ab2fc864e572f1a4c59511316466ec72927e0..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image128.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image129.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image129.jpg deleted file mode 100644 index c144eb5cfdf77caa18567a909b27c55880960a08..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image129.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image130.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image130.jpg deleted file mode 100644 index a88fdddd286bd9b2eb6773734227471f5f2dd655..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image130.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image131.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image131.jpg deleted file mode 100644 index c51c4fdf2370cb292a811dce8b467ded6182d86c..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image131.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image132.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image132.jpg deleted file mode 100644 index 085df85bf7f3f65fb1e21e135cac3ccb18cbe3e0..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image132.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image133.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image133.jpg deleted file mode 100644 index dbd04ac802204eb9327552431a1f5b3fe213f10b..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image133.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image134.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image134.jpg deleted file mode 100644 index 2de14584f45993655ef626b7a76e0b22cab2eeb8..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image134.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image135.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image135.jpg deleted file mode 100644 index aaa7212f107ca3d81470328b37f25e4ba36cf199..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image135.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image136.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image136.jpg deleted file mode 100644 index 3aa85416624254075ae967ff13178df1922ed1b4..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image136.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image137.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image137.jpg deleted file mode 100644 index 5db1e281048f80b4e82cb7d375c4fa08729f4ea7..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image137.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image138.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image138.jpg deleted file mode 100644 index 4884213affb1aee8b54131690cf5042293803eb2..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image138.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image139.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image139.jpg deleted file mode 100644 index 0173dc81f0ba01639ef60397d45185323fe84440..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image139.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image140.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image140.jpg deleted file mode 100644 index 900204ac42a37068c1292b4779a08336e84a88ff..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image140.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image141.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image141.jpg deleted file mode 100644 index 49d46b7bbb3ff297f7975c5f96ccfa42b0c3fdc1..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image141.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image142.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image142.jpg deleted file mode 100644 index 12a49c84869560400bc3859aef3b80cea3bd7722..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image142.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/images/clip_image143.jpg b/docs/devtoolkit/docs/source_zh_cn/images/clip_image143.jpg deleted file mode 100644 index 7a2e5f268171c29cfc84c6f9748f5ebe3a7ef399..0000000000000000000000000000000000000000 Binary files a/docs/devtoolkit/docs/source_zh_cn/images/clip_image143.jpg and /dev/null differ diff --git a/docs/devtoolkit/docs/source_zh_cn/index.rst b/docs/devtoolkit/docs/source_zh_cn/index.rst deleted file mode 100644 index 55cf86cc6f52ce0661e93ad1c642bc37e009e81a..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/index.rst +++ /dev/null @@ -1,61 +0,0 @@ -MindSpore Dev Toolkit文档 -============================ - -MindSpore Dev Toolkit是一款支持MindSpore开发的 `PyCharm `_ (多平台Python IDE)插件,提供 `创建项目 `_ 、 `智能补全 `_ 、 `API互搜 `_ 和 `文档搜索 `_ 等功能。 - -MindSpore Dev Toolkit通过深度学习、智能搜索及智能推荐等技术,打造智能计算最佳体验,致力于全面提升MindSpore框架的易用性,助力MindSpore生态推广。 - -代码仓地址: - -系统需求 ------------------------------- - -- 插件支持的操作系统: - - - Windows 10 - - - Linux - - - MacOS(仅支持x86架构,补全功能暂未上线) - -- 插件支持的PyCharm版本: - - - 2020.3 - - - 2021.X - - - 2022.x - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: PyCharm插件使用指南 - :hidden: - - PyCharm_plugin_install - compiling - smart_completion - PyCharm_change_version - api_search - api_scanning - knowledge_search - mindspore_project_wizard - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: VSCode插件使用指南 - :hidden: - - VSCode_plugin_install - VSCode_smart_completion - VSCode_change_version - VSCode_api_search - VSCode_api_scan - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/knowledge_search.md b/docs/devtoolkit/docs/source_zh_cn/knowledge_search.md deleted file mode 100644 index 6ba1068fafc6861143763f2a30a347bad4ed3ced..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/knowledge_search.md +++ /dev/null @@ -1,22 +0,0 @@ -# 智能知识搜索 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/knowledge_search.md) - -## 功能介绍 - -* 定向推荐:根据用户使用习惯,提供更精准的搜索结果。 -* 沉浸式资料检索体验,避免在IDE和浏览器之间的互相切换。适配侧边栏,提供窄屏适配界面。 - -## 使用步骤 - -1. 打开侧边栏,显示搜索主页。 - - ![img](images/clip_image072.jpg) - -2. 输入"api映射",点击搜索,查看结果。 - - ![img](images/clip_image074.jpg) - -3. 点击home按钮回到主页。 - - ![img](images/clip_image076.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/mindspore_project_wizard.md b/docs/devtoolkit/docs/source_zh_cn/mindspore_project_wizard.md deleted file mode 100644 index ca5972171d1a96375155f6aef3b094cbf73ead13..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/mindspore_project_wizard.md +++ /dev/null @@ -1,103 +0,0 @@ -# 创建项目 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/mindspore_project_wizard.md) - -## 技术背景 - -本功能的实现基于[conda](https://conda.io)。Conda是一个包管理和环境管理系统,是MindSpore推荐的安装方式之一。 - -## 功能介绍 - -* 创建conda环境或选择已有conda环境,并安装MindSpore二进制包至conda环境。 -* 部署最佳实践模版。不仅可以测试环境是否安装成功,对新用户也提供了一个MindSpore的入门介绍。 -* 在网络状况良好时,10分钟之内即可完成环境安装,开始体验MindSpore。最大可节约新用户80%的环境配置时间。 - -## 使用步骤 - -1. 选择**File** > **New Project**。 - - ![img](images/clip_image002.jpg) - -2. 选择**MindSpore**。 - - ![img](images/clip_image004.jpg) - -3. Miniconda下载安装。***已经安装过conda的可以跳过此步骤。*** - - 3.1 点击Install Miniconda Automatically按钮。 - - ![img](images/clip_image006.jpg) - - 3.2 选择下载安装文件夹。**建议不修改路径,使用默认路径安装conda。** - - ![img](images/clip_image008.jpg) - - 3.3 点击**Install**按钮,等待下载安装。 - - ![img](images/clip_image010.jpg) - - ![img](images/clip_image012.jpg) - - 3.4 Miniconda下载安装完成。 - - ![img](images/clip_image014.jpg) - - 3.5 根据提示重新启动PyCharm或者稍后自行重新启动PyCharm。***注意:接下来的步骤必须重启PyCharm后方可继续*** - - ![img](images/clip_image015.jpg) - -4. 确认Conda executable路径已正确填充。 如果Conda executable没有自动填充,点击文件夹按钮,选择本地已安装的conda的路径。 - - ![img](images/clip_image016.jpg) - -5. 创建或选择已有的conda环境。 - - * 创建新的conda环境。 **建议不修改路径,使用默认路径创建conda环境。由于PyCharm限制,Linux系统下暂时无法选择默认目录以外的地址。** - - ![img](images/clip_image018.jpg) - - * 选择PyCharm中已有的conda环境。 - - ![img](images/clip_image019.jpg) - -6. 选择硬件环境和MindSpore项目最佳实践模板。 - - 6.1 选择硬件环境。 - - ![img](images/clip_image020.jpg) - - 6.2 选择最佳实践模板。最佳实践模版是MindSpore提供一些样例项目,以供新用户熟悉MindSpore。最佳实践模版可以直接运行。 - - ![img](images/clip_image021.jpg) - -7. 点击**Create**按钮新建项目,等待MindSpore下载安装成功。 - - 7.1 点击**Create**按钮创建MindSpore新项目。 - - ![img](images/clip_image022.jpg) - - 7.2 正在创建创建conda环境。 - - ![img](images/clip_image023.jpg) - - 7.3 正在通过conda配置MindSpore。 - - ![img](images/clip_image024.jpg) - -8. 创建MindSpore项目完成。 - - ![img](images/clip_image025.jpg) - -9. 验证MindSpore项目是否创建成功。 - - * 点击下方Terminal,输入 python -c "import mindspore;mindspore.run_check()" ,查看输出。如下图,显示了版本号等,表示MindSpore环境可用。 - - ![img](images/clip_image026.jpg) - - * 如果选择了最佳实践模版,可以通过运行最佳实践,测试MindSpore环境。 - - ![img](images/clip_image027.jpg) - - ![img](images/clip_image028.jpg) - - ![img](images/clip_image029.jpg) \ No newline at end of file diff --git a/docs/devtoolkit/docs/source_zh_cn/smart_completion.md b/docs/devtoolkit/docs/source_zh_cn/smart_completion.md deleted file mode 100644 index 5e4cf4edd5272715f7af955a697341e51f465eb5..0000000000000000000000000000000000000000 --- a/docs/devtoolkit/docs/source_zh_cn/smart_completion.md +++ /dev/null @@ -1,36 +0,0 @@ -# 代码补全 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/devtoolkit/docs/source_zh_cn/smart_completion.md) - -## 功能介绍 - -* 提供基于MindSpore项目的AI代码补全。 -* 无需安装MindSpore环境,也可轻松开发MindSpore。 - -## 使用步骤 - -1. 打开Python文件编写代码。 - - ![img](images/clip_image088.jpg) - -2. 编码时,补全会自动生效。有MindSpore图标的条目为MindSpore Dev Toolkit智能补全提供的代码。 - - ![img](images/clip_image090.jpg) - - ![img](images/clip_image092.jpg) - -## 备注 - -1. PyCharm的2021以后版本,会根据机器学习重新排列补全内容。此行为可能导致插件的补全条目排序靠后。可以在设置中停用此功能,使用MindSpore Dev Toolkit提供的排序。 - - ![img](images/clip_image093.jpg) - -2. 关闭此选项前后的对比。 - - * 关闭后。 - - ![img](images/clip_image094.jpg) - - * 关闭前。 - - ![img](images/clip_image096.jpg) \ No newline at end of file diff --git a/docs/federated/docs/Makefile b/docs/federated/docs/Makefile deleted file mode 100644 index 1eff8952707bdfa503c8d60c1e9a903053170ba2..0000000000000000000000000000000000000000 --- a/docs/federated/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source_zh_cn -BUILDDIR = build_zh_cn - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/federated/docs/_ext/overwriteautosummary_generate.txt b/docs/federated/docs/_ext/overwriteautosummary_generate.txt deleted file mode 100644 index 4b0a1b1dd2b410ecab971b13da9993c90d65ef0d..0000000000000000000000000000000000000000 --- a/docs/federated/docs/_ext/overwriteautosummary_generate.txt +++ /dev/null @@ -1,707 +0,0 @@ -""" - sphinx.ext.autosummary.generate - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - Usable as a library or script to generate automatic RST source files for - items referred to in autosummary:: directives. - - Each generated RST file contains a single auto*:: directive which - extracts the docstring of the referred item. - - Example Makefile rule:: - - generate: - sphinx-autogen -o source/generated source/*.rst - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import argparse -import importlib -import inspect -import locale -import os -import pkgutil -import pydoc -import re -import sys -import warnings -from gettext import NullTranslations -from os import path -from typing import Any, Dict, List, NamedTuple, Sequence, Set, Tuple, Type, Union - -from jinja2 import TemplateNotFound -from jinja2.sandbox import SandboxedEnvironment - -import sphinx.locale -from sphinx import __display_version__, package_dir -from sphinx.application import Sphinx -from sphinx.builders import Builder -from sphinx.config import Config -from sphinx.deprecation import RemovedInSphinx50Warning -from sphinx.ext.autodoc import Documenter -from sphinx.ext.autodoc.importer import import_module -from sphinx.ext.autosummary import (ImportExceptionGroup, get_documenter, import_by_name, - import_ivar_by_name) -from sphinx.locale import __ -from sphinx.pycode import ModuleAnalyzer, PycodeError -from sphinx.registry import SphinxComponentRegistry -from sphinx.util import logging, rst, split_full_qualified_name, get_full_modname -from sphinx.util.inspect import getall, safe_getattr -from sphinx.util.osutil import ensuredir -from sphinx.util.template import SphinxTemplateLoader - -logger = logging.getLogger(__name__) - - -class DummyApplication: - """Dummy Application class for sphinx-autogen command.""" - - def __init__(self, translator: NullTranslations) -> None: - self.config = Config() - self.registry = SphinxComponentRegistry() - self.messagelog: List[str] = [] - self.srcdir = "/" - self.translator = translator - self.verbosity = 0 - self._warncount = 0 - self.warningiserror = False - - self.config.add('autosummary_context', {}, True, None) - self.config.add('autosummary_filename_map', {}, True, None) - self.config.add('autosummary_ignore_module_all', True, 'env', bool) - self.config.add('docs_branch', '', True, None) - self.config.add('branch', '', True, None) - self.config.add('cst_module_name', '', True, None) - self.config.add('copy_repo', '', True, None) - self.config.add('giturl', '', True, None) - self.config.add('repo_whl', '', True, None) - self.config.init_values() - - def emit_firstresult(self, *args: Any) -> None: - pass - - -class AutosummaryEntry(NamedTuple): - name: str - path: str - template: str - recursive: bool - - -def setup_documenters(app: Any) -> None: - from sphinx.ext.autodoc import (AttributeDocumenter, ClassDocumenter, DataDocumenter, - DecoratorDocumenter, ExceptionDocumenter, - FunctionDocumenter, MethodDocumenter, ModuleDocumenter, - NewTypeAttributeDocumenter, NewTypeDataDocumenter, - PropertyDocumenter) - documenters: List[Type[Documenter]] = [ - ModuleDocumenter, ClassDocumenter, ExceptionDocumenter, DataDocumenter, - FunctionDocumenter, MethodDocumenter, NewTypeAttributeDocumenter, - NewTypeDataDocumenter, AttributeDocumenter, DecoratorDocumenter, PropertyDocumenter, - ] - for documenter in documenters: - app.registry.add_documenter(documenter.objtype, documenter) - - -def _simple_info(msg: str) -> None: - warnings.warn('_simple_info() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - print(msg) - - -def _simple_warn(msg: str) -> None: - warnings.warn('_simple_warn() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - print('WARNING: ' + msg, file=sys.stderr) - - -def _underline(title: str, line: str = '=') -> str: - if '\n' in title: - raise ValueError('Can only underline single lines') - return title + '\n' + line * len(title) - - -class AutosummaryRenderer: - """A helper class for rendering.""" - - def __init__(self, app: Union[Builder, Sphinx], template_dir: str = None) -> None: - if isinstance(app, Builder): - warnings.warn('The first argument for AutosummaryRenderer has been ' - 'changed to Sphinx object', - RemovedInSphinx50Warning, stacklevel=2) - if template_dir: - warnings.warn('template_dir argument for AutosummaryRenderer is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - system_templates_path = [os.path.join(package_dir, 'ext', 'autosummary', 'templates')] - loader = SphinxTemplateLoader(app.srcdir, app.config.templates_path, - system_templates_path) - - self.env = SandboxedEnvironment(loader=loader) - self.env.filters['escape'] = rst.escape - self.env.filters['e'] = rst.escape - self.env.filters['underline'] = _underline - - if isinstance(app, (Sphinx, DummyApplication)): - if app.translator: - self.env.add_extension("jinja2.ext.i18n") - self.env.install_gettext_translations(app.translator) - elif isinstance(app, Builder): - if app.app.translator: - self.env.add_extension("jinja2.ext.i18n") - self.env.install_gettext_translations(app.app.translator) - - def exists(self, template_name: str) -> bool: - """Check if template file exists.""" - warnings.warn('AutosummaryRenderer.exists() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - try: - self.env.get_template(template_name) - return True - except TemplateNotFound: - return False - - def render(self, template_name: str, context: Dict) -> str: - """Render a template file.""" - try: - template = self.env.get_template(template_name) - except TemplateNotFound: - try: - # objtype is given as template_name - template = self.env.get_template('autosummary/%s.rst' % template_name) - except TemplateNotFound: - # fallback to base.rst - template = self.env.get_template('autosummary/base.rst') - - return template.render(context) - - -# -- Generating output --------------------------------------------------------- - - -class ModuleScanner: - def __init__(self, app: Any, obj: Any) -> None: - self.app = app - self.object = obj - - def get_object_type(self, name: str, value: Any) -> str: - return get_documenter(self.app, value, self.object).objtype - - def is_skipped(self, name: str, value: Any, objtype: str) -> bool: - try: - return self.app.emit_firstresult('autodoc-skip-member', objtype, - name, value, False, {}) - except Exception as exc: - logger.warning(__('autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s'), - name, exc, type='autosummary') - return False - - def scan(self, imported_members: bool) -> List[str]: - members = [] - for name in members_of(self.object, self.app.config): - try: - value = safe_getattr(self.object, name) - except AttributeError: - value = None - - objtype = self.get_object_type(name, value) - if self.is_skipped(name, value, objtype): - continue - - try: - if inspect.ismodule(value): - imported = True - elif safe_getattr(value, '__module__') != self.object.__name__: - imported = True - else: - imported = False - except AttributeError: - imported = False - - respect_module_all = not self.app.config.autosummary_ignore_module_all - if imported_members: - # list all members up - members.append(name) - elif imported is False: - # list not-imported members - members.append(name) - elif '__all__' in dir(self.object) and respect_module_all: - # list members that have __all__ set - members.append(name) - - return members - - -def members_of(obj: Any, conf: Config) -> Sequence[str]: - """Get the members of ``obj``, possibly ignoring the ``__all__`` module attribute - - Follows the ``conf.autosummary_ignore_module_all`` setting.""" - - if conf.autosummary_ignore_module_all: - return dir(obj) - else: - return getall(obj) or dir(obj) - - -def generate_autosummary_content(name: str, obj: Any, parent: Any, - template: AutosummaryRenderer, template_name: str, - imported_members: bool, app: Any, - recursive: bool, context: Dict, - modname: str = None, qualname: str = None) -> str: - doc = get_documenter(app, obj, parent) - - def skip_member(obj: Any, name: str, objtype: str) -> bool: - try: - return app.emit_firstresult('autodoc-skip-member', objtype, name, - obj, False, {}) - except Exception as exc: - logger.warning(__('autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s'), - name, exc, type='autosummary') - return False - - def get_class_members(obj: Any) -> Dict[str, Any]: - members = sphinx.ext.autodoc.get_class_members(obj, [qualname], safe_getattr) - return {name: member.object for name, member in members.items()} - - def get_module_members(obj: Any) -> Dict[str, Any]: - members = {} - for name in members_of(obj, app.config): - try: - members[name] = safe_getattr(obj, name) - except AttributeError: - continue - return members - - def get_all_members(obj: Any) -> Dict[str, Any]: - if doc.objtype == "module": - return get_module_members(obj) - elif doc.objtype == "class": - return get_class_members(obj) - return {} - - def get_members(obj: Any, types: Set[str], include_public: List[str] = [], - imported: bool = True) -> Tuple[List[str], List[str]]: - items: List[str] = [] - public: List[str] = [] - - all_members = get_all_members(obj) - for name, value in all_members.items(): - documenter = get_documenter(app, value, obj) - if documenter.objtype in types: - # skip imported members if expected - if imported or getattr(value, '__module__', None) == obj.__name__: - skipped = skip_member(value, name, documenter.objtype) - if skipped is True: - pass - elif skipped is False: - # show the member forcedly - items.append(name) - public.append(name) - else: - items.append(name) - if name in include_public or not name.startswith('_'): - # considers member as public - public.append(name) - return public, items - - def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: - """Find module attributes with docstrings.""" - attrs, public = [], [] - try: - analyzer = ModuleAnalyzer.for_module(name) - attr_docs = analyzer.find_attr_docs() - for namespace, attr_name in attr_docs: - if namespace == '' and attr_name in members: - attrs.append(attr_name) - if not attr_name.startswith('_'): - public.append(attr_name) - except PycodeError: - pass # give up if ModuleAnalyzer fails to parse code - return public, attrs - - def get_modules(obj: Any) -> Tuple[List[str], List[str]]: - items: List[str] = [] - for _, modname, _ispkg in pkgutil.iter_modules(obj.__path__): - fullname = name + '.' + modname - try: - module = import_module(fullname) - if module and hasattr(module, '__sphinx_mock__'): - continue - except ImportError: - pass - - items.append(fullname) - public = [x for x in items if not x.split('.')[-1].startswith('_')] - return public, items - - ns: Dict[str, Any] = {} - ns.update(context) - - if doc.objtype == 'module': - scanner = ModuleScanner(app, obj) - ns['members'] = scanner.scan(imported_members) - ns['functions'], ns['all_functions'] = \ - get_members(obj, {'function'}, imported=imported_members) - ns['classes'], ns['all_classes'] = \ - get_members(obj, {'class'}, imported=imported_members) - ns['exceptions'], ns['all_exceptions'] = \ - get_members(obj, {'exception'}, imported=imported_members) - ns['attributes'], ns['all_attributes'] = \ - get_module_attrs(ns['members']) - ispackage = hasattr(obj, '__path__') - if ispackage and recursive: - ns['modules'], ns['all_modules'] = get_modules(obj) - elif doc.objtype == 'class': - ns['members'] = dir(obj) - ns['inherited_members'] = \ - set(dir(obj)) - set(obj.__dict__.keys()) - ns['methods'], ns['all_methods'] = \ - get_members(obj, {'method'}, ['__init__']) - ns['attributes'], ns['all_attributes'] = \ - get_members(obj, {'attribute', 'property'}) - - if modname is None or qualname is None: - modname, qualname = split_full_qualified_name(name) - - if doc.objtype in ('method', 'attribute', 'property'): - ns['class'] = qualname.rsplit(".", 1)[0] - - if doc.objtype in ('class',): - shortname = qualname - else: - shortname = qualname.rsplit(".", 1)[-1] - - ns['fullname'] = name - ns['module'] = modname - ns['objname'] = qualname - ns['name'] = shortname - - ns['objtype'] = doc.objtype - ns['underline'] = len(name) * '=' - - if template_name: - return template.render(template_name, ns) - else: - return template.render(doc.objtype, ns) - - -def generate_autosummary_docs(sources: List[str], output_dir: str = None, - suffix: str = '.rst', base_path: str = None, - builder: Builder = None, template_dir: str = None, - imported_members: bool = False, app: Any = None, - overwrite: bool = True, encoding: str = 'utf-8') -> None: - - if builder: - warnings.warn('builder argument for generate_autosummary_docs() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - if template_dir: - warnings.warn('template_dir argument for generate_autosummary_docs() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - showed_sources = list(sorted(sources)) - if len(showed_sources) > 20: - showed_sources = showed_sources[:10] + ['...'] + showed_sources[-10:] - logger.info(__('[autosummary] generating autosummary for: %s') % - ', '.join(showed_sources)) - - if output_dir: - logger.info(__('[autosummary] writing to %s') % output_dir) - - if base_path is not None: - sources = [os.path.join(base_path, filename) for filename in sources] - - template = AutosummaryRenderer(app) - - # read - items = find_autosummary_in_files(sources) - - # keep track of new files - new_files = [] - - if app: - filename_map = app.config.autosummary_filename_map - else: - filename_map = {} - - # write - for entry in sorted(set(items), key=str): - if entry.path is None: - # The corresponding autosummary:: directive did not have - # a :toctree: option - continue - - path = output_dir or os.path.abspath(entry.path) - ensuredir(path) - - try: - name, obj, parent, modname = import_by_name(entry.name, grouped_exception=True) - qualname = name.replace(modname + ".", "") - except ImportExceptionGroup as exc: - try: - # try to import as an instance attribute - name, obj, parent, modname = import_ivar_by_name(entry.name) - qualname = name.replace(modname + ".", "") - except ImportError as exc2: - if exc2.__cause__: - exceptions: List[BaseException] = exc.exceptions + [exc2.__cause__] - else: - exceptions = exc.exceptions + [exc2] - - errors = list(set("* %s: %s" % (type(e).__name__, e) for e in exceptions)) - logger.warning(__('[autosummary] failed to import %s.\nPossible hints:\n%s'), - entry.name, '\n'.join(errors)) - continue - - context: Dict[str, Any] = {} - if app: - context.update(app.config.autosummary_context) - - content = generate_autosummary_content(name, obj, parent, template, entry.template, - imported_members, app, entry.recursive, context, - modname, qualname) - try: - py_source_rel = get_full_modname(modname, qualname).replace('.', '/') + '.py' - except: - logger.warning(name) - py_source_rel = '' - - re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{app.config.docs_branch}/" + \ - f"resource/_static/logo_source_en.svg\n :target: " + app.config.giturl + \ - f"{app.config.copy_repo}/blob/{app.config.branch}/" + app.config.repo_whl + \ - py_source_rel.split(app.config.cst_module_name)[-1] + '\n :alt: View Source On Gitee\n\n' - - if re_view not in content and py_source_rel: - content = re.sub('([=]{5,})\n', r'\1\n' + re_view, content, 1) - filename = os.path.join(path, filename_map.get(name, name) + suffix) - if os.path.isfile(filename): - with open(filename, encoding=encoding) as f: - old_content = f.read() - - if content == old_content: - continue - elif overwrite: # content has changed - with open(filename, 'w', encoding=encoding) as f: - f.write(content) - new_files.append(filename) - else: - with open(filename, 'w', encoding=encoding) as f: - f.write(content) - new_files.append(filename) - - # descend recursively to new files - if new_files: - generate_autosummary_docs(new_files, output_dir=output_dir, - suffix=suffix, base_path=base_path, - builder=builder, template_dir=template_dir, - imported_members=imported_members, app=app, - overwrite=overwrite) - - -# -- Finding documented entries in files --------------------------------------- - -def find_autosummary_in_files(filenames: List[str]) -> List[AutosummaryEntry]: - """Find out what items are documented in source/*.rst. - - See `find_autosummary_in_lines`. - """ - documented: List[AutosummaryEntry] = [] - for filename in filenames: - with open(filename, encoding='utf-8', errors='ignore') as f: - lines = f.read().splitlines() - documented.extend(find_autosummary_in_lines(lines, filename=filename)) - return documented - - -def find_autosummary_in_docstring(name: str, module: str = None, filename: str = None - ) -> List[AutosummaryEntry]: - """Find out what items are documented in the given object's docstring. - - See `find_autosummary_in_lines`. - """ - if module: - warnings.warn('module argument for find_autosummary_in_docstring() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - try: - real_name, obj, parent, modname = import_by_name(name, grouped_exception=True) - lines = pydoc.getdoc(obj).splitlines() - return find_autosummary_in_lines(lines, module=name, filename=filename) - except AttributeError: - pass - except ImportExceptionGroup as exc: - errors = list(set("* %s: %s" % (type(e).__name__, e) for e in exc.exceptions)) - print('Failed to import %s.\nPossible hints:\n%s' % (name, '\n'.join(errors))) - except SystemExit: - print("Failed to import '%s'; the module executes module level " - "statement and it might call sys.exit()." % name) - return [] - - -def find_autosummary_in_lines(lines: List[str], module: str = None, filename: str = None - ) -> List[AutosummaryEntry]: - """Find out what items appear in autosummary:: directives in the - given lines. - - Returns a list of (name, toctree, template) where *name* is a name - of an object and *toctree* the :toctree: path of the corresponding - autosummary directive (relative to the root of the file name), and - *template* the value of the :template: option. *toctree* and - *template* ``None`` if the directive does not have the - corresponding options set. - """ - autosummary_re = re.compile(r'^(\s*)\.\.\s+(ms[a-z]*)?autosummary::\s*') - automodule_re = re.compile( - r'^\s*\.\.\s+automodule::\s*([A-Za-z0-9_.]+)\s*$') - module_re = re.compile( - r'^\s*\.\.\s+(current)?module::\s*([a-zA-Z0-9_.]+)\s*$') - autosummary_item_re = re.compile(r'^\s+(~?[_a-zA-Z][a-zA-Z0-9_.]*)\s*.*?') - recursive_arg_re = re.compile(r'^\s+:recursive:\s*$') - toctree_arg_re = re.compile(r'^\s+:toctree:\s*(.*?)\s*$') - template_arg_re = re.compile(r'^\s+:template:\s*(.*?)\s*$') - - documented: List[AutosummaryEntry] = [] - - recursive = False - toctree: str = None - template = None - current_module = module - in_autosummary = False - base_indent = "" - - for line in lines: - if in_autosummary: - m = recursive_arg_re.match(line) - if m: - recursive = True - continue - - m = toctree_arg_re.match(line) - if m: - toctree = m.group(1) - if filename: - toctree = os.path.join(os.path.dirname(filename), - toctree) - continue - - m = template_arg_re.match(line) - if m: - template = m.group(1).strip() - continue - - if line.strip().startswith(':'): - continue # skip options - - m = autosummary_item_re.match(line) - if m: - name = m.group(1).strip() - if name.startswith('~'): - name = name[1:] - if current_module and \ - not name.startswith(current_module + '.'): - name = "%s.%s" % (current_module, name) - documented.append(AutosummaryEntry(name, toctree, template, recursive)) - continue - - if not line.strip() or line.startswith(base_indent + " "): - continue - - in_autosummary = False - - m = autosummary_re.match(line) - if m: - in_autosummary = True - base_indent = m.group(1) - recursive = False - toctree = None - template = None - continue - - m = automodule_re.search(line) - if m: - current_module = m.group(1).strip() - # recurse into the automodule docstring - documented.extend(find_autosummary_in_docstring( - current_module, filename=filename)) - continue - - m = module_re.match(line) - if m: - current_module = m.group(2) - continue - - return documented - - -def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - usage='%(prog)s [OPTIONS] ...', - epilog=__('For more information, visit .'), - description=__(""" -Generate ReStructuredText using autosummary directives. - -sphinx-autogen is a frontend to sphinx.ext.autosummary.generate. It generates -the reStructuredText files from the autosummary directives contained in the -given input files. - -The format of the autosummary directive is documented in the -``sphinx.ext.autosummary`` Python module and can be read using:: - - pydoc sphinx.ext.autosummary -""")) - - parser.add_argument('--version', action='version', dest='show_version', - version='%%(prog)s %s' % __display_version__) - - parser.add_argument('source_file', nargs='+', - help=__('source files to generate rST files for')) - - parser.add_argument('-o', '--output-dir', action='store', - dest='output_dir', - help=__('directory to place all output in')) - parser.add_argument('-s', '--suffix', action='store', dest='suffix', - default='rst', - help=__('default suffix for files (default: ' - '%(default)s)')) - parser.add_argument('-t', '--templates', action='store', dest='templates', - default=None, - help=__('custom template directory (default: ' - '%(default)s)')) - parser.add_argument('-i', '--imported-members', action='store_true', - dest='imported_members', default=False, - help=__('document imported members (default: ' - '%(default)s)')) - parser.add_argument('-a', '--respect-module-all', action='store_true', - dest='respect_module_all', default=False, - help=__('document exactly the members in module __all__ attribute. ' - '(default: %(default)s)')) - - return parser - - -def main(argv: List[str] = sys.argv[1:]) -> None: - sphinx.locale.setlocale(locale.LC_ALL, '') - sphinx.locale.init_console(os.path.join(package_dir, 'locale'), 'sphinx') - translator, _ = sphinx.locale.init([], None) - - app = DummyApplication(translator) - logging.setup(app, sys.stdout, sys.stderr) # type: ignore - setup_documenters(app) - args = get_parser().parse_args(argv) - - if args.templates: - app.config.templates_path.append(path.abspath(args.templates)) - app.config.autosummary_ignore_module_all = not args.respect_module_all # type: ignore - - generate_autosummary_docs(args.source_file, args.output_dir, - '.' + args.suffix, - imported_members=args.imported_members, - app=app) - - -if __name__ == '__main__': - main() diff --git a/docs/federated/docs/_ext/overwriteobjectiondirective.txt b/docs/federated/docs/_ext/overwriteobjectiondirective.txt deleted file mode 100644 index 8a58bf71191f77ca22097ea9de244c9df5c3d4fb..0000000000000000000000000000000000000000 --- a/docs/federated/docs/_ext/overwriteobjectiondirective.txt +++ /dev/null @@ -1,368 +0,0 @@ -""" - sphinx.directives - ~~~~~~~~~~~~~~~~~ - - Handlers for additional ReST directives. - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import re -import inspect -import importlib -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Tuple, TypeVar, cast - -from docutils import nodes -from docutils.nodes import Node -from docutils.parsers.rst import directives, roles - -from sphinx import addnodes -from sphinx.addnodes import desc_signature -from sphinx.deprecation import RemovedInSphinx50Warning, deprecated_alias -from sphinx.util import docutils, logging -from sphinx.util.docfields import DocFieldTransformer, Field, TypedField -from sphinx.util.docutils import SphinxDirective -from sphinx.util.typing import OptionSpec - -if TYPE_CHECKING: - from sphinx.application import Sphinx - - -# RE to strip backslash escapes -nl_escape_re = re.compile(r'\\\n') -strip_backslash_re = re.compile(r'\\(.)') - -T = TypeVar('T') -logger = logging.getLogger(__name__) - -def optional_int(argument: str) -> int: - """ - Check for an integer argument or None value; raise ``ValueError`` if not. - """ - if argument is None: - return None - else: - value = int(argument) - if value < 0: - raise ValueError('negative value; must be positive or zero') - return value - -def get_api(fullname): - try: - module_name, api_name= ".".join(fullname.split('.')[:-1]), fullname.split('.')[-1] - module_import = importlib.import_module(module_name) - except ModuleNotFoundError: - module_name, api_name = ".".join(fullname.split('.')[:-2]), ".".join(fullname.split('.')[-2:]) - module_import = importlib.import_module(module_name) - api = eval(f"module_import.{api_name}") - return api - -def get_example(name: str): - try: - api_doc = inspect.getdoc(get_api(name)) - example_str = re.findall(r'Examples:\n([\w\W]*?)(\n\n|$)', api_doc) - if not example_str: - return [] - example_str = re.sub(r'\n\s+', r'\n', example_str[0][0]) - example_str = example_str.strip() - example_list = example_str.split('\n') - return ["", "**样例:**", ""] + example_list + [""] - except: - return [] - -def get_platforms(name: str): - try: - api_doc = inspect.getdoc(get_api(name)) - example_str = re.findall(r'Supported Platforms:\n\s+(.*?)\n\n', api_doc) - if not example_str: - example_str_leak = re.findall(r'Supported Platforms:\n\s+(.*)', api_doc) - if example_str_leak: - example_str = example_str_leak[0].strip() - example_list = example_str.split('\n') - example_list = [' ' + example_list[0]] - return ["", "支持平台:"] + example_list + [""] - return [] - example_str = example_str[0].strip() - example_list = example_str.split('\n') - example_list = [' ' + example_list[0]] - return ["", "支持平台:"] + example_list + [""] - except: - return [] - -class ObjectDescription(SphinxDirective, Generic[T]): - """ - Directive to describe a class, function or similar object. Not used - directly, but subclassed (in domain-specific directives) to add custom - behavior. - """ - - has_content = True - required_arguments = 1 - optional_arguments = 0 - final_argument_whitespace = True - option_spec: OptionSpec = { - 'noindex': directives.flag, - } # type: Dict[str, DirectiveOption] - - # types of doc fields that this directive handles, see sphinx.util.docfields - doc_field_types: List[Field] = [] - domain: str = None - objtype: str = None - indexnode: addnodes.index = None - - # Warning: this might be removed in future version. Don't touch this from extensions. - _doc_field_type_map = {} # type: Dict[str, Tuple[Field, bool]] - - def get_field_type_map(self) -> Dict[str, Tuple[Field, bool]]: - if self._doc_field_type_map == {}: - self._doc_field_type_map = {} - for field in self.doc_field_types: - for name in field.names: - self._doc_field_type_map[name] = (field, False) - - if field.is_typed: - typed_field = cast(TypedField, field) - for name in typed_field.typenames: - self._doc_field_type_map[name] = (field, True) - - return self._doc_field_type_map - - def get_signatures(self) -> List[str]: - """ - Retrieve the signatures to document from the directive arguments. By - default, signatures are given as arguments, one per line. - - Backslash-escaping of newlines is supported. - """ - lines = nl_escape_re.sub('', self.arguments[0]).split('\n') - if self.config.strip_signature_backslash: - # remove backslashes to support (dummy) escapes; helps Vim highlighting - return [strip_backslash_re.sub(r'\1', line.strip()) for line in lines] - else: - return [line.strip() for line in lines] - - def handle_signature(self, sig: str, signode: desc_signature) -> Any: - """ - Parse the signature *sig* into individual nodes and append them to - *signode*. If ValueError is raised, parsing is aborted and the whole - *sig* is put into a single desc_name node. - - The return value should be a value that identifies the object. It is - passed to :meth:`add_target_and_index()` unchanged, and otherwise only - used to skip duplicates. - """ - raise ValueError - - def add_target_and_index(self, name: Any, sig: str, signode: desc_signature) -> None: - """ - Add cross-reference IDs and entries to self.indexnode, if applicable. - - *name* is whatever :meth:`handle_signature()` returned. - """ - return # do nothing by default - - def before_content(self) -> None: - """ - Called before parsing content. Used to set information about the current - directive context on the build environment. - """ - pass - - def transform_content(self, contentnode: addnodes.desc_content) -> None: - """ - Called after creating the content through nested parsing, - but before the ``object-description-transform`` event is emitted, - and before the info-fields are transformed. - Can be used to manipulate the content. - """ - pass - - def after_content(self) -> None: - """ - Called after parsing content. Used to reset information about the - current directive context on the build environment. - """ - pass - - def check_class_end(self, content): - for i in content: - if not i.startswith('.. include::') and i != "\n" and i != "": - return False - return True - - def extend_items(self, rst_file, start_num, num): - ls = [] - for i in range(1, num+1): - ls.append((rst_file, start_num+i)) - return ls - - def run(self) -> List[Node]: - """ - Main directive entry function, called by docutils upon encountering the - directive. - - This directive is meant to be quite easily subclassable, so it delegates - to several additional methods. What it does: - - * find out if called as a domain-specific directive, set self.domain - * create a `desc` node to fit all description inside - * parse standard options, currently `noindex` - * create an index node if needed as self.indexnode - * parse all given signatures (as returned by self.get_signatures()) - using self.handle_signature(), which should either return a name - or raise ValueError - * add index entries using self.add_target_and_index() - * parse the content and handle doc fields in it - """ - if ':' in self.name: - self.domain, self.objtype = self.name.split(':', 1) - else: - self.domain, self.objtype = '', self.name - self.indexnode = addnodes.index(entries=[]) - - node = addnodes.desc() - node.document = self.state.document - node['domain'] = self.domain - # 'desctype' is a backwards compatible attribute - node['objtype'] = node['desctype'] = self.objtype - node['noindex'] = noindex = ('noindex' in self.options) - if self.domain: - node['classes'].append(self.domain) - node['classes'].append(node['objtype']) - - self.names: List[T] = [] - signatures = self.get_signatures() - for sig in signatures: - # add a signature node for each signature in the current unit - # and add a reference target for it - signode = addnodes.desc_signature(sig, '') - self.set_source_info(signode) - node.append(signode) - try: - # name can also be a tuple, e.g. (classname, objname); - # this is strictly domain-specific (i.e. no assumptions may - # be made in this base class) - name = self.handle_signature(sig, signode) - except ValueError: - # signature parsing failed - signode.clear() - signode += addnodes.desc_name(sig, sig) - continue # we don't want an index entry here - if name not in self.names: - self.names.append(name) - if not noindex: - # only add target and index entry if this is the first - # description of the object with this name in this desc block - self.add_target_and_index(name, sig, signode) - - contentnode = addnodes.desc_content() - node.append(contentnode) - if self.names: - # needed for association of version{added,changed} directives - self.env.temp_data['object'] = self.names[0] - self.before_content() - try: - example = get_example(self.names[0][0]) - platforms = get_platforms(self.names[0][0]) - except Exception as e: - example = '' - platforms = '' - logger.warning(f'Error API names in {self.arguments[0]}.') - logger.warning(f'{e}') - extra = platforms + example - if extra: - if self.objtype == "method": - self.content.data.extend(extra) - else: - index_num = 0 - for num, i in enumerate(self.content.data): - if i.startswith('.. py:method::') or self.check_class_end(self.content.data[num:]): - index_num = num - break - if index_num: - count = len(self.content.data) - for i in extra: - self.content.data.insert(index_num-count, i) - else: - self.content.data.extend(extra) - try: - self.content.items.extend(self.extend_items(self.content.items[0][0], self.content.items[-1][1], len(extra))) - except Exception as e: - logger.warning(f'{e}') - self.state.nested_parse(self.content, self.content_offset, contentnode) - self.transform_content(contentnode) - self.env.app.emit('object-description-transform', - self.domain, self.objtype, contentnode) - DocFieldTransformer(self).transform_all(contentnode) - self.env.temp_data['object'] = None - self.after_content() - return [self.indexnode, node] - - -class DefaultRole(SphinxDirective): - """ - Set the default interpreted text role. Overridden from docutils. - """ - - optional_arguments = 1 - final_argument_whitespace = False - - def run(self) -> List[Node]: - if not self.arguments: - docutils.unregister_role('') - return [] - role_name = self.arguments[0] - role, messages = roles.role(role_name, self.state_machine.language, - self.lineno, self.state.reporter) - if role: - docutils.register_role('', role) - self.env.temp_data['default_role'] = role_name - else: - literal_block = nodes.literal_block(self.block_text, self.block_text) - reporter = self.state.reporter - error = reporter.error('Unknown interpreted text role "%s".' % role_name, - literal_block, line=self.lineno) - messages += [error] - - return cast(List[nodes.Node], messages) - - -class DefaultDomain(SphinxDirective): - """ - Directive to (re-)set the default domain for this source file. - """ - - has_content = False - required_arguments = 1 - optional_arguments = 0 - final_argument_whitespace = False - option_spec = {} # type: Dict - - def run(self) -> List[Node]: - domain_name = self.arguments[0].lower() - # if domain_name not in env.domains: - # # try searching by label - # for domain in env.domains.values(): - # if domain.label.lower() == domain_name: - # domain_name = domain.name - # break - self.env.temp_data['default_domain'] = self.env.domains.get(domain_name) - return [] - -def setup(app: "Sphinx") -> Dict[str, Any]: - app.add_config_value("strip_signature_backslash", False, 'env') - directives.register_directive('default-role', DefaultRole) - directives.register_directive('default-domain', DefaultDomain) - directives.register_directive('describe', ObjectDescription) - # new, more consistent, name - directives.register_directive('object', ObjectDescription) - - app.add_event('object-description-transform') - - return { - 'version': 'builtin', - 'parallel_read_safe': True, - 'parallel_write_safe': True, - } - diff --git a/docs/federated/docs/_ext/overwriteviewcode.txt b/docs/federated/docs/_ext/overwriteviewcode.txt deleted file mode 100644 index 172780ec56b3ed90e7b0add617257a618cf38ee0..0000000000000000000000000000000000000000 --- a/docs/federated/docs/_ext/overwriteviewcode.txt +++ /dev/null @@ -1,378 +0,0 @@ -""" - sphinx.ext.viewcode - ~~~~~~~~~~~~~~~~~~~ - - Add links to module code in Python object descriptions. - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import posixpath -import traceback -import warnings -from os import path -from typing import Any, Dict, Generator, Iterable, Optional, Set, Tuple, cast - -from docutils import nodes -from docutils.nodes import Element, Node - -import sphinx -from sphinx import addnodes -from sphinx.application import Sphinx -from sphinx.builders import Builder -from sphinx.builders.html import StandaloneHTMLBuilder -from sphinx.deprecation import RemovedInSphinx50Warning -from sphinx.environment import BuildEnvironment -from sphinx.locale import _, __ -from sphinx.pycode import ModuleAnalyzer -from sphinx.transforms.post_transforms import SphinxPostTransform -from sphinx.util import get_full_modname, logging, status_iterator -from sphinx.util.nodes import make_refnode - - -logger = logging.getLogger(__name__) - - -OUTPUT_DIRNAME = '_modules' - - -class viewcode_anchor(Element): - """Node for viewcode anchors. - - This node will be processed in the resolving phase. - For viewcode supported builders, they will be all converted to the anchors. - For not supported builders, they will be removed. - """ - - -def _get_full_modname(app: Sphinx, modname: str, attribute: str) -> Optional[str]: - try: - return get_full_modname(modname, attribute) - except AttributeError: - # sphinx.ext.viewcode can't follow class instance attribute - # then AttributeError logging output only verbose mode. - logger.verbose('Didn\'t find %s in %s', attribute, modname) - return None - except Exception as e: - # sphinx.ext.viewcode follow python domain directives. - # because of that, if there are no real modules exists that specified - # by py:function or other directives, viewcode emits a lot of warnings. - # It should be displayed only verbose mode. - logger.verbose(traceback.format_exc().rstrip()) - logger.verbose('viewcode can\'t import %s, failed with error "%s"', modname, e) - return None - - -def is_supported_builder(builder: Builder) -> bool: - if builder.format != 'html': - return False - elif builder.name == 'singlehtml': - return False - elif builder.name.startswith('epub') and not builder.config.viewcode_enable_epub: - return False - else: - return True - - -def doctree_read(app: Sphinx, doctree: Node) -> None: - env = app.builder.env - if not hasattr(env, '_viewcode_modules'): - env._viewcode_modules = {} # type: ignore - - def has_tag(modname: str, fullname: str, docname: str, refname: str) -> bool: - entry = env._viewcode_modules.get(modname, None) # type: ignore - if entry is False: - return False - - code_tags = app.emit_firstresult('viewcode-find-source', modname) - if code_tags is None: - try: - analyzer = ModuleAnalyzer.for_module(modname) - analyzer.find_tags() - except Exception: - env._viewcode_modules[modname] = False # type: ignore - return False - - code = analyzer.code - tags = analyzer.tags - else: - code, tags = code_tags - - if entry is None or entry[0] != code: - entry = code, tags, {}, refname - env._viewcode_modules[modname] = entry # type: ignore - _, tags, used, _ = entry - if fullname in tags: - used[fullname] = docname - return True - - return False - - for objnode in list(doctree.findall(addnodes.desc)): - if objnode.get('domain') != 'py': - continue - names: Set[str] = set() - for signode in objnode: - if not isinstance(signode, addnodes.desc_signature): - continue - modname = signode.get('module') - fullname = signode.get('fullname') - try: - if fullname and modname==None: - if fullname.split('.')[-1].lower() == fullname.split('.')[-1] and fullname.split('.')[-2].lower() != fullname.split('.')[-2]: - modname = '.'.join(fullname.split('.')[:-2]) - fullname = '.'.join(fullname.split('.')[-2:]) - else: - modname = '.'.join(fullname.split('.')[:-1]) - fullname = fullname.split('.')[-1] - fullname_new = fullname - except Exception: - logger.warning(f'error_modename:{modname}') - logger.warning(f'error_fullname:{fullname}') - refname = modname - if env.config.viewcode_follow_imported_members: - new_modname = app.emit_firstresult( - 'viewcode-follow-imported', modname, fullname, - ) - if not new_modname: - new_modname = _get_full_modname(app, modname, fullname) - modname = new_modname - # logger.warning(f'new_modename:{modname}') - if not modname: - continue - # fullname = signode.get('fullname') - # if fullname and modname==None: - fullname = fullname_new - if not has_tag(modname, fullname, env.docname, refname): - continue - if fullname in names: - # only one link per name, please - continue - names.add(fullname) - pagename = posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/')) - signode += viewcode_anchor(reftarget=pagename, refid=fullname, refdoc=env.docname) - - -def env_merge_info(app: Sphinx, env: BuildEnvironment, docnames: Iterable[str], - other: BuildEnvironment) -> None: - if not hasattr(other, '_viewcode_modules'): - return - # create a _viewcode_modules dict on the main environment - if not hasattr(env, '_viewcode_modules'): - env._viewcode_modules = {} # type: ignore - # now merge in the information from the subprocess - for modname, entry in other._viewcode_modules.items(): # type: ignore - if modname not in env._viewcode_modules: # type: ignore - env._viewcode_modules[modname] = entry # type: ignore - else: - if env._viewcode_modules[modname]: # type: ignore - used = env._viewcode_modules[modname][2] # type: ignore - for fullname, docname in entry[2].items(): - if fullname not in used: - used[fullname] = docname - - -def env_purge_doc(app: Sphinx, env: BuildEnvironment, docname: str) -> None: - modules = getattr(env, '_viewcode_modules', {}) - - for modname, entry in list(modules.items()): - if entry is False: - continue - - code, tags, used, refname = entry - for fullname in list(used): - if used[fullname] == docname: - used.pop(fullname) - - if len(used) == 0: - modules.pop(modname) - - -class ViewcodeAnchorTransform(SphinxPostTransform): - """Convert or remove viewcode_anchor nodes depends on builder.""" - default_priority = 100 - - def run(self, **kwargs: Any) -> None: - if is_supported_builder(self.app.builder): - self.convert_viewcode_anchors() - else: - self.remove_viewcode_anchors() - - def convert_viewcode_anchors(self) -> None: - for node in self.document.findall(viewcode_anchor): - anchor = nodes.inline('', _('[源代码]'), classes=['viewcode-link']) - refnode = make_refnode(self.app.builder, node['refdoc'], node['reftarget'], - node['refid'], anchor) - node.replace_self(refnode) - - def remove_viewcode_anchors(self) -> None: - for node in list(self.document.findall(viewcode_anchor)): - node.parent.remove(node) - - -def missing_reference(app: Sphinx, env: BuildEnvironment, node: Element, contnode: Node - ) -> Optional[Node]: - # resolve our "viewcode" reference nodes -- they need special treatment - if node['reftype'] == 'viewcode': - warnings.warn('viewcode extension is no longer use pending_xref node. ' - 'Please update your extension.', RemovedInSphinx50Warning) - return make_refnode(app.builder, node['refdoc'], node['reftarget'], - node['refid'], contnode) - - return None - - -def get_module_filename(app: Sphinx, modname: str) -> Optional[str]: - """Get module filename for *modname*.""" - source_info = app.emit_firstresult('viewcode-find-source', modname) - if source_info: - return None - else: - try: - filename, source = ModuleAnalyzer.get_module_source(modname) - return filename - except Exception: - return None - - -def should_generate_module_page(app: Sphinx, modname: str) -> bool: - """Check generation of module page is needed.""" - module_filename = get_module_filename(app, modname) - if module_filename is None: - # Always (re-)generate module page when module filename is not found. - return True - - builder = cast(StandaloneHTMLBuilder, app.builder) - basename = modname.replace('.', '/') + builder.out_suffix - page_filename = path.join(app.outdir, '_modules/', basename) - - try: - if path.getmtime(module_filename) <= path.getmtime(page_filename): - # generation is not needed if the HTML page is newer than module file. - return False - except IOError: - pass - - return True - - -def collect_pages(app: Sphinx) -> Generator[Tuple[str, Dict[str, Any], str], None, None]: - env = app.builder.env - if not hasattr(env, '_viewcode_modules'): - return - if not is_supported_builder(app.builder): - return - highlighter = app.builder.highlighter # type: ignore - urito = app.builder.get_relative_uri - - modnames = set(env._viewcode_modules) # type: ignore - - for modname, entry in status_iterator( - sorted(env._viewcode_modules.items()), # type: ignore - __('highlighting module code... '), "blue", - len(env._viewcode_modules), # type: ignore - app.verbosity, lambda x: x[0]): - if not entry: - continue - if not should_generate_module_page(app, modname): - continue - - code, tags, used, refname = entry - # construct a page name for the highlighted source - pagename = posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/')) - # highlight the source using the builder's highlighter - if env.config.highlight_language in ('python3', 'default', 'none'): - lexer = env.config.highlight_language - else: - lexer = 'python' - highlighted = highlighter.highlight_block(code, lexer, linenos=False) - # split the code into lines - lines = highlighted.splitlines() - # split off wrap markup from the first line of the actual code - before, after = lines[0].split('
')
-        lines[0:1] = [before + '
', after]
-        # nothing to do for the last line; it always starts with 
anyway - # now that we have code lines (starting at index 1), insert anchors for - # the collected tags (HACK: this only works if the tag boundaries are - # properly nested!) - maxindex = len(lines) - 1 - for name, docname in used.items(): - type, start, end = tags[name] - backlink = urito(pagename, docname) + '#' + refname + '.' + name - lines[start] = ( - '
%s' % (name, backlink, _('[文档]')) + - lines[start]) - lines[min(end, maxindex)] += '
' - # try to find parents (for submodules) - parents = [] - parent = modname - while '.' in parent: - parent = parent.rsplit('.', 1)[0] - if parent in modnames: - parents.append({ - 'link': urito(pagename, - posixpath.join(OUTPUT_DIRNAME, parent.replace('.', '/'))), - 'title': parent}) - parents.append({'link': urito(pagename, posixpath.join(OUTPUT_DIRNAME, 'index')), - 'title': _('Module code')}) - parents.reverse() - # putting it all together - context = { - 'parents': parents, - 'title': modname, - 'body': (_('

Source code for %s

') % modname + - '\n'.join(lines)), - } - yield (pagename, context, 'page.html') - - if not modnames: - return - - html = ['\n'] - # the stack logic is needed for using nested lists for submodules - stack = [''] - for modname in sorted(modnames): - if modname.startswith(stack[-1]): - stack.append(modname + '.') - html.append('
    ') - else: - stack.pop() - while not modname.startswith(stack[-1]): - stack.pop() - html.append('
') - stack.append(modname + '.') - html.append('
  • %s
  • \n' % ( - urito(posixpath.join(OUTPUT_DIRNAME, 'index'), - posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/'))), - modname)) - html.append('' * (len(stack) - 1)) - context = { - 'title': _('Overview: module code'), - 'body': (_('

    All modules for which code is available

    ') + - ''.join(html)), - } - - yield (posixpath.join(OUTPUT_DIRNAME, 'index'), context, 'page.html') - - -def setup(app: Sphinx) -> Dict[str, Any]: - app.add_config_value('viewcode_import', None, False) - app.add_config_value('viewcode_enable_epub', False, False) - app.add_config_value('viewcode_follow_imported_members', True, False) - app.connect('doctree-read', doctree_read) - app.connect('env-merge-info', env_merge_info) - app.connect('env-purge-doc', env_purge_doc) - app.connect('html-collect-pages', collect_pages) - app.connect('missing-reference', missing_reference) - # app.add_config_value('viewcode_include_modules', [], 'env') - # app.add_config_value('viewcode_exclude_modules', [], 'env') - app.add_event('viewcode-find-source') - app.add_event('viewcode-follow-imported') - app.add_post_transform(ViewcodeAnchorTransform) - return { - 'version': sphinx.__display_version__, - 'env_version': 1, - 'parallel_read_safe': True - } diff --git a/docs/federated/docs/requirements.txt b/docs/federated/docs/requirements.txt deleted file mode 100644 index a1b6a69f6dbd9c6f78710f56889e14f0e85b27f4..0000000000000000000000000000000000000000 --- a/docs/federated/docs/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -sphinx == 4.4.0 -docutils == 0.17.1 -myst-parser == 0.18.1 -sphinx_rtd_theme == 1.0.0 -numpy -IPython -jieba diff --git a/docs/federated/docs/source_en/Data_Join.rst b/docs/federated/docs/source_en/Data_Join.rst deleted file mode 100644 index 85f2da5d13e7350bf3a9f748fa4cdad8e8efe815..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/Data_Join.rst +++ /dev/null @@ -1,12 +0,0 @@ -Data Join -===================== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/Data_Join.rst - :alt: View Source on Gitee - -.. toctree:: - :maxdepth: 1 - - data_join/data_join - data_join/private_set_intersection \ No newline at end of file diff --git a/docs/federated/docs/source_en/communication_compression.md b/docs/federated/docs/source_en/communication_compression.md deleted file mode 100644 index c797eb4d9cc29d5d75043b2ad80de46d09bb01b4..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/communication_compression.md +++ /dev/null @@ -1,139 +0,0 @@ -# Device-Cloud Federated Learning Communication Compression - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/communication_compression.md) - -During the horizontal device-side federated learning training process, the traffic volume affects the user experience of the device-side (user traffic, communication latency, number of FL-Client participants) and is limited by the cloud-side performance constraints (memory, bandwidth, CPU usage). To improve user experience and reduce performance bottlenecks, MindSpore federated learning framework provides traffic compression for upload and download in device-cloud federated scenarios. - -## Compression Method - -### Uploading Compression Method - -The upload compression method can be divided into three main parts: weight difference codec, sparse codec and quantization codec. The flowcharts on FL-Client and FL-Server are given below. - -![Upload compression client execution order](./images/upload_compression_client_en.png) - -Fig.1 Flowchart of the upload compression method on FL-Client - -![Upload compression server execution order](./images/upload_compress_server_en.png) - -Fig.2 Flowchart of the upload compression method on FL-Server - -### Weight Difference Codec - -The weight difference is the vector difference of the weight matrix before and after the device-side training. Compared with the original weights, the distribution of the weight difference is more in line with the Gaussian distribution and therefore more suitable to be compressed. FL-Client performs the encoding operation on the weight difference, while FL-Server performs the decoding operation. Note that in order to reduce the weight difference to weights before FL-Server aggregates the weights, FL-Client does not multiply the weights by the amount of data when uploading the weights. When FL-Server decodes, it needs to multiply the weights by the amount of data. - -![Weight difference encoding](./images/weight_diff_encode_en.png) - -Fig.3 Flow chart of weight difference encoding on FL-Client - -![Weight difference decoding](./images/weight_diff_decode_en.png) - -Fig.4 Flow chart of weight difference decoding on FL-Server - -### Sparse Codec - -The device-side and cloud-side follow the same random algorithm to generate a sparse mask matrix that has the same shape as the original weights that need to be uploaded. The mask matrix contains only two values, 0 or 1. Each FL-Client only uploads data with the same weight as the non-zero value position of the mask matrix to the FL-Server. - -Take the sparse method with a sparse rate of sparse_rate=0.08 as an example. The parameters that are required to be uploaded by FL-Client: - -| Parameters | Length | -| -------------------- | ----- | -| albert.pooler.weight | 97344 | -| albert.pooler.bias | 312 | -| classifier.weight | 1560 | -| classifier.bias | 5 | - -Concatenate all parameters as one-dimensional vectors: - -| Parameters | Length | -| ----------- | ---------------------- | -| merged_data | 97344+312+1560+5=99221 | - -Generate a mask vector with the same length as the concatenated parameter. There are 7937 values of 1, i.e., 7937 = int(sparse_rate*concatenated parameter length) and the rest have a value of 0, i.e., mask_vector = (1,1,1,... ,0,0,0,...): - -| Parameters | Length | -| ----------- | --------- | -| mask_vector | 99221 | - -Use a pseudo-random algorithm to randomize the mask_vector. The random seed is the current number of iteration. Take out the indexes in the mask_vector with value 1. Take out the value of merged_data[indexes], i.e. the compressed vector. - -| Parameters | Length | -| ----------- | --------- | -| compressed_vector | 7937 | - -After sparse compression, the parameter that FL-Client needs to upload is the compressed_vector. - -After receiving the compressed_vector, FL-Server first constructs the mask vector mask_vector with the same pseudo-random algorithm and random seeds as FL-Client. Then it takes out the indexes with the value of 1 in the mask_vector. Generate the all-zero matrix with the same shape as the model. The values in compressed_vector are put into weight_vector[indexes] in turn. weight_vector is the sparsely decoded vector. - -### Quantization Codec - -The quantization compression method is approximating communication data fixed-point of floating-point type to a finite number of discrete values. - -Taking the 8-bit quantization as an example: - -Quantify the number of bits num_bits = 8 - -The floating-point data before compression is - -data = [0.03356021, -0.01842778, -0.009684053, 0.025363436, -0.027571501, 0.0077043395, 0.016391572, -0.03598478, -0.0009508357] - -Compute the max and min values: - -min_val = -0.03598478 - -max_val = 0.03356021 - -Calculate scaling factor: - -scale = (max_val - min_val ) / (2 ^ num_bits - 1) = 0.000272725450980392 - -Convert the pre-compressed data to an integer between -128 and 127 with the conversion formula quant_data = round((data - min_val) / scale) - 2 ^ (num_bits - 1). And strongly convert the data type to int8: - -quant_data = [127, -64, -32, 97, -97, 32, 64, -128, 0] - -After the quantitative encoding, the parameters that FL-Client needs to upload are quant_data and the minimum and maximum values min_val and max_val. - -After receiving quant_data, min_val and max_val, FL-Server uses the inverse quantization formula (quant_data + 2 ^ (num_bits - 1)) * (max_val - min_val) / (2 ^ num_bits - 1) + min_val to reduce the weights. - -## Downloading Compression Method - -The download compression method is mainly a quantization codec operation, and the flow charts on FL-Server and FL-Client are given below. - -![Download compression server execution order](./images/download_compress_server_en.png) - -Fig.5 Flowchart of the download compression method on FL-Server - -![Download compression client execution order](./images/download_compress_client_en.png) - -Fig.6 Flowchart of the download compression method on FL-Client - -### Quantization Codec - -The quantization codec is the same as that in upload compression. - -## Code Implementation Preparation - -To use the upload and download compression methods, first successfully complete the training aggregation process for either device or cloud federated scenario, e.g. [Implementing a Sentiment Classification Application (Android)](https://www.mindspore.cn/federated/docs/en/master/sentiment_classification_application.html). The preparation work including datasets and network models and the simulation of the process to initiate multi-client participation in federated learning are described in detail in this document. - -## Algorithm Open Script - -The upload and download compression methods are currently only supported in the device-cloud federated learning scenario. The open method requires setting `upload_compress_type='DIFF_SPARSE_QUANT'` and `download_compress_type='QUANT'` in the corresponding yaml in the server startup script when starting the cloud-side service. The above two hyperparameters control the upload and download compression methods on and off, respectively. - -The relevant parameter configuration to start the algorithm is given in the cloud-side [full startup script](https://gitee.com/mindspore/federated/tree/master/tests/st/cross_device_cloud/). After determining the parameter configuration, the user needs to configure the corresponding parameters before executing the training, as follows: - -```yaml -compression: - upload_compress_type: NO_COMPRESS - upload_sparse_rate: 0.4 - download_compress_type: NO_COMPRESS -``` - -| Hyperparameter Names and Reference Values | Hyperparameter Description | -| ---------------------- | ------------------------------------------------------------ | -| upload_compress_type | Upload compression type, string type, including: "NO_COMPRESS", "DIFF_SPARSE_QUANT" | -| upload_sparse_rate | Sparse ratio, i.e., weight retention, float type, defined in the domain (0, 1] | -| download_compress_type | Download compression type, string type, including: "NO_COMPRESS", "QUANT" | - -## ALBERT Results - -The total number of federated learning iterations is 100. The number of client local training epochs is 1. The number of clients is 20. The batchSize is set to 16. The learning rate is 1e-5. Both upload and download compression methods are turned on. The upload sparse ratio is 0.4. The final accuracy on the validation set is 72.5%, and 72.3% for the common federated scenario without compression. diff --git a/docs/federated/docs/source_en/conf.py b/docs/federated/docs/source_en/conf.py deleted file mode 100644 index e63a2a2d9ce5e931baad97629fdd909a7d7c71e4..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/conf.py +++ /dev/null @@ -1,204 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import shutil -import sys -import IPython -import re -import sphinx.ext.autosummary.generate as g -from sphinx.ext import autodoc as sphinx_autodoc - -import mindspore - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_parser', - 'sphinx.ext.mathjax', - 'IPython.sphinxext.ipython_console_highlighting' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -autodoc_inherit_docstrings = False - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -import sphinx_rtd_theme -layout_target = os.path.join(os.path.dirname(sphinx_rtd_theme.__file__), 'layout.html') -layout_src = '../../../../resource/_static/layout.html' -if os.path.exists(layout_target): - os.remove(layout_target) -shutil.copy(layout_src, layout_target) - -html_search_language = 'en' - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/', '../../../../resource/python_objects.inv'), - 'numpy': ('https://docs.scipy.org/doc/numpy/', '../../../../resource/numpy_objects.inv'), -} - -# overwriteautosummary_generate add view source for api and more autosummary class availably. -with open('../_ext/overwriteautosummary_generate.txt', 'r', encoding="utf8") as f: - exec(f.read(), g.__dict__) - -# Modify default signatures for autodoc. -autodoc_source_path = os.path.abspath(sphinx_autodoc.__file__) -autodoc_source_re = re.compile(r'stringify_signature\(.*?\)') -get_param_func_str = r"""\ -import re -import inspect as inspect_ - -def get_param_func(func): - try: - source_code = inspect_.getsource(func) - if func.__doc__: - source_code = source_code.replace(func.__doc__, '') - all_params_str = re.findall(r"def [\w_\d\-]+\(([\S\s]*?)(\):|\) ->.*?:)", source_code) - all_params = re.sub("(self|cls)(,|, )?", '', all_params_str[0][0].replace("\n", "").replace("'", "\"")) - return all_params - except: - return '' - -def get_obj(obj): - if isinstance(obj, type): - return obj.__init__ - - return obj -""" - -with open(autodoc_source_path, "r+", encoding="utf8") as f: - code_str = f.read() - code_str = autodoc_source_re.sub('"(" + get_param_func(get_obj(self.object)) + ")"', code_str, count=0) - exec(get_param_func_str, sphinx_autodoc.__dict__) - exec(code_str, sphinx_autodoc.__dict__) - -import mindspore_federated - -# Copy source files of en python api from mindspore repository. -src_dir_en = os.path.join(os.getenv("MF_PATH"), 'docs/api/api_python_en') -present_path = os.path.dirname(__file__) - -for i in os.listdir(src_dir_en): - if os.path.isfile(os.path.join(src_dir_en,i)): - if os.path.exists('./'+i): - os.remove('./'+i) - shutil.copy(os.path.join(src_dir_en,i),'./'+i) - else: - if os.path.exists('./'+i): - shutil.rmtree('./'+i) - shutil.copytree(os.path.join(src_dir_en,i),'./'+i) - -# get params for add view source -import json - -if os.path.exists('../../../../tools/generate_html/version.json'): - with open('../../../../tools/generate_html/version.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily_dev.json'): - with open('../../../../tools/generate_html/daily_dev.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily.json'): - with open('../../../../tools/generate_html/daily.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) - -if os.getenv("MF_PATH").split('/')[-1]: - copy_repo = os.getenv("MF_PATH").split('/')[-1] -else: - copy_repo = os.getenv("MF_PATH").split('/')[-2] - -branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == copy_repo][0] -docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == 'tutorials'][0] -cst_module_name = 'mindspore_federated' -repo_whl = 'mindspore_federated' -giturl = 'https://gitee.com/mindspore/' - -sys.path.append(os.path.abspath('../../../../resource/sphinx_ext')) -# import anchor_mod -import nbsphinx_mod - -sys.path.append(os.path.abspath('../../../../resource/search')) -import search_code - - -sys.path.append(os.path.abspath('../../../../resource/custom_directives')) -from custom_directives import IncludeCodeDirective - -def setup(app): - app.add_directive('includecode', IncludeCodeDirective) - app.add_config_value('docs_branch', '', True) - app.add_config_value('branch', '', True) - app.add_config_value('cst_module_name', '', True) - app.add_config_value('copy_repo', '', True) - app.add_config_value('giturl', '', True) - app.add_config_value('repo_whl', '', True) - -src_release = os.path.join(os.getenv("MF_PATH"), 'RELEASE.md') -des_release = "./RELEASE.md" -with open(src_release, "r", encoding="utf-8") as f: - data = f.read() -if len(re.findall("\n## (.*?)\n",data)) > 1: - content = re.findall("(## [\s\S\n]*?)\n## ", data) -else: - content = re.findall("(## [\s\S\n]*)", data) -#result = content[0].replace('# MindSpore', '#', 1) -with open(des_release, "w", encoding="utf-8") as p: - p.write("# Release Notes"+"\n\n") - p.write(content[0]) \ No newline at end of file diff --git a/docs/federated/docs/source_en/cross_device.rst b/docs/federated/docs/source_en/cross_device.rst deleted file mode 100644 index fbbbf5f34ac67b7fa855b672393e32d67755a79b..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/cross_device.rst +++ /dev/null @@ -1,17 +0,0 @@ -Device-side Client -====================== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/cross_device.rst - :alt: View Source on Gitee - -.. toctree:: - :maxdepth: 1 - - java_api_callback - java_api_client - java_api_clientmanager - java_api_dataset - java_api_flparameter - java_api_syncfljob - interface_description_federated_client diff --git a/docs/federated/docs/source_en/data_join.md b/docs/federated/docs/source_en/data_join.md deleted file mode 100644 index c76aab81230f589794fb91f515879802457cdafd..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/data_join.md +++ /dev/null @@ -1,241 +0,0 @@ -# Vertical Federated Learning Data Access - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/data_join.md) - -Unlike horizontal federated learning, two participants (leader and follower) have the same sample space for training or inference in vertical federated learning. Therefore, the data intersection must be done collaboratively before both parties in vertical federated learning initiate training or inference. Both parties must read their respective original data and extract the ID (unique identifier of each data, and none of them is the same) corresponding to each data for intersection (i.e., finding the intersection). Then, both parties obtain features or tags from the original data based on the intersected IDs. Finally, each side exports the persistence file and reads the data in the reordering manner before subsequent training or inference. - -## Overall Process - -Data access can be divided into two parts: data export and data read. - -### Exporting Data - -The MindSpore Federated vertical federated learning data export process framework is shown in Figure 1: - -![](./images/data_join_en.png) - -Fig. 1 Vertical Federated Learning Data Export Process Framework Diagram - -In the data export process, Leader Worker and Follower Worker are the two participants in the vertical federated learning. The Leader Worker is resident and keeps a listening ear on the Follower Worker, who can enter the data access process at any moment. - -After the Leader Worker receives a registration request from the Follower Worker, it checks the registration content. If the registration is successful, the task-related hyperparameters (PSI-related hyperparameters, bucketing rules, ID field names, etc.) are sent to the Follower Worker. - -The Leader Worker and Follower Worker read their respective raw data, extract the list of IDs from their raw data and implement bucketing. - -Each bucket of Leader Worker and Follower Worker initiates the privacy intersection method to obtain the ID intersections of the two parties. - -Finally, the two parties extract the corresponding data from the original data based on the ID intersections and export it to a file in MindRecord format. - -### Reading Data - -Vertical federated requires that both participants have the same value and order of data IDs for each batch of training or inference. MindSpore Federated ensures that the data is read in the same order by using the same random seed and by using dictionary sorting on the exported file sets when both parties read their respective data. - -## An Example for Quick Experience - -### Sample Data Preparation - -To use the data access method, the original data needs to be prepared first. The user can use [random data generation script](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/generate_random_data.py) to generate forged data for each participant as a sample. - -```shell -python generate_random_data.py \ - --seed=0 \ - --total_output_path=vfl/input/total_data.csv \ - --intersection_output_path=vfl/input/intersection_data.csv \ - --leader_output_path=vfl/input/leader_data_*.csv \ - --follower_output_path=vfl/input/follower_data_*.csv \ - --leader_file_num=4 \ - --follower_file_num=2 \ - --leader_data_num=300 \ - --follower_data_num=200 \ - --overlap_num=100 \ - --id_len=20 \ - --feature_num=30 -``` - -The user can set the hyperparameter according to the actual situation: - -| Hyperparameter names | Hyperparameter description | -| -------------------- | ------------------------------------------------------------ | -| seed | Random seed, int type. | -| total_output_path | The output path of all data, str type. | -| intersection_output_path | The output path of intersection data, str type. | -| leader_output_path | The export path of the leader data. If the configuration includes the `*`, the `*` will be replaced by the serial number of 0, 1, 2 ...... in order when exporting multiple files. str type. | -| follower_output_path | The export path of the follower data. If the configuration includes the `*`, the `*` will be replaced by the serial number of 0, 1, 2 ...... in order when exporting multiple files. str type. | -| leader_file_num | The number of output files for leader data. int type. | -| follower_file_num | The number of output files for follower data. int type. | -| leader_data_num | The total number of leader data. int type. | -| follower_data_num | The total number of follower data. int type. | -| overlap_num | The total amount of data that overlaps between leader and follower data. int type. | -| id_len | The data ID is a string type. The hyperparameter is the length of the string. int type. | -| feature_num | The number of columns of the exported data | - -Multiple csv files are generated after running the data preparation: - -```text -follower_data_0.csv -follower_data_1.csv -intersection_data.csv -leader_data_0.csv -leader_data_1.csv -leader_data_2.csv -leader_data_3.csv -``` - -### Sample of Data Export - -Users can use [script of finding data intersections](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/run_data_join.py) to implement data intersections between two parties and export it to MindRecord format file. The users need to start Leader and Follower processes separately. - -Start Leader: - -```shell -python run_data_join.py \ - --role="leader" \ - --main_table_files="vfl/input/leader/" \ - --output_dir="vfl/output/leader/" \ - --data_schema_path="vfl/leader_schema.yaml" \ - --server_name=leader_node \ - --http_server_address="127.0.0.1:1086" \ - --remote_server_name=follower_node \ - --remote_server_address="127.0.0.1:1087" \ - --primary_key="oaid" \ - --bucket_num=5 \ - --store_type="csv" \ - --shard_num=1 \ - --join_type="psi" \ - --thread_num=0 -``` - -Start Follower: - -```shell -python run_data_join.py \ - --role="follower" \ - --main_table_files="vfl/input/follower/" \ - --output_dir="vfl/output/follower/" \ - --data_schema_path="vfl/follower_schema.yaml" \ - --server_name=follower_node \ - --http_server_address="127.0.0.1:1087" \ - --remote_server_name=leader_node \ - --remote_server_address="127.0.0.1:1086" \ - --store_type="csv" \ - --thread_num=0 -``` - -The user can set the hyperparameter according to the actual situation. - -| Hyperparameter names | Hyperparameter description | -| ------------------- | ------------------------------------------------------- | -| role | Role types of the worker. str type. Including: "leader", "follower". | -| main_table_files | The path of raw data, configure either single or multiple file paths, data directory paths, list or str types | -| output_dir | The directory path of the exported MindRecord related files, str type. | -| data_schema_path | The path of the super reference file to be configured during export, str type. | -| server_name |Name of local http server that used for communication, str type. | -| http_server_address | Local IP and port address, str type. | -| remote_server_name | Name of remote http server that used for communication, str type. | -| remote_server_address | Peer IP and port address, str type. | -| primary_key (Follower does not need to be configured) | The name of data ID, str type. | -| bucket_num (Follower does not need to be configured) | Find the number of sub-buckets when intersecting and exporting, int type. | -| store_type | Raw data storage type, str type. Including: "csv". | -| shard_num (Follower does not need to be configured) | The number of files exported from a single bucket, int type. | -| join_type (Follower does not need to be configured) | Algorithm of intersection finding, str type. Including: "psi". | -| thread_num | Calculate the number of threads required when using the PSI intersection algorithm, int type. | - -In the above sample, the files corresponding data_schema_path can be referred to the corresponding files configuration of [leader_schema.yaml](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/vfl/leader_schema.yaml) and [follower_schema.yaml](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/vfl/follower_schema.yaml). The user needs to provide the column names and types of the data to be exported in this file. - -After running the data export, generate multiple MindRecord related files. - -```text -mindrecord_0 -mindrecord_0.db -mindrecord_1 -mindrecord_1.db -mindrecord_2 -mindrecord_2.db -mindrecord_3 -mindrecord_3.db -mindrecord_4 -mindrecord_4.db -``` - -### Sample of Data Reading - -The user can use the [script of reading data](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/load_joined_data.py) to implement data reading after intersection. - -```shell -python load_joined_data.py \ - --seed=0 \ - --input_dir=vfl/output/leader/ \ - --shuffle=True -``` - -The user can set the hyperparameter according to the actual situation. - -| Hyperparameter names | Hyperparameter description | -| --------- | ----------------------------------------- | -| seed | Random seed. int type. | -| input_dir | The directory of the input MindRecord related files, str type. | -| shuffle | Whether the data order needs to be changed, bool type. | - -If the intersection result is correct, when each of the two parties reads the data, the OAID order of each data of the two parties is the same, while the data of the other columns in each data can be different values. Print the intersection data after running the data read: - -```text -Leader data export results: -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'uMbgxIMMwWhMGrVMVtM7')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'IwoGP08kWVtT4WHL2PLu')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'MSRe6mURtxgyEgWzDn0b')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'y7X0WcMKnTLrhxVcWfGF')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'DicKRIVvbOYSiv63TvcL')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'TCHgtynOhH3z11QYemsH')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'OWmhgIfC3k8UTteGUhni')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'NTV3qEYXBHqKBWyHGc7s')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'wuinSeN1bzYgXy4XmSlR')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'SSsCU0Pb46XGzUIa3Erg')} -…… - -Follower data export results: -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'uMbgxIMMwWhMGrVMVtM7')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'IwoGP08kWVtT4WHL2PLu')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'MSRe6mURtxgyEgWzDn0b')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'y7X0WcMKnTLrhxVcWfGF')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'DicKRIVvbOYSiv63TvcL')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'TCHgtynOhH3z11QYemsH')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'OWmhgIfC3k8UTteGUhni')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'NTV3qEYXBHqKBWyHGc7s')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'wuinSeN1bzYgXy4XmSlR')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'SSsCU0Pb46XGzUIa3Erg')} -…… -``` - -## An Example for Deep Experience - -For detailed API documentation for the following code, see [Data Access Documentation](https://www.mindspore.cn/federated/docs/en/master/data_join/data_join.html). - -### Data Export - -The user can implement data join and MindRecord related files export by using the encapsulated interface and yaml file in the following way: - -```python -from mindspore_federated import FLDataWorker -from mindspore_federated.common.config import get_config - - -if __name__ == '__main__': - current_dir = os.path.dirname(os.path.abspath(__file__)) - args = get_config(os.path.join(current_dir, "vfl/vfl_data_join_config.yaml")) - dict_cfg = args.__dict__ - - worker = FLDataWorker(config=dict_cfg) - worker.do_worker() -``` - -### Data Reading - -The user can implement data in exported MindRecord related files reading by using the encapsulated interface in the following way: - -```python -from mindspore_federated.data_join import load_mindrecord - - -if __name__ == "__main__": - dataset = load_mindrecord(input_dir="vfl/output/leader/", shuffle=True, seed=0) -``` diff --git a/docs/federated/docs/source_en/deploy_federated_client.md b/docs/federated/docs/source_en/deploy_federated_client.md deleted file mode 100644 index 479340ad31695118282d8acfe69108f925100d08..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/deploy_federated_client.md +++ /dev/null @@ -1,202 +0,0 @@ -# Horizontal Federated Device-side Deployment - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/deploy_federated_client.md) - -This document describes how to compile and deploy Federated-Client. - -## Linux Compilation Guidance - -### System Environment and Third-party Dependencies - -This section describes how to complete the device-side compilation of MindSpore federated learning. Currently, the federated learning device-side only provides compilation guidance on Linux, and other systems are not supported. The following table lists the system environment and third-party dependencies required for compilation. - -| Software Name | Version | Functions | -|-----------------------| ------------ | ------------ | -| Ubuntu | 18.04.02LTS | Compiling and running MindSpore operating system | -| [GCC](#installing-gcc) | Between 7.3.0 to 9.4.0 | C++ compiler for compiling MindSpore | -| [git](#installing-git) | - | Source code management tools used by MindSpore | -| [CMake](#installing-cmake) | 3.18.3 and above | Compiling and building MindSpore tools | -| [Gradle](#installing-gradle) | 6.6.1 | JVM-based building tools | -| [Maven](#installing-maven) | 3.3.1 and above | Tools for managing and building Java projects | -| [OpenJDK](#installing-openjdk) | Between 1.8 to 1.15 | Tools for managing and building Java projects | - -#### Installing GCC - -Install GCC with the following command. - -```bash -sudo apt-get install gcc-7 git -y -``` - -To install a higher version of GCC, use the following command to install GCC 8. - -```bash -sudo apt-get install gcc-8 -y -``` - -Or install GCC 9. - -```bash -sudo apt-get install software-properties-common -y -sudo add-apt-repository ppa:ubuntu-toolchain-r/test -sudo apt-get update -sudo apt-get install gcc-9 -y -``` - -#### Installing git - -Install git with the following command. - -```bash -sudo apt-get install git -y -``` - -#### Installing Cmake - -Install [CMake](https://cmake.org/) with the following command. - -```bash -wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | sudo apt-key add - -sudo apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" -sudo apt-get install cmake -y -``` - -#### Installing Gradle - -Install [Gradle](https://gradle.org/releases/) with the following command. - -```bash -# Download the corresponding zip package and unzip it. -# Configure environment variables: - export GRADLE_HOME=GRADLE path - export GRADLE_USER_HOME=GRADLE path -# Add the bin directory to the PATH: - export PATH=${GRADLE_HOME}/bin:$PATH -``` - -#### Installing Maven - -Install [Maven](https://archive.apache.org/dist/maven/maven-3/) with the following command. - -```bash -# Download the corresponding zip package and unzip it. -# Configure environment variables: - export MAVEN_HOME=MAVEN path -# Add the bin directory to the PATH: - export PATH=${MAVEN_HOME}/bin:$PATH -``` - -#### Installing OpenJDK - -Install [OpenJDK](https://jdk.java.net/archive/) with the following command. - -```bash -# Download the corresponding zip package and unzip it. -# Configure environment variables: - export JAVA_HOME=JDK path -# Add the bin directory to the PATH: - export PATH=${JAVA_HOME}/bin:$PATH -``` - -### Verifying Installation - -Verify that the installation in [System environment and third-party dependencies](#system-environment-and-third-party-dependencies) is successful. - -```text -Open a command window and enter: gcc --version -The following output identifies a successful installation: - gcc version version number - -Open a command window and enter: git --version -The following output identifies a successful installation: - git version version number - -Open a command window and enter: cmake --version -The following output identifies a successful installation: - cmake version version number - -Open a command window and enter: gradle --version -The following output identifies a successful installation: - Gradle version number - -Open a command window and enter: mvn --version -The following output identifies a successful installation: - Apache Maven version number - -Open a command window and enter: java --version -The following output identifies a successful installation: - openjdk version version number - -``` - -### Compilation Options - -The `cli_build.sh` script in the federated learning device_client directory is used for compilation on the federated learning device-side. - -#### Instructions for Using cli_build.sh Parameters - -| Parameters | Parameter Description | Value Range | Default Values | -| ---- | ------------------------ | -------- | ------------ | -| -p | the download path of dependency external packages | string | third | -| -c | whether to reuse dependency packages previously downloaded | on and off | on | - -### Compilation Examples - -1. First, you need to download the source code from the gitee code repository before you can compile it. - - ```bash - git clone https://gitee.com/mindspore/federated.git ./ - ``` - -2. Go to the mindspore_federated/device_client directory and execute the following command: - - ```bash - bash cli_build.sh - ``` - -3. Since the end-side framework and the model are decoupled, the x86 architecture package we provide, mindspore-lite-{version}-linux-x64.tar.gz, does not contain model-related scripts, so the user needs to generate the jar package corresponding to the model scripts. The jar package corresponding to the model scripts we provide can be obtained in the following way: - - ```bash - cd federated/example/quick_start_flclient - bash build.sh -r mindspore-lite-java-flclient.jar # After -r, you need to give the absolute path to the latest x86 architecture package (generated in Step 2, federated/mindspore_federated/device_client/build/libs/jarX86/mindspore-lite-java-flclient.jar) - ``` - -After running the above command, the path of generated jar package is federated/example/quick_start_flclient/target/quick_start_flclient.jar. - -### Building Dependency Environment - -1. After extracting the file `federated/mindspore_federated/device_client/third/mindspore-lite-{version}-linux-x64.tar.gz`, the obtained directory structure is as follows(files that are not used in federated learning are not displayed here): - - ```sh - mindspore-lite-{version}-linux-x64 - ├── tools - └── runtime - ├── include # Header files of training framework - ├── lib # Training framework library - │ ├── libminddata-lite.a # Static library files for image processing - │ ├── libminddata-lite.so # Dynamic library files for image processing - │ ├── libmindspore-lite-jni.so # jni dynamic library relied by MindSpore Lite inference framework - │ ├── libmindspore-lite-train.a # Static library relied by MindSpore Lite training framework - │ ├── libmindspore-lite-train.so # Dynamic library relied by MindSpore Lite training framework - │ ├── libmindspore-lite-train-jni.so # jni dynamic library relied by MindSpore Lite training framework - │ ├── libmindspore-lite.a # Static library relied by MindSpore Lite inference framework - │ ├── libmindspore-lite.so # Dynamic library relied by MindSpore Lite inference framework - │ ├── mindspore-lite-java.jar # MindSpore Lite training framework jar package - └── third_party - ├── glog - │└── libmindspore_glog.so.0 # Dynamic library files of glog - └── libjpeg-turbo - └── lib - ├── libjpeg.so.62 # Dynamic library files for image processing - └── libturbojpeg.so.0 # Dynamic library files for image processing - ``` - -2. Put the so files relied by federated learning in paths `mindspore-lite-{version}-linux-x64/runtime/lib/`, `mindspore-lite-{version}-linux-x64/runtime/third_party/glog/` and `mindspore-lite-{version}-linux-x64/runtime/third_party/libjpeg-turbo/lib/` in a folder, e.g. `/resource/x86libs/`. Then set the environment variables in x86 (absolute paths need to be provided below): - - ```sh - export LD_LIBRARY_PATH=/resource/x86libs/:$LD_LIBRARY_PATH - ``` - -3. After setting up the dependency environment, you can simulate starting multiple clients in the x86 environment for federated learning by referring to the application practice tutorial [Implementing an end-cloud federation for image classification application (x86)](https://www.mindspore.cn/federated/docs/en/master/image_classification_application.html). - - diff --git a/docs/federated/docs/source_en/deploy_federated_server.md b/docs/federated/docs/source_en/deploy_federated_server.md deleted file mode 100644 index 12bccc355abeb63f2b625565be03306cb25e3ba8..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/deploy_federated_server.md +++ /dev/null @@ -1,317 +0,0 @@ -# Horizontal Federated Cloud-based Deployment - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/deploy_federated_server.md) - -The following uses LeNet as an example to describe how to use MindSpore Federated to deploy a horizontal federated learning cluster. - -The following figure shows the physical architecture of the MindSpore Federated Learning (FL) Server cluster: - -![mindspore-federated-networking](./images/mindspore_federated_networking.png) - -As shown in the preceding figure, in the horizontal federated learning cloud cluster, there are three MindSpore process roles: `Federated Learning Scheduler`, `Federated Learning Server` and `Federated Learning Worker`: - -- Federated Learning Scheduler - - `Scheduler` provides the following functions: - - 1. Cluster networking assistance: During cluster initialization, the `Scheduler` collects server information and ensures cluster consistency. - 2. Open management plane: You can manage clusters through the `RESTful` APIs. - - In a federated learning task, there is only one `Scheduler`, which communicates with the `Server` using the TCP proprietary protocol. - -- Federated Learning Server - - `Server` executes federated learning tasks, receives and parses data from devices, and provides capabilities such as secure aggregation, time-limited communication, and model storage. In a federated learning task, users can configure multiple `Servers` which communicate with each other through the TCP proprietary protocol and open HTTP ports for device-side connection. - - In the MindSpore federated learning framework, `Server` also supports auto scaling and disaster recovery, and can dynamically schedule hardware resources without interrupting training tasks. - -- Federated Learning Worker - - `Worker` is an accessory module for executing the federated learning task, which is used for supervised retraining of the model in the Server, and then the trained model is distributed to the Server. In a federated learning task, there can be more than one (user configurable) of `Worker`, and the communication between `Worker` and `Server` is performed via TCP protocol. - -`Scheduler` and `Server` must be deployed on a server or container with a single NIC and in the same network segment. MindSpore automatically obtains the first available IP address as the `Server` IP address. - -> The servers will verify the timestamp carried by the clients. It is necessary to eunsure the servers are periodically time synchronized to avoid a large time offset. - -## Preparations - -> Recommend to create a virtual environment for the following operations with [Anaconda](https://www.anaconda.com/). - -### Installing MindSpore - -The MindSpore horizontal federated learning cloud cluster supports deployment on x86 CPU and GPU CUDA hardware platforms. Run commands provided by the [MindSpore Installation Guide](https://www.mindspore.cn/install) to install the latest MindSpore. - -### Installing MindSpore Federated - -Compile and install with [source code](https://gitee.com/mindspore/federated). - -```shell -git clone https://gitee.com/mindspore/federated.git -b master -cd federated -bash build.sh -``` - -For `bash build.sh`, compilation can be accelerated by the `-jn` option, e.g. `-j16`. The third-party dependencies can be downloaded from gitee instead of github by the `-S on` option. - -After compilation, find the whl installation package of Federated in the `build/package/` directory to install: - -```bash -pip install mindspore_federated-{version}-{python_version}-linux_{arch}.whl -``` - -### Verifying Installation - -Execute the following command to verify the installation result. The installation is successful if no error is reported when importing Python modules. - -```python -from mindspore_federated import FLServerJob -``` - -### Installing and Starting Redis Server - -Federated Learning relies on [Redis Server](https://gitee.com/link?target=https%3A%2F%2Fredis.io%2F) as the cached data middleware by default. To run the Federated Learning service, a Redis server needs to be installed and run. - -> User must check the security of the Redis to be used. Some versions may have security vulnerabilities. - -Install Redis server: - -```bash -sudo apt-get install redis -``` - -Run the Redis server and the number of configuration side is 23456: - -```bash -redis-server --port 23456 --save "" -``` - -## Starting a Cluster - -1. [examples](https://gitee.com/mindspore/federated/tree/master/example/cross_device_lenet_femnist/). - - ```bash - cd example/cross_device_lenet_femnist - ``` - -2. Modify the yaml configuration file according to the actual running: `default_yaml_config.yaml`. [sample configuration of Lenet](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/yamls/lenet/default_yaml_config.yaml) is as follows: - - ```yaml - fl_name: Lenet - fl_iteration_num: 25 - server_mode: FEDERATED_LEARNING - enable_ssl: False - - distributed_cache: - type: redis - address: 127.0.0.1:23456 # ip:port of redis actual machine - plugin_lib_path: "" - - round: - start_fl_job_threshold: 2 - start_fl_job_time_window: 30000 - update_model_ratio: 1.0 - update_model_time_window: 30000 - global_iteration_time_window: 60000 - - summary: - metrics_file: "metrics.json" - failure_event_file: "event.txt" - continuous_failure_times: 10 - data_rate_dir: ".." - participation_time_level: "5,15" - - unsupervised: - cluster_client_num: 1000 - eval_type: SILHOUETTE_SCORE - - encrypt: - encrypt_train_type: NOT_ENCRYPT - pw_encrypt: - share_secrets_ratio: 1.0 - cipher_time_window: 3000 - reconstruct_secrets_threshold: 1 - dp_encrypt: - dp_eps: 50.0 - dp_delta: 0.01 - dp_norm_clip: 1.0 - signds: - sign_k: 0.01 - sign_eps: 100 - sign_thr_ratio: 0.6 - sign_global_lr: 0.1 - sign_dim_out: 0 - - compression: - upload_compress_type: NO_COMPRESS - upload_sparse_rate: 0.4 - download_compress_type: NO_COMPRESS - - ssl: - # when ssl_config is set - # for tcp/http server - server_cert_path: "server.p12" - # for tcp client - client_cert_path: "client.p12" - # common - ca_cert_path: "ca.crt" - crl_path: "" - cipher_list: "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-PSK-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-CCM:ECDHE-ECDSA-AES256-CCM:ECDHE-ECDSA-CHACHA20-POLY1305" - cert_expire_warning_time_in_day: 90 - - client_verify: - pki_verify: false - root_first_ca_path: "" - root_second_ca_path: "" - equip_crl_path: "" - replay_attack_time_diff: 600000 - - client: - http_url_prefix: "" - client_epoch_num: 20 - client_batch_size: 32 - client_learning_rate: 0.01 - connection_num: 10000 - - ``` - -3. Prepare the model file and start it in the following way: weight-based start. You need to provide the corresponding model weights. - - Obtain lenet model weight: - - ```bash - wget https://ms-release.obs.cn-north-4.myhuaweicloud.com/ms-dependencies/Lenet.ckpt - ``` - -4. Run Scheduler, and the management side address is `127.0.0.1:11202` by default. - - ```python - python run_sched.py \ - --yaml_config="yamls/lenet.yaml" \ - --scheduler_manage_address="10.*.*.*:18019" - ``` - -5. Run Server, and start one Server and the HTTP server address is `127.0.0.1:6666` by default. - - ```python - python run_server.py \ - --yaml_config="yamls/lenet.yaml" \ - --tcp_server_ip="10.*.*.*" \ - --checkpoint_dir="fl_ckpt" \ - --local_server_num=1 \ - --http_server_address="10.*.*.*:8019" - ``` - -6. Stop federated learning. The current version of the federated learning cluster is a resident process, and the `finish_cloud.py` script can be executed to terminate the federated learning service. The example of executing the command is as follows, where `redis_port` is passed with the same parameters as when starting redis, representing stopping the cluster corresponding to this `Scheduler`. - - ```python - python finish_cloud.py --redis_port=23456 - ``` - - If console prints the following contents: - - ```text - killed $PID1 - killed $PID2 - killed $PID3 - killed $PID4 - killed $PID5 - killed $PID6 - killed $PID7 - killed $PID8 - ``` - - it indicates the termination service is successful. - -## Auto Scaling - -MindSpore federated learning framework supports `Server` auto scaling and provides `RESTful` services externally through the `Scheduler` management port, enabling users to dynamically schedule hardware resources without interrupting training tasks. - -The following example describes how to control scale-out and scale-in of cluster through APIs. - -### Scale-out - -After the cluster starts, enter the machine where the scheduler node is deployed and make a request to the `Scheduler` to query the status and node information. A `RESTful` request can be constructed with the `curl` command. - -```sh -curl -k 'http://10.*.*.*:18015/state' -``` - -`Scheduler` will return query results in `json` format. - -```json -{ - "message":"Get cluster state successful.", - "cluster_state":"CLUSTER_READY", - "code":0, - "nodes":[ - {"node_id","{ip}:{port}::{timestamp}::{random}", - "tcp_address":"{ip}:{port}", - "role":"SERVER"} - ] -} -``` - -You need to pull up 3 new `Server` processes and accumulate the `local_server_num` parameter to the number of scale-out, so as to ensure the correctness of the global networking information, i.e. after scale-out, the number of `local_server_num` should be 4. An example of executing the command is as follows: - -```sh -python run_server.py --yaml_config="yamls/lenet.yaml" --tcp_server_ip="10.*.*.*" --checkpoint_dir="fl_ckpt" --local_server_num=4 --http_server_address="10.*.*.*:18015" -``` - -This command indicates starting four `Server` nodes and the total number of `Server` is 4. - -### Scale-in - -Simulate the scale-in directly via kill -9 pid, construct a `RESTful` request with the `curl` command, and query the status, which finds that there is one node_id missing from the cluster to achieve the purpose of scale-in. - -```sh -curl -k \ -'http://10.*.*.*:18015/state' -``` - -`Scheduler` returns the query results in `json` format. - -```json -{ - "message":"Get cluster state successful.", - "cluster_state":"CLUSTER_READY", - "code":0, - "nodes":[ - {"node_id","{ip}:{port}::{timestamp}::{random}", - "tcp_address":"{ip}:{port}", - "role":"SERVER"}, - {"node_id","worker_fl_{timestamp}::{random}", - "tcp_address":"", - "role":"WORKER"}, - {"node_id","worker_fl_{timestamp}::{random}", - "tcp_address":"", - "role":"WORKER"} - ] -} -``` - -> - After scale-out/scale-in of the cluster is successful, the training tasks are automatically resumed without additional intervention. - -## Security - -MindSpore federated Learning Framework supports SSL security authentication of `Server`. To enable security authentication, you need to add `enable_ssl=True` to the startup command, and the config.json configuration file specified by config_file_path needs to add the following fields: - -```json -{ - "server_cert_path": "server.p12", - "crl_path": "", - "client_cert_path": "client.p12", - "ca_cert_path": "ca.crt", - "cert_expire_warning_time_in_day": 90, - "cipher_list": "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-DSS-AES128-GCM-SHA256:kEDH+AESGCM:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA:ECDHE-ECDSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-DSS-AES128-SHA256:DHE-RSA-AES256-SHA256:DHE-DSS-AES256-SHA:DHE-RSA-AES256-SHA:!aNULL:!eNULL:!EXPORT:!DES:!RC4:!3DES:!MD5:!PSK", - "connection_num":10000 -} -``` - -- server_cert_path: The path to the p12 file containing the ciphertext of the certificate and key on the server-side. -- crl_path: Files of revocation list. -- client_cert_path: The path to the p12 file containing the ciphertext of the certificate and key on the client-side. -- ca_cert_path: Root certificate. -- cipher_list: Cipher suite. -- cert_expire_warning_time_in_day: Alarm time of certificate expiration. - -The key in the p12 file is stored in cipher text. diff --git a/docs/federated/docs/source_en/deploy_vfl.md b/docs/federated/docs/source_en/deploy_vfl.md deleted file mode 100644 index 9aeae5961bb201e55ec866cc9431a1ef8309c8dd..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/deploy_vfl.md +++ /dev/null @@ -1,69 +0,0 @@ -# Vertical Federated Deployment - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/deploy_vfl.md) - -This document explains how to use and deploy the vertical federated learning framework. - -The MindSpore Vertical Federated Learning (VFL) physical architecture is shown in the figure: - -![](./images/deploy_VFL_en.png) - -As shown above, there are two participants in the vertical federated interaction: the Leader node and the Follower node, each of which has processes in two roles: `FLDataWorker` and `VFLTrainer`: - -- FLDataWorker - - The functions of `FLDataWorker` mainly includes: - - 1. Dataset intersection: obtains a common user intersection for both vertical federated participants, and supports a privacy dataset intersection protocol that prevents federated learning participants from obtaining ID information outside the intersection. - 2. Training data generation: After obtaining the intersection ID, the data features are expanded to generate the mindrecord file for training. - 3. Open management surface: `RESTful` interface is provided to users for cluster management. - - In a federated learning task, there is only one `Scheduler`, which communicates with the `Server` through TCP protocol. - -- VFLTrainer - - `VFLTrainer` is the main body that performs the vertical federated training tasks, and performs the forward and reverse computation after model slicing, Embedding tensor transfer, gradient tensor transfer, and reverse optimizer update. The current version supports single-computer single-card and single-computer multi-card training modes. - - In the MindSpore federated learning framework, `Server` also supports elastic scaling and disaster recovery, enabling dynamic provisioning of hardware resources without interruption of training tasks. - -`FLDataWorker` and `VFLTrainer` are generally deployed in the same server or container. - -## Preparation - -> It is recommended to use [Anaconda](https://www.anaconda.com/) to create a virtual environment for the following operations. - -### Installing MindSpore - -MindSpore vertical federated supports deployment on x86 CPU, GPU CUDA and Ascend hardware platforms. The latest version of MindSpore can be installed by referring to [MindSpore Installation Guide](https://www.mindspore.cn/install). - -### Installing MindSpore Federated - -Compile and install via [source code](https://gitee.com/mindspore/federated). - -```shell -git clone https://gitee.com/mindspore/federated.git -b master -cd federated -bash build.sh -``` - -For `bash build.sh`, accelerate compilation through the `-jn` option, e.g. `-j16`, and download third-party dependencies from gitee instead of github by the `-S on` option. - -Once compiled, find the Federated whl installation package in the `build/package/` directory to install. - -```shell -pip install mindspore_federated-{version}-{python_version}-linux_{arch}.whl -``` - -#### Verifying installation - -Execute the following command to verify the installation. The installation is successful if no error is reported when importing Python modules. - -```python -from mindspore_federated import FLServerJob -``` - -## Running the Example - -A running sample of FLDataWorker can be found in [Vertical federated learning data access](https://www.mindspore.cn/federated/docs/en/master/data_join.html). - -A running sample of VFLTrainer can be found in [Vertical federated learning model training - Wide&Deep Recommended Application](https://www.mindspore.cn/federated/docs/en/master/split_wnd_application.html). diff --git a/docs/federated/docs/source_en/faq.md b/docs/federated/docs/source_en/faq.md deleted file mode 100644 index 912e7d07d46372092d9db4e179df939b99d1c420..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/faq.md +++ /dev/null @@ -1,9 +0,0 @@ -# FAQ - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/faq.md) - -**Q: If the cluster networking is unsuccessful, how to locate the cause?** - -A: Please check the server's network conditions, for example, check whether the firewall prohibits port access, please set the firewall to allow port access. - -
    \ No newline at end of file diff --git a/docs/federated/docs/source_en/federated_install.md b/docs/federated/docs/source_en/federated_install.md deleted file mode 100644 index 4a93f9e2293910567913f6606da4027b951eab0e..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/federated_install.md +++ /dev/null @@ -1,25 +0,0 @@ -# Obtaining MindSpore Federated - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/federated_install.md) - -Currently, the [MindSpore Federated](https://gitee.com/mindspore/federated) framework code has been built independently, divided into device-side and cloud-side. Its cloud-side capability relies on MindSpore and MindSpore Federated, using MindSpore for cloud-side cluster aggregation training and communication with device-side, so it needs to get MindSpore whl package and MindSpore Federated whl package respectively. The device-side capability relies on MindSpore Lite and MindSpore Federated java packages, where MindSpore Federated java is mainly responsible for data pre-processing, model training and inference by calling MindSpore Lite for, as well as model-related uploads and downloads by using privacy protection mechanisms and the cloud side. - -## Obtaining the MindSpore WHL Package - -You can use the source code or download the release version to install MindSpore on hardware platforms such as the x86 CPU and GPU CUDA. For details about the installation process, see [Install](https://www.mindspore.cn/install/en) on the MindSpore website. - -## Obtaining the MindSpore Lite Java Package - -You can use the source code or download the release version. Currently, only the Linux and Android platforms are supported, and only the CPU hardware architecture is supported. For details about the installation process, see [Downloading MindSpore Lite](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) and [Building MindSpore Lite](https://www.mindspore.cn/lite/docs/en/master/build/build.html). - -## Obtaining MindSpore Federated WHL Package - -You can use the source code or download the release version to install MindSpore on hardware platforms such as the x86 CPU and GPU CUDA. For details about the installation process, see [Building MindSpore Federated whl](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_server.html). - -## Obtaining MindSpore Federated Java Package - -You can use the source code or download the release version. Currently, MindSpore Federated Learing supports the Linux and Android platforms. For details about the installation process, see [Building MindSpore Federated java](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html). - -## Requirements for Building the Linux Environment - -Currently, the source code build is supported only in the Linux environment. For details about the environment requirements, see [MindSpore Source Code Build](https://www.mindspore.cn/install/en) and [MindSpore Lite Source Code Build](https://www.mindspore.cn/lite/docs/en/master/build/build.html). diff --git a/docs/federated/docs/source_en/horizontal_server.rst b/docs/federated/docs/source_en/horizontal_server.rst deleted file mode 100644 index a2f1a2d2aa2f7511f4e4401c3148775232110754..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/horizontal_server.rst +++ /dev/null @@ -1,12 +0,0 @@ -Federated Server -================ - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/horizontal_server.rst - :alt: View Source on Gitee - -.. toctree:: - :maxdepth: 1 - - horizontal/federated_server - horizontal/federated_server_yaml \ No newline at end of file diff --git a/docs/federated/docs/source_en/image_classfication_dataset_process.md b/docs/federated/docs/source_en/image_classfication_dataset_process.md deleted file mode 100644 index e9578b14a58db46c8ecceb68167a35aa892393cd..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/image_classfication_dataset_process.md +++ /dev/null @@ -1,450 +0,0 @@ -# Federated Learning Image Classification Dataset Process - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/image_classfication_dataset_process.md) - -This tutorial uses the federated learning dataset `FEMNIST` in the `leaf` dataset, which contains 62 different categories of handwritten digits and letters (digits 0 to 9, 26 lowercase letters, and 26 uppercase letters) with an image size of `28 x 28` pixels. The dataset contains handwritten digits and letters from 3500 users (up to 3500 clients can be simulated to participate in federated learning). The total data volume is 805,263, the average data volume per user is 226.83, and the variance of the data volume for all users is 88.94. - -Refer to [leaf dataset instruction](https://github.com/TalwalkarLab/leaf) to download the dataset. - -1. Environmental requirements before downloading the dataset. - - ```sh - numpy==1.16.4 - scipy # conda install scipy - tensorflow==1.13.1 # pip install tensorflow - Pillow # pip install Pillow - matplotlib # pip install matplotlib - jupyter # conda install jupyter notebook==5.7.8 tornado==4.5.3 - pandas # pip install pandas - ``` - -2. Use git to download the official dataset generation script. - - ```sh - git clone https://github.com/TalwalkarLab/leaf.git - ``` - - After downloading the project, the directory structure is as follows: - - ```sh - leaf/data/femnist - ├── data # Used to store the dataset generated by the command - ├── preprocess # Store the code related to data pre-processing - ├── preprocess.sh # shell script generated by femnist dataset - └── README.md # Official dataset download guidance - ``` - -3. Taking `femnist` dataset as an example, run the following command to enter the specified path. - - ```sh - cd leaf/data/femnist - ``` - -4. Using the command `. /preprocess.sh -s niid --sf 1.0 -k 0 -t sample` generates a dataset containing 3500 users, and the training sets and the test sets are divided in a ratio of 9:1 for each user's data. - - The meaning of the parameters in the command can be found in the `leaf/data/femnist/README.md` file. - - The directory structure after running is as follows: - - ```text - leaf/data/femnist/35_client_sf1_data/ - ├── all_data # All datasets are mixed together, without distinguishing the training sets and test sets, containing a total of 35 json files, and each json file contains the data of 100 users - ├── test # The test sets are divided into the training sets and the test sets in a ratio of 9:1 for each user's data, containing a total of 35 json files, and each json file contains the data of 100 users - ├── train # The training sets are divided into the training sets and the test sets in a ratio of 9:1 for each user's data, containing a total of 35 json files, and each json file contains the data of 100 users - └── ... # Other documents do not need to use, and details are not described herein - ``` - - Each json file contains the following three parts: - - - `users`: User list. - - `num_samples`: The sample number list of each user. - - `user_data`: A dictionary object with user names as key and their respective data as value. For each user, the data is represented as a list of images, with each image represented as a list of integers of size 784 (obtained by spreading the `28 x 28` image array). - - Before rerunning `preprocess.sh`, make sure to delete the `rem_user_data`, `sampled_data`, `test` and `train` subfolders from the data directory. - -5. Divide the 35 json files into 3500 json files (each json file represents a user). - - The code is as follows: - - ```python - import os - import json - - def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - def partition_json(root_path, new_root_path): - """ - partition 35 json files to 3500 json file - - Each raw .json file is an object with 3 keys: - 1. 'users', a list of users - 2. 'num_samples', a list of the number of samples for each user - 3. 'user_data', an object with user names as keys and their respective data as values; for each user, data is represented as a list of images, with each image represented as a size-784 integer list (flattened from 28 by 28) - - Each new .json file is an object with 3 keys: - 1. 'user_name', the name of user - 2. 'num_samples', the number of samples for the user - 3. 'user_data', an dict object with 'x' as keys and their respective data as values; with 'y' as keys and their respective label as values; - - Args: - root_path (str): raw root path of 35 json files - new_root_path (str): new root path of 3500 json files - """ - paths = os.listdir(root_path) - count = 0 - file_num = 0 - for i in paths: - file_num += 1 - file_path = os.path.join(root_path, i) - print('======== process ' + str(file_num) + ' file: ' + str(file_path) + '======================') - with open(file_path, 'r') as load_f: - load_dict = json.load(load_f) - users = load_dict['users'] - num_users = len(users) - num_samples = load_dict['num_samples'] - for j in range(num_users): - count += 1 - print('---processing user: ' + str(count) + '---') - cur_out = {'user_name': None, 'num_samples': None, 'user_data': {}} - cur_user_id = users[j] - cur_data_num = num_samples[j] - cur_user_path = os.path.join(new_root_path, cur_user_id + '.json') - cur_out['user_name'] = cur_user_id - cur_out['num_samples'] = cur_data_num - cur_out['user_data'].update(load_dict['user_data'][cur_user_id]) - with open(cur_user_path, 'w') as f: - json.dump(cur_out, f) - f = os.listdir(new_root_path) - print(len(f), ' users have been processed!') - # partition train json files - partition_json("leaf/data/femnist/35_client_sf1_data/train", "leaf/data/femnist/3500_client_json/train") - # partition test json files - partition_json("leaf/data/femnist/35_client_sf1_data/test", "leaf/data/femnist/3500_client_json/test") - ``` - - where `root_path` is `leaf/data/femnist/35_client_sf1_data/{train,test}`. `new_root_path` is set by itself to store the generated 3500 user json files, which need to be processed separately for the training and test folders. - - Each of the 3500 newly generated user json files contains the following three parts: - - - `user_name`: User name. - - `num_samples`: The number of user samples - - `user_data`: A dictionary object with 'x' as key and user data as value; with 'y' as key and the label corresponding to the user data as value. - - Print the result as following after running the script, which means a successful run: - - ```sh - ======== process 1 file: /leaf/data/femnist/35_client_sf1_data/train/all_data_16_niid_0_keep_0_train_9.json====================== - ---processing user: 1--- - ---processing user: 2--- - ---processing user: 3--- - ...... - ``` - -6. Convert a json file to an image file. - - Refer to the following code: - - ```python - import os - import json - import numpy as np - from PIL import Image - - name_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', - 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', - 'V', 'W', 'X', 'Y', 'Z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', - 'v', 'w', 'x', 'y', 'z' - ] - - def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - def json_2_numpy(img_size, file_path): - """ - read json file to numpy - Args: - img_size (list): contain three elements: the height, width, channel of image - file_path (str): root path of 3500 json files - return: - image_numpy (numpy) - label_numpy (numpy) - """ - # open json file - with open(file_path, 'r') as load_f_train: - load_dict = json.load(load_f_train) - num_samples = load_dict['num_samples'] - x = load_dict['user_data']['x'] - y = load_dict['user_data']['y'] - size = (num_samples, img_size[0], img_size[1], img_size[2]) - image_numpy = np.array(x, dtype=np.float32).reshape(size) # mindspore doesn't support float64 and int64 - label_numpy = np.array(y, dtype=np.int32) - return image_numpy, label_numpy - - def json_2_img(json_path, save_path): - """ - transform single json file to images - - Args: - json_path (str): the path json file - save_path (str): the root path to save images - - """ - data, label = json_2_numpy([28, 28, 1], json_path) - for i in range(data.shape[0]): - img = data[i] * 255 # PIL don't support the 0/1 image ,need convert to 0~255 image - im = Image.fromarray(np.squeeze(img)) - im = im.convert('L') - img_name = str(label[i]) + '_' + name_list[label[i]] + '_' + str(i) + '.png' - path1 = os.path.join(save_path, str(label[i])) - mkdir(path1) - img_path = os.path.join(path1, img_name) - im.save(img_path) - print('-----', i, '-----') - - def all_json_2_img(root_path, save_root_path): - """ - transform json files to images - Args: - json_path (str): the root path of 3500 json files - save_path (str): the root path to save images - """ - usage = ['train', 'test'] - for i in range(2): - x = usage[i] - files_path = os.path.join(root_path, x) - files = os.listdir(files_path) - - for name in files: - user_name = name.split('.')[0] - json_path = os.path.join(files_path, name) - save_path1 = os.path.join(save_root_path, user_name) - mkdir(save_path1) - save_path = os.path.join(save_path1, x) - mkdir(save_path) - print('=============================' + name + '=======================') - json_2_img(json_path, save_path) - - all_json_2_img("leaf/data/femnist/3500_client_json/", "leaf/data/femnist/3500_client_img/") - ``` - - Print the result as following after running the script, which means a successful run: - - ```sh - =============================f0644_19.json======================= - ----- 0 ----- - ----- 1 ----- - ----- 2 ----- - ...... - ``` - -7. Since the dataset under some user folders is small, if the number is smaller than the batch size, random expansion is required. - - The entire dataset `"leaf/data/femnist/3500_client_img/"` can be checked and expanded by referring to the following code: - - ```python - import os - import shutil - from random import choice - - def count_dir(path): - num = 0 - for root, dirs, files in os.walk(path): - for file in files: - num += 1 - return num - - def get_img_list(path): - img_path_list = [] - label_list = os.listdir(path) - for i in range(len(label_list)): - label = label_list[i] - imgs_path = os.path.join(path, label) - imgs_name = os.listdir(imgs_path) - for j in range(len(imgs_name)): - img_name = imgs_name[j] - img_path = os.path.join(imgs_path, img_name) - img_path_list.append(img_path) - return img_path_list - - def data_aug(data_root_path, batch_size = 32): - users = os.listdir(data_root_path) - tags = ["train", "test"] - aug_users = [] - for i in range(len(users)): - user = users[i] - for tag in tags: - data_path = os.path.join(data_root_path, user, tag) - num_data = count_dir(data_path) - if num_data < batch_size: - aug_users.append(user + "_" + tag) - print("user: ", user, " ", tag, " data number: ", num_data, " < ", batch_size, " should be aug") - aug_num = batch_size - num_data - img_path_list = get_img_list(data_path) - for j in range(aug_num): - img_path = choice(img_path_list) - info = img_path.split(".") - aug_img_path = info[0] + "_aug_" + str(j) + ".png" - shutil.copy(img_path, aug_img_path) - print("[aug", j, "]", "============= copy file:", img_path, "to ->", aug_img_path) - print("the number of all aug users: " + str(len(aug_users))) - print("aug user name: ", end=" ") - for k in range(len(aug_users)): - print(aug_users[k], end = " ") - - if __name__ == "__main__": - data_root_path = "leaf/data/femnist/3500_client_img/" - batch_size = 32 - data_aug(data_root_path, batch_size) - ``` - -8. Convert the expanded image dataset into a bin file format usable in the Federated Learning Framework. - - Refer to the following code: - - ```python - import numpy as np - import os - import mindspore.dataset as ds - import mindspore.dataset.vision as vision - import mindspore.dataset.transforms as transforms - import mindspore - - def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - def count_id(path): - files = os.listdir(path) - ids = {} - for i in files: - ids[i] = int(i) - return ids - - def create_dataset_from_folder(data_path, img_size, batch_size=32, repeat_size=1, num_parallel_workers=1, shuffle=False): - """ create dataset for train or test - Args: - data_path: Data path - batch_size: The number of data records in each group - repeat_size: The number of replicated data records - num_parallel_workers: The number of parallel workers - """ - # define dataset - ids = count_id(data_path) - mnist_ds = ds.ImageFolderDataset(dataset_dir=data_path, decode=False, class_indexing=ids) - # define operation parameters - resize_height, resize_width = img_size[0], img_size[1] # 32 - - transform = [ - vision.Decode(True), - vision.Grayscale(1), - vision.Resize(size=(resize_height, resize_width)), - vision.Grayscale(3), - vision.ToTensor(), - ] - compose = transforms.Compose(transform) - - # apply map operations on images - mnist_ds = mnist_ds.map(input_columns="label", operations=transforms.TypeCast(mindspore.int32)) - mnist_ds = mnist_ds.map(input_columns="image", operations=compose) - - # apply DatasetOps - buffer_size = 10000 - if shuffle: - mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script - mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) - mnist_ds = mnist_ds.repeat(repeat_size) - return mnist_ds - - def img2bin(root_path, root_save): - """ - transform images to bin files - - Args: - root_path: the root path of 3500 images files - root_save: the root path to save bin files - - """ - - use_list = [] - train_batch_num = [] - test_batch_num = [] - mkdir(root_save) - users = os.listdir(root_path) - for user in users: - use_list.append(user) - user_path = os.path.join(root_path, user) - train_test = os.listdir(user_path) - for tag in train_test: - data_path = os.path.join(user_path, tag) - dataset = create_dataset_from_folder(data_path, (32, 32, 1), 32) - batch_num = 0 - img_list = [] - label_list = [] - for data in dataset.create_dict_iterator(): - batch_x_tensor = data['image'] - batch_y_tensor = data['label'] - trans_img = np.transpose(batch_x_tensor.asnumpy(), [0, 2, 3, 1]) - img_list.append(trans_img) - label_list.append(batch_y_tensor.asnumpy()) - batch_num += 1 - - if tag == "train": - train_batch_num.append(batch_num) - elif tag == "test": - test_batch_num.append(batch_num) - - imgs = np.array(img_list) # (batch_num, 32,3,32,32) - labels = np.array(label_list) - path1 = os.path.join(root_save, user) - mkdir(path1) - image_path = os.path.join(path1, user + "_" + "bn_" + str(batch_num) + "_" + tag + "_data.bin") - label_path = os.path.join(path1, user + "_" + "bn_" + str(batch_num) + "_" + tag + "_label.bin") - - imgs.tofile(image_path) - labels.tofile(label_path) - print("user: " + user + " " + tag + "_batch_num: " + str(batch_num)) - print("total " + str(len(use_list)) + " users finished!") - - root_path = "leaf/data/femnist/3500_client_img/" - root_save = "leaf/data/femnist/3500_clients_bin" - img2bin(root_path, root_save) - ``` - - Print the result as following after running the script, which means a successful run: - - ```sh - user: f0141_43 test_batch_num: 1 - user: f0141_43 train_batch_num: 10 - user: f0137_14 test_batch_num: 1 - user: f0137_14 train_batch_num: 11 - ...... - total 3500 users finished! - ``` - -9. Generate `3500_clients_bin` folder containing a total of 3500 user folders with the following directory structure: - - ```sh - leaf/data/femnist/3500_clients_bin - ├── f0000_14 # User number - │ ├── f0000_14_bn_10_train_data.bin # The training data of user f0000_14 (The number 10 after bn_ represents the batch number) - │ ├── f0000_14_bn_10_train_label.bin # Training tag for user f0000_14 - │ ├── f0000_14_bn_1_test_data.bin # Test data of user f0000_14 (the number 1 after bn_ represents batch number) - │ └── f0000_14_bn_1_test_label.bin # Test tag for user f0000_14 - ├── f0001_41 # User number - │ ├── f0001_41_bn_11_train_data.bin # The training data of user f0001_41 (The number 11 after bn_ represents the batch number) - │ ├── f0001_41_bn_11_train_label.bin # Training tag for user f0001_41 - │ ├── f0001_41_bn_1_test_data.bin # Test data of user f0001_41 (the number 1 after bn_ represents batch number) - │ └── f0001_41_bn_1_test_label.bin # Test tag for user f0001_41 - │ ... - └── f4099_10 # User number - ├── f4099_10_bn_4_train_data.bin # The training data of user f4099_10 (the number 4 after bn_ represents the batch number) - ├── f4099_10_bn_4_train_label.bin # Training tag for user f4099_10 - ├── f4099_10_bn_1_test_data.bin # Test data of user f4099_10 (the number 1 after bn_ represents batch number) - └── f4099_10_bn_1_test_label.bin # Test tag for user f4099_10 - ``` - -The `3500_clients_bin` folder generated according to steps 1 to 9 above can be directly used as the input data for the device-cloud federated image classification task. diff --git a/docs/federated/docs/source_en/image_classification_application.md b/docs/federated/docs/source_en/image_classification_application.md deleted file mode 100644 index 04a7b00dc093f330a7b333f175cfe99a94255446..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/image_classification_application.md +++ /dev/null @@ -1,331 +0,0 @@ -# Implementing an Image Classification Application of Cross-device Federated Learning (x86) - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/image_classification_application.md) - -Federated learning can be divided into cross-silo federated learning and cross-device federated learning according to different participating clients. In the cross-silo federated learning scenario, the clients participating in federated learning are different organizations (for example, medical or financial) or data centers geographically distributed, that is, training models on multiple data islands. The clients participating in the cross-device federated learning scenario are a large number of mobiles or IoT devices. This framework will introduce how to use the network LeNet to implement an image classification application on the MindSpore cross-silo federated framework, and provides related tutorials for simulating to start multi-client participation in federated learning in the x86 environment. - -Before you start, check whether MindSpore has been correctly installed. If not, install MindSpore on your computer by referring to [Install](https://www.mindspore.cn/install/en) on the MindSpore website. - -## Preparation - -We provide [Federated Learning Image Classification Dataset FEMNIST](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/federated/3500_clients_bin.zip) and the [device-side model file](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/lenet_train.ms) of the `.ms` format for users to use directly. Users can also refer to the following tutorials to generate the datasets and models based on actual needs. - -### Generating a Device-side Model File - -1. Define the network and training process. - - For the definition of the specific network and training process, please refer to [Beginners Getting Started](https://www.mindspore.cn/tutorials/en/master/beginner/quick_start.html). - -2. Export a model as a MindIR file. - - The code snippet is as follows: - - ```python - import argparse - import numpy as np - import mindspore as ms - import mindspore.nn as nn - - def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - """weight initial for conv layer""" - weight = weight_variable() - return nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_init=weight, - has_bias=False, - pad_mode="valid", - ) - - def fc_with_initialize(input_channels, out_channels): - """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - - def weight_variable(): - """weight initial""" - return ms.common.initializer.TruncatedNormal(0.02) - - class LeNet5(nn.Cell): - def __init__(self, num_class=10, channel=3): - super(LeNet5, self).__init__() - self.num_class = num_class - self.conv1 = conv(channel, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16 * 5 * 5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, self.num_class) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x - - parser = argparse.ArgumentParser(description="export mindir for lenet") - parser.add_argument("--device_target", type=str, default="CPU") - parser.add_argument("--mindir_path", type=str, - default="lenet_train.mindir") # the mindir file path of the model to be export - - args, _ = parser.parse_known_args() - device_target = args.device_target - mindir_path = args.mindir_path - - ms.set_context(mode=ms.GRAPH_MODE, device_target=device_target) - - if __name__ == "__main__": - np.random.seed(0) - network = LeNet5(62) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction="mean") - net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) - net_with_criterion = nn.WithLossCell(network, criterion) - train_network = nn.TrainOneStepCell(net_with_criterion, net_opt) - train_network.set_train() - - data = ms.Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) - label = ms.Tensor(np.random.randint(0, 1, (32, 62)).astype(np.float32)) - ms.export(train_network, data, label, file_name=mindir_path, - file_format='MINDIR') # Add the export statement to obtain the model file in MindIR format. - ``` - - The parameter `--mindir_path` is used to set the path of the generated file in MindIR format. - -3. Convert the MindIR file into an .ms file that can be used by the federated learning device-side framework. - - For details about model conversion, see [Training Model Conversion Tutorial](https://www.mindspore.cn/lite/docs/en/master/train/converter_train.html). - - The following is an example of model conversion: - - Assume that the model file to be converted is `lenet_train.mindir`. Run the following command: - - ```sh - ./converter_lite --fmk=MINDIR --trainModel=true --modelFile=lenet_train.mindir --outputFile=lenet_train - ``` - - If the conversion is successful, the following information is displayed: - - ```sh - CONVERT RESULT SUCCESS:0 - ``` - - This indicates that the MindSpore model is successfully converted to the MindSpore device-side model and the new file `lenet_train.ms` is generated. If the conversion fails, the following information is displayed: - - ```sh - CONVERT RESULT FAILED: - ``` - - The generated model file in `.ms` format is the model file required by subsequent clients. - -## Simulating Multi-client Participation in Federated Learning - -### Preparing a Model File for the Client - -This example uses lenet on the device-side to simulate the actual network used, where[device-side model file](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/lenet_train.ms) in `.ms` format of lenet. As the real scenario where a client contains only one model file in .ms format, in the simulation scenario, multiple copies of the .ms file need to be copied and named according to the `lenet_train{i}.ms` format, where i represents the client number, since the .ms file has been automatically copied for each client in `run_client_x86.py`. - -See the copy_ms function in [startup script](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/simulate_x86/run_client_x86.py) for details. - -### Starting the Cloud Side Service - -Users can first refer to [cloud-side deployment tutorial](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_server.html) to deploy the cloud-side environment and start the cloud-side service. - -### Starting the Client - -Before starting the client, please refer to the section [Device-side deployment tutotial](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html) for deployment of device environment. - -We provide a reference script [run_client_x86.py](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/simulate_x86/run_client_x86.py), users can set relevant parameters to start different federated learning interfaces. -After the cloud-side service is successfully started, the script providing run_client_x86.py is used to call the federated learning framework jar package `mindspore-lite-java-flclient.jar` and the corresponding jar package `quick_start_flclient.jar` of the model script, obtaining in [Compiling package Flow in device-side deployment](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html) to simulate starting multiple clients to participate in the federated learning task. - -Taking the LeNet network as an example, some of the input parameters in the `run_client_x86.py` script have the following meanings, and users can set them according to the actual situation: - -- `--fl_jar_path` - - For setting the federated learning jar package path and obtaining x86 environment federated learning jar package, refer to [Compile package process in device-side deployment](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html). - -- `--case_jar_path` - - For setting the path of jar package `quick_start_flclient.jar` generated by model script and obtaining the JAR package in the x86 environment, see [Compile package process in device-side deployment](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html). - -- `--lite_jar_path` - - For setting the path of jar package `mindspore-lite-java.jar` of mindspore lite, which is located in `mindspore-lite-{version}-linux-x64.tar.gz`. For x86 environment federated learning jar package acquisition, see [Compile package process in device-side deployment](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html). - -- `--train_data_dir` - - The root path of the training dataset in which the LeNet image classification task is stored is the training data.bin file and label.bin file for each client, e.g. `data/femnist/3500_clients_bin/`. - -- `--fl_name` - - Specifies the package path of model script used by federated learning. We provide two types of model scripts for your reference ([Supervised sentiment classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert), [Lenet image classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)). For supervised sentiment classification tasks, this parameter can be set to the package path of the provided script file [AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java), like as `com.mindspore.flclient.demo.albert.AlbertClient`. For Lenet image classification tasks, this parameter can be set to the package path of the provided script file [LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java), like as `com.mindspore.flclient.demo.lenet.LenetClient`. At the same time, users can refer to these two types of model scripts, define the model script by themselves, and then set the parameter to the package path of the customized model file ModelClient.java (which needs to inherit from the class [Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java)). - -- `--train_model_dir` - - Specifies the training model path used for federated learning. The path is the directory where multiple .ms files copied in the preceding tutorial are stored, for example, `ms/lenet`. The path must be an absolute path. - -- `--domain_name` - - Used to set the url for device-cloud communication. Currently, https and http communication are supported, and the corresponding formats are like as: https://......, http://....... When `if_use_elb` is set to true, the format must be: or , where `127.0.0.1` corresponds to the ip of the machine ip providing cloud-side services (corresponding to the cloud-side parameter `--scheduler_ip`), and `6666` corresponds to the cloud-side parameter `--fl_server_port`. - - Note 1: When this parameter is set to `http://......`, it means that HTTP communication is used, and there may be communication security risks. - - Note 2: When this parameter is set to `https://......`, it means the use of HTTPS communication. At this time, SSL certificate authentication must be performed, and the certificate path needs to be set by the parameter `-cert_path`. - -- `--task` - - Specifies the type of the task to be started. `train` indicates that a training task is started. `inference` indicates that multiple data inference tasks are started. `getModel` indicates that the task for obtaining the cloud model is started. Other character strings indicate that the inference task of a single data record is started. The default value is `train`. The initial model file (.ms file) is not trained. Therefore, you are advised to start the training task first. After the training is complete, start the inference task. (Note that the values of client_num in the two startups must be the same to ensure that the model file used by `inference` is the same as that used by `train`.) - -- `--batch_size` - - Specifies the number of single-step training samples used in federated learning training and inference, that is, batch size. It needs to be consistent with the batch size of the input data of the model. - -- `--client_num` - - Specifies the number of clients. The value must be the same as that of `start_fl_job_cnt` when the server is started. This parameter is not required in actual scenarios. - -If you want to know more about the meaning of other parameters in the `run_client_x86.py` script, you can refer to the comments in the script. - -The basic startup instructions of the federated learning interface are as follows: - -```sh - rm -rf client_*\ - && rm -rf ms/* \ - && python3 run_client_x86.py \ - --fl_jar_path="federated/mindspore_federated/device_client/build/libs/jarX86/mindspore-lite-java-flclient.jar" \ - --case_jar_path="federated/example/quick_start_flclient/target/case_jar/quick_start_flclient.jar" \ - --lite_jar_path="federated/mindspore_federated/device_client/third/mindspore-lite-2.0.0-linux-x64/runtime/lib/mindspore-lite-java.jar" \ - --train_data_dir="federated/tests/st/simulate_x86/data/3500_clients_bin/" \ - --eval_data_dir="null" \ - --infer_data_dir="null" \ - --vocab_path="null" \ - --ids_path="null" \ - --path_regex="," \ - --fl_name="com.mindspore.flclient.demo.lenet.LenetClient" \ - --origin_train_model_path="federated/tests/st/simulate_x86/ms_files/lenet/lenet_train.ms" \ - --origin_infer_model_path="null" \ - --train_model_dir="ms" \ - --infer_model_dir="ms" \ - --ssl_protocol="TLSv1.2" \ - --deploy_env="x86" \ - --domain_name="http://10.*.*.*:8010" \ - --cert_path="CARoot.pem" --use_elb="false" \ - --server_num=1 \ - --task="train" \ - --thread_num=1 \ - --cpu_bind_mode="NOT_BINDING_CORE" \ - --train_weight_name="null" \ - --infer_weight_name="null" \ - --name_regex="::" \ - --server_mode="FEDERATED_LEARNING" \ - --batch_size=32 \ - --input_shape="null" \ - --client_num=8 -``` - -Note that the related path in the startup command must give an absolute path. - -The above commands indicate that eight clients are started to participate in federated learning. If the startup is successful, log files corresponding to the eight clients are generated in the current folder. You can view the log files to learn the running status of each client: - -```text -./ -├── client_0 -│ └── client.log # Log file of client 0. -│ ...... -└── client_7 - └── client.log # Log file of client 7. -``` - -For different interfaces and scenarios, you only need to modify specific parameter values according to the meaning of the parameters, such as: - -- Start federated learning and training tasks: SyncFLJob.flJobRun() - - When `--task` in `Basic Start Command` is set to `train`, it means to start the task. - - You can use the command `grep -r "average loss:" client_0/client.log` to view the average loss of each epoch of `client_0` during the training process. It will be printed as follows: - - ```sh - INFO: ----------epoch:0,average loss:4.1258564 ---------- - ...... - ``` - - You can also use the command `grep -r "evaluate acc:" client_0/client.log` to view the verification accuracy of the model after the aggregation in each federated learning iteration for `client_0` . It will be printed like the following: - - ```sh - INFO: [evaluate] evaluate acc: 0.125 - ...... - ``` - - On the cloud side, the number of client group ids and algorithm type for unsupervised cluster index statistics can be specified by setting the 'cluster_client_num' parameter and 'eval_type' parameter of yaml configuration file. The 'metrics.json' statistical file generated on the cloud side can query the unsupervised indicator information: - - ```text - "unsupervisedEval":0.640 - "unsupervisedEval":0.675 - "unsupervisedEval":0.677 - "unsupervisedEval":0.706 - ...... - ``` - -- Start the inference task: SyncFLJob.modelInference() - - When `--task` in `Basic Start Command` is set to `inference`, it means to start the task. - - You can view the inference result of `client_0` through the command `grep -r "the predicted labels:" client_0/client.log`: - - ```sh - INFO: [model inference] the predicted labels: [0, 0, 0, 1, 1, 1, 2, 2, 2] - ...... - ``` - -- Start the task of obtaining the latest model on the cloud side: SyncFLJob.getModel() - - When `--task` in `Basic Start Command` is set to `inference`, it means to start the task. - - If there is the following content in the log file, it means that the latest model on the cloud side is successfully obtained: - - ```sh - INFO: [getModel] get response from server ok! - ``` - -### Stopping the Client Process - -For details, see the [finish.py](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/simulate_x86/finish.py) script. The details are as follows: - -The command of stopping the client process: - -```sh -python finish.py --kill_tag=mindspore-lite-java-flclient -``` - -The parameter `--kill_tag` is used to search for the keyword to kill the client process. You only need to set the special keyword in `--jarPath`. The default value is `mindspore-lite-java-flclient`, that is, the name of the federated learning JAR package. The user can check whether the process still exists through the command `ps -ef |grep "mindspore-lite-java-flclient"`. - -Experimental results of 50 clients participating in federated learning and training tasks. - -Currently, the `3500_clients_bin` folder contains data of 3500 clients. This script can simulate a maximum of 3500 clients to participate in federated learning. - -The following figure shows the accuracy of the test dataset for federated learning on 50 clients (set `server_num` to 16). - -![lenet_50_clients_acc](images/lenet_50_clients_acc_en.png) - -The total number of federated learning iterations is 100, the number of epochs for local training on the client is 20, and the value of batchSize is 32. - -The test accuracy in the figure refers to the accuracy of each client test dataset on the aggregated model on the cloud for each federated learning iteration: - -AVG: average accuracy of 50 clients in the test dataset for each federated learning iteration. - -TOP5: average accuracy of the 5 clients with the highest accuracy in the test dataset for each federated learning iteration. - -LOW5: average accuracy of the 5 clients with the lowest accuracy in the test dataset for each federated learning iteration. diff --git a/docs/federated/docs/source_en/image_classification_application_in_cross_silo.md b/docs/federated/docs/source_en/image_classification_application_in_cross_silo.md deleted file mode 100644 index 86196f4866a41b9cd97e68edc0ecdd4089b50df2..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/image_classification_application_in_cross_silo.md +++ /dev/null @@ -1,313 +0,0 @@ -# Implementing a Cloud-Slio Federated Image Classification Application (x86) - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/image_classification_application_in_cross_silo.md) - -Based on the type of participating clients, federated learning can be classified into cross-silo federated learning and cross-device federated learning. In a cross-silo federated learning scenario, the clients involved in federated learning are different organizations (e.g., healthcare or finance) or geographically distributed data centers, i.e., training models on multiple data silos. In the cross-device federated learning scenario, the participating clients are a large number of mobile or IoT devices. This framework will describe how to implement an image classification application by using the network LeNet on the MindSpore Federated cross-silo federated framework. - -The full script to launch cross-silo federated image classification application can be found [here](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_femnist). - -## Downloading the Dataset - -This example uses the federated learning dataset `FEMNIST` from [leaf dataset](https://github.com/TalwalkarLab/leaf), which contains 62 different categories of handwritten numbers and letters (numbers 0 to 9, 26 lowercase letters, 26 uppercase letters) with an image size of `28 x 28` pixels. The dataset contains handwritten digits and letters from 3500 users (up to 3500 clients can be simulated to participate in federated learning). The total data volume is 805263, the average amount of data contained per user is 226.83, and the variance of the data volume for all users is 88.94. - -You can refer to [Image classfication dataset process](https://www.mindspore.cn/federated/docs/en/master/image_classfication_dataset_process.html) in steps 1 to 7 to obtain the 3500 user datasets `3500_client_img` in the form of images. - -Due to the relatively small amount of data per user in the original 3500 user dataset, it will converge too fast in the cross-silo federated task to obviously reflect the convergence effect of the cross-silo federated framework. The following provides a reference script to merge the specified number of user data into one user to increase the amount of individual user data participating in the cross-silo federated task and better simulate the cross-silo federated framework experiment. - -```python -import os -import shutil - - -def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - -def combine_users(root_data_path, new_data_path, raw_user_num, new_user_num): - mkdir(new_data_path) - user_list = os.listdir(root_data_path) - num_per_user = int(raw_user_num / new_user_num) - for i in range(new_user_num): - print( - "========================== combine the raw {}~{} users to the new user: dataset_{} ==========================".format( - i * num_per_user, i * num_per_user + num_per_user - 1, i)) - new_user = "dataset_" + str(i) - new_user_path = os.path.join(new_data_path, new_user) - mkdir(new_user_path) - for j in range(num_per_user): - index = i * new_user_num + j - user = user_list[index] - user_path = os.path.join(root_data_path, user) - tags = os.listdir(user_path) - print("------------- process the raw user: {} -------------".format(user)) - for t in tags: - tag_path = os.path.join(user_path, t) - label_list = os.listdir(tag_path) - new_tag_path = os.path.join(new_user_path, t) - mkdir(new_tag_path) - for label in label_list: - label_path = os.path.join(tag_path, label) - img_list = os.listdir(label_path) - new_label_path = os.path.join(new_tag_path, label) - mkdir(new_label_path) - - for img in img_list: - img_path = os.path.join(label_path, img) - new_img_name = user + "_" + img - new_img_path = os.path.join(new_label_path, new_img_name) - shutil.copy(img_path, new_img_path) - -if __name__ == "__main__": - root_data_path = "cross_silo_femnist/femnist/3500_clients_img" - new_data_path = "cross_silo_femnist/femnist/35_7_client_img" - raw_user_num = 35 - new_user_num = 7 - combine_users(root_data_path, new_data_path, raw_user_num, new_user_num) -``` - -where `root_data_path` is the path to the original 3500 user datasets, `new_data_path` is the path to the merged dataset, `raw_user_num` specifies the total number of user datasets to be merged (needs to be <= 3500), and `new_user_num` is used to set the number of users merged by the original datasets. For example, the sample code will select the first 35 users from `cross_silo_femnist/femnist/3500_clients_img`, merge them into 7 user datasets and store them in the path `cross_silo_femnist/femnist/35_7_client_img` (the merged 7 users each contains the original 5 user dataset). - -The following print represents a successful merge of the datasets. - -```sh -========================== combine the raw 0~4 users to the new user: dataset_0 ========================== -------------- process the raw user: f1798_42 ------------- -------------- process the raw user: f2149_81 ------------- -------------- process the raw user: f4046_46 ------------- -------------- process the raw user: f1093_13 ------------- -------------- process the raw user: f1124_24 ------------- -========================== combine the raw 5~9 users to the new user: dataset_1 ========================== -------------- process the raw user: f0586_11 ------------- -------------- process the raw user: f0721_31 ------------- -------------- process the raw user: f3527_33 ------------- -------------- process the raw user: f0146_33 ------------- -------------- process the raw user: f1272_09 ------------- -========================== combine the raw 10~14 users to the new user: dataset_2 ========================== -------------- process the raw user: f0245_40 ------------- -------------- process the raw user: f2363_77 ------------- -------------- process the raw user: f3596_19 ------------- -------------- process the raw user: f2418_82 ------------- -------------- process the raw user: f2288_58 ------------- -========================== combine the raw 15~19 users to the new user: dataset_3 ========================== -------------- process the raw user: f2249_75 ------------- -------------- process the raw user: f3681_31 ------------- -------------- process the raw user: f3766_48 ------------- -------------- process the raw user: f0537_35 ------------- -------------- process the raw user: f0614_14 ------------- -========================== combine the raw 20~24 users to the new user: dataset_4 ========================== -------------- process the raw user: f2302_58 ------------- -------------- process the raw user: f3472_19 ------------- -------------- process the raw user: f3327_11 ------------- -------------- process the raw user: f1892_07 ------------- -------------- process the raw user: f3184_11 ------------- -========================== combine the raw 25~29 users to the new user: dataset_5 ========================== -------------- process the raw user: f1692_18 ------------- -------------- process the raw user: f1473_30 ------------- -------------- process the raw user: f0909_04 ------------- -------------- process the raw user: f1956_19 ------------- -------------- process the raw user: f1234_26 ------------- -========================== combine the raw 30~34 users to the new user: dataset_6 ========================== -------------- process the raw user: f0031_02 ------------- -------------- process the raw user: f0300_24 ------------- -------------- process the raw user: f4064_46 ------------- -------------- process the raw user: f2439_77 ------------- -------------- process the raw user: f1717_16 ------------- -``` - -The following directory structure of the folder `cross_silo_femnist/femnist/35_7_client_img` is as follows: - -```text -35_7_client_img # Merge the 35 users in the FeMnist dataset into 7 client data (each containing 5 pieces of user data) -├── dataset_0 # The dataset of Client 0 -│ ├── train # Training dataset -│ │ ├── 0 # Store image data corresponding to category 0 -│ │ ├── 1 # Store image data corresponding to category 1 -│ │ │ ...... -│ │ └── 61 # Store image data corresponding to category 61 -│ └── test # Test dataset, with the same directory structure as train -│ ...... -│ -└── dataset_6 # The dataset of Client 6 - ├── train # Training dataset - │ ├── 0 # Store image data corresponding to category 0 - │ ├── 1 # Store image data corresponding to category 1 - │ │ ...... - │ └── 61 # Store image data corresponding to category 61 - └── test # Test dataset, with the same directory structure as train -``` - -## Defining the Network - -We choose the relatively simple LeNet network, which has seven layers without the input layer: two convolutional layers, two downsampling layers (pooling layers), and three fully connected layers. Each layer contains a different number of training parameters, as shown in the following figure: - -![LeNet5](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/LeNet_5.jpg) - -> More information about LeNet network is not described herein. For more details, please refer to . - -The network used for this task can be found in the script [test_cross_silo_femnist.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_femnist/test_cross_silo_femnist.py). - -For a specific understanding of the network definition process in MindSpore, please refer to [quick start](https://www.mindspore.cn/tutorials/en/master/beginner/quick_start.html#building-network). - -## Launching the Cross-Silo Federated Task - -### Installing MindSpore and MindSpore Federated - -Both source code and downloadable distribution are included. Support CPU, GPU, Ascend hardware platforms, just choose to install according to the hardware platforms. The installation steps can be found in [MindSpore Installation Guide](https://www.mindspore.cn/install), [MindSpore Federated Installation Guide](https://www.mindspore.cn/federated/docs/en/master/federated_install.html). - -Currently the federated learning framework is only supported for deployment in Linux environments. Cross-silo federated learning framework requires MindSpore version number >= 1.5.0. - -### Launching the Task - -Refer to [Example](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_femnist) to launch cluster. The reference example directory structure is as follows. - -```text -cross_silo_femnist/ -├── config.json # Configuration file -├── finish_cross_silo_femnist.py # Close the cross-silo federated task script -├── run_cross_silo_femnist_sched.py # Start cross-silo federated scheduler script -├── run_cross_silo_femnist_server.py # Start cross-silo federated server script -├── run_cross_silo_femnist_worker.py # Start cross-silo federated worker script -├── run_cross_silo_femnist_worker_distributed.py # Start the cloud Federation distributed training worker script -└── test_cross_silo_femnist.py # Training scripts used by the client -``` - -1. Start Scheduler - - `run_cross_silo_femnist_sched.py` is a Python script provided for the user to start the `Scheduler` and supports modifying the configuration via argument passing `argparse`. The following command is executed, representing the `Scheduler` that starts this federated learning task with TCP port `5554`. - - ```sh - python run_cross_silo_femnist_sched.py --scheduler_manage_address=127.0.0.1:5554 - ``` - - The following print represents a successful start-up: - - ```sh - [INFO] FEDERATED(35566,7f4275895740,python):2022-10-09-15:23:22.450.205 [mindspore_federated/fl_arch/ccsrc/scheduler/scheduler.cc:35] Run] Scheduler started successfully. - [INFO] FEDERATED(35566,7f41f259d700,python):2022-10-09-15:23:22.450.357 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:90] Run] Start http server! - ``` - -2. Start Server - - `run_cross_silo_femnist_server.py` is a Python script for the user to start a number of `Server`, and supports modify the configuration via argument passing `argparse`. The following command is executed, representing the `Server` that starts this federated learning task, with an http start port of `5555` and a number of `servers` of `4`. - - ```sh - python run_cross_silo_femnist_server.py --local_server_num=4 --http_server_address=10.*.*.*:5555 - ``` - - The above command is equivalent to starting four `Server` processes, each with a federated learning service port of `5555`, `5556`, `5557` and `5558`. - -3. Start Worker - - `run_cross_silo_femnist_worker.py` is a Python script for the user to start a number of `worker`, and supports modify the configuration via argument passing `argparse`. The following command is executed, representing the `worker` that starts this federated learning task, with an http start port of `5555` and a number of `worker` of `4`. - - ```sh - python run_cross_silo_femnist_worker.py --dataset_path=/data_nfs/code/fed_user_doc/federated/tests/st/cross_silo_femnist/35_7_client_img/ --http_server_address=10.*.*.*:5555 - ``` - - At present, the `worker` node of the cloud federation supports the distributed training mode of single machine multi-card and multi-machine multi-card. `run_cross_silo_femnist_worker_distributed.py` is a Python script provided for users to start the distributed training of `worker` node. It also supports the configuration modification through parameter argparse. Execute the following instructions, representing the distributed `worker` that starts this federated learning task, where `device_num` represents the number of processes started by the `worker` cluster, `run_distribute` represents the distributed training started by the cluster, and its http start port is `5555`. The number of `orker` processes is `4`: - - ```sh - python run_cross_silo_femnist_worker_distributed.py --device_num=4 --run_distribute=True --dataset_path=/data_nfs/code/fed_user_doc/federated/tests/st/cross_silo_femnist/35_7_client_img/ --http_server_address=10.*.*.*:5555 - ``` - -After executing the above three commands, go to the `worker_0` folder in the current directory and check the `worker_0` log with the command `grep -rn "test acc" *` and you will see a print similar to the following: - -```sh -local epoch: 0, loss: 3.787421340711655, trian acc: 0.05342741935483871, test acc: 0.075 -``` - -Then it means that cross-silo federated learning is started successfully and `worker_0` is training, other workers can be viewed in a similar way. - -If worker has been started in distributed multi-card training mode, enter the folder `worker_distributed/log_output/` in the current directory, and run the command `grep -rn "test acc" *` to view the log of `worker` distributed cluster. You can see the following print: - -```text -local epoch: 0, loss: 2.3467453340711655, trian acc: 0.06532451988877687, test acc: 0.076 -``` - -Please refer to [yaml configuration notes](https://www.mindspore.cn/federated/docs/zh-CN/master/horizontal/federated_server_yaml.html) for the description of parameter configuration in the above script. - -### Viewing Log - -After successfully starting the task, the corresponding log file will be generated under the current directory `cross_silo_femnist` with the following log file directory structure: - -```text -cross_silo_femnist -├── scheduler -│ └── scheduler.log # Print the log during running scheduler -├── server_0 -│ └── server.log # Print the log during running server_0 -├── server_1 -│ └── server.log # Print the log during running server_1 -├── server_2 -│ └── server.log # Print the log during running server_2 -├── server_3 -│ └── server.log # Print the log during running server_3 -├── worker_0 -│ ├── ckpt # Store the aggregated model ckpt obtained by worker_0 at the end of each federation learning iteration -│ │ ├── 0-fl-ms-bs32-0epoch.ckpt -│ │ ├── 0-fl-ms-bs32-1epoch.ckpt -│ │ │ -│ │ │ ...... -│ │ │ -│ │ └── 0-fl-ms-bs32-19epoch.ckpt -│ └── worker.log # Record the output logs when worker_0 participates in the federated learning task -└── worker_1 - ├── ckpt # Store the aggregated model ckpt obtained by worker_1 at the end of each federation learning iteration - │ ├── 1-fl-ms-bs32-0epoch.ckpt - │ ├── 1-fl-ms-bs32-1epoch.ckpt - │ │ - │ │ ...... - │ │ - │ └── 1-fl-ms-bs32-19epoch.ckpt - └── worker.log # Record the output logs when worker_1 participates in the federated learning task -``` - -### Closing the Task - -If you want to exit in the middle, the following command is available: - -```sh -python finish_cross_silo_femnist.py --redis_port=2345 -``` - -Or wait until the training task is finished and then the cluster will exit automatically, no need to close it manually. - -### Results - -- Used data: - - The `35_7_client_img/` dataset generated in the `download dataset` section above - -- The number of client-side local training epochs: 20 - -- The total number of cross-silo federated learning iterations: 20 - -- Results (accuracy of the model on the client's test set after each iteration aggregation) - -`worker_0` result: - -```sh -worker_0/worker.log:7409:local epoch: 0, loss: 3.787421340711655, trian acc: 0.05342741935483871, test acc: 0.075 -worker_0/worker.log:14419:local epoch: 1, loss: 3.725699281115686, trian acc: 0.05342741935483871, test acc: 0.075 -worker_0/worker.log:21429:local epoch: 2, loss: 3.5285709657335795, trian acc: 0.19556451612903225, test acc: 0.16875 -worker_0/worker.log:28439:local epoch: 3, loss: 3.0393165519160608, trian acc: 0.4889112903225806, test acc: 0.4875 -worker_0/worker.log:35449:local epoch: 4, loss: 2.575952764115026, trian acc: 0.6854838709677419, test acc: 0.60625 -worker_0/worker.log:42459:local epoch: 5, loss: 2.2081101375296512, trian acc: 0.7782258064516129, test acc: 0.6875 -worker_0/worker.log:49470:local epoch: 6, loss: 1.9229739431736557, trian acc: 0.8054435483870968, test acc: 0.69375 -worker_0/worker.log:56480:local epoch: 7, loss: 1.7005576549999293, trian acc: 0.8296370967741935, test acc: 0.65625 -worker_0/worker.log:63490:local epoch: 8, loss: 1.5248727620766704, trian acc: 0.8407258064516129, test acc: 0.6375 -worker_0/worker.log:70500:local epoch: 9, loss: 1.3838803705352127, trian acc: 0.8568548387096774, test acc: 0.7 -worker_0/worker.log:77510:local epoch: 10, loss: 1.265225578921041, trian acc: 0.8679435483870968, test acc: 0.7125 -worker_0/worker.log:84520:local epoch: 11, loss: 1.167484122101638, trian acc: 0.8659274193548387, test acc: 0.70625 -worker_0/worker.log:91530:local epoch: 12, loss: 1.082880981700859, trian acc: 0.8770161290322581, test acc: 0.65625 -worker_0/worker.log:98540:local epoch: 13, loss: 1.0097520119572772, trian acc: 0.8840725806451613, test acc: 0.64375 -worker_0/worker.log:105550:local epoch: 14, loss: 0.9469810053708015, trian acc: 0.9022177419354839, test acc: 0.7 -worker_0/worker.log:112560:local epoch: 15, loss: 0.8907848935604703, trian acc: 0.9022177419354839, test acc: 0.6625 -worker_0/worker.log:119570:local epoch: 16, loss: 0.8416629644123349, trian acc: 0.9082661290322581, test acc: 0.70625 -worker_0/worker.log:126580:local epoch: 17, loss: 0.798475691030866, trian acc: 0.9122983870967742, test acc: 0.70625 -worker_0/worker.log:133591:local epoch: 18, loss: 0.7599438544427897, trian acc: 0.9243951612903226, test acc: 0.6875 -worker_0/worker.log:140599:local epoch: 19, loss: 0.7250227383907605, trian acc: 0.9294354838709677, test acc: 0.7125 -``` - -The test results of other clients are basically the same, so the details are not listed herein. diff --git a/docs/federated/docs/source_en/images/HFL_en.png b/docs/federated/docs/source_en/images/HFL_en.png deleted file mode 100644 index f7b5adac95b8dff2fc010fa49607c706b67daab7..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/HFL_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/VFL_en.png b/docs/federated/docs/source_en/images/VFL_en.png deleted file mode 100644 index 818bb3de139b9ae2d499c18383d08f324e2b634d..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/VFL_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/create_android_project.png b/docs/federated/docs/source_en/images/create_android_project.png deleted file mode 100644 index a519264c4158fba67eb1ff5f5fbc3eae65b32363..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/create_android_project.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/data_join_en.png b/docs/federated/docs/source_en/images/data_join_en.png deleted file mode 100644 index 9cd2b73335d611e9532867f08e675054244561aa..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/data_join_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/deploy_VFL_en.png b/docs/federated/docs/source_en/images/deploy_VFL_en.png deleted file mode 100644 index 390edb1cd1be92e8234a79b36bd10571fba092f0..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/deploy_VFL_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/download_compress_client_en.png b/docs/federated/docs/source_en/images/download_compress_client_en.png deleted file mode 100644 index ab2b84783c1478091c2f9d600e5b051a2f4149dd..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/download_compress_client_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/download_compress_server_en.png b/docs/federated/docs/source_en/images/download_compress_server_en.png deleted file mode 100644 index 7917007287e2a7cf84adf3ecd559ee46b49b879b..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/download_compress_server_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/label_dp_en.png b/docs/federated/docs/source_en/images/label_dp_en.png deleted file mode 100644 index 059f9bb61fdb129697d561cf476ef6e09adf4ecb..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/label_dp_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/lenet_50_clients_acc_en.png b/docs/federated/docs/source_en/images/lenet_50_clients_acc_en.png deleted file mode 100644 index 49d69ab1e14d37c9ae54562f6be09bddae24a570..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/lenet_50_clients_acc_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/lenet_signds_loss_auc.png b/docs/federated/docs/source_en/images/lenet_signds_loss_auc.png deleted file mode 100644 index 7304b69c4d0abf039549dce758b906d688213e4f..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/lenet_signds_loss_auc.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/mindspore_federated_networking.png b/docs/federated/docs/source_en/images/mindspore_federated_networking.png deleted file mode 100644 index 4340cb66b638e072ffdb11167743cc45c36a9536..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/mindspore_federated_networking.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/splitnn_pangu_alpha_en.png b/docs/federated/docs/source_en/images/splitnn_pangu_alpha_en.png deleted file mode 100644 index 6a423ab3dbe0d43e858d48bd3f6de7ad92c3f040..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/splitnn_pangu_alpha_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/splitnn_wide_and_deep_en.png b/docs/federated/docs/source_en/images/splitnn_wide_and_deep_en.png deleted file mode 100644 index 0a281a046fbb44c59be02a31ee0b15092cc041ec..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/splitnn_wide_and_deep_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/start_android_project.png b/docs/federated/docs/source_en/images/start_android_project.png deleted file mode 100644 index 3a9336add10acbbef60dc429b8a3bad1ca198c38..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/start_android_project.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/upload_compress_server_en.png b/docs/federated/docs/source_en/images/upload_compress_server_en.png deleted file mode 100644 index 02e554d3c3977c7f041e6c68fc434a794c2e8a8a..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/upload_compress_server_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/upload_compression_client_en.png b/docs/federated/docs/source_en/images/upload_compression_client_en.png deleted file mode 100644 index 8690363d2c4e29dfbe19927c4e65d2e80bf82e79..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/upload_compression_client_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_1_en.png b/docs/federated/docs/source_en/images/vfl_1_en.png deleted file mode 100644 index 2cefd4c529b4df0dc5f6a36c667f02fcaa608fdc..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_1_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_backward_en.png b/docs/federated/docs/source_en/images/vfl_backward_en.png deleted file mode 100644 index 867a4698dac1cac5a4233f20635074669747e96a..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_backward_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_feature_reconstruction_defense_en.png b/docs/federated/docs/source_en/images/vfl_feature_reconstruction_defense_en.png deleted file mode 100644 index 608017209c42d910796a461337271855eebaad80..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_feature_reconstruction_defense_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_feature_reconstruction_en.png b/docs/federated/docs/source_en/images/vfl_feature_reconstruction_en.png deleted file mode 100644 index 985eadb43d776efac54127e17e92dafd76106321..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_feature_reconstruction_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_forward_en.png b/docs/federated/docs/source_en/images/vfl_forward_en.png deleted file mode 100644 index 37e6fbec1ff5b7f0c85bdc4fb86e0a7957b6ca54..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_forward_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_normal_communication_compress_en.png b/docs/federated/docs/source_en/images/vfl_normal_communication_compress_en.png deleted file mode 100644 index 38bd38616ff55da6ae4439fd8f0ca1951990eb78..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_normal_communication_compress_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_pangu_communication_compress_en.png b/docs/federated/docs/source_en/images/vfl_pangu_communication_compress_en.png deleted file mode 100644 index a9c10552966f3f8d1247665347bbc11f72287588..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_pangu_communication_compress_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/vfl_with_tee_en.png b/docs/federated/docs/source_en/images/vfl_with_tee_en.png deleted file mode 100644 index 78732f01117e0e5e6e1f358e94563ae3f2ce44b5..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/vfl_with_tee_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/weight_diff_decode_en.png b/docs/federated/docs/source_en/images/weight_diff_decode_en.png deleted file mode 100644 index c0a48f84a02a828b3a7b9e4c1206502a8786bdd5..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/weight_diff_decode_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/images/weight_diff_encode_en.png b/docs/federated/docs/source_en/images/weight_diff_encode_en.png deleted file mode 100644 index 54df35c2ac7ac2169e6f91e25bd17c797e2586a0..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_en/images/weight_diff_encode_en.png and /dev/null differ diff --git a/docs/federated/docs/source_en/index.rst b/docs/federated/docs/source_en/index.rst deleted file mode 100644 index 510c30a9b9a46c6dd24033eb7b07616e44ff20f1..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/index.rst +++ /dev/null @@ -1,177 +0,0 @@ -.. MindSpore documentation master file, created by - sphinx-quickstart on Thu Mar 24 11:00:00 2020. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -MindSpore Federated Documents -================================ - -MindSpore Federated is an open source federated learning tool for MindSpore, and it enables all-scenario intelligent applications when user data is stored locally. - -Federated learning is a cryptographically distributed machine learning technique for solving data silos and performing efficient, secure and reliable machine learning across multiple parties or multiple resource computing nodes. Support the various participants of machine learning to build AI models together without directly sharing local data, including but not limited to mainstream deep learning models such as ad recommendation, classification, and detection, mainly applied in finance, medical, recommendation and other fields. - -MindSpore Federated provides a horizontal federated model with sample federation and a vertical federation model with feature federation. Support commercial deployment for millions of stateless terminal devices, as well as cloud federated between data centers across trusted zones. - -Code repository address: - -Advantages of the MindSpore Federated Horizontal Framework ------------------------------------------------------------ - -Horizontal Federated Architecture: - -.. raw:: html - - - -1. Privacy Protection - - It supports accuracy-lossless security aggregation solution based on secure multi-party computation (MPC) to prevent model theft. - - It supports performance-lossless encryption based on local differential privacy to prevent private data leakage from models. - - It supports a gradient protection scheme based on Symbolic Dimensional Selection (SignDS), which prevents model privacy data leakage while reducing communication overhead by 99%. - -2. Distributed Federated Aggregation - - The loosely coupled cluster processing mode on the cloud and distributed gradient quadratic aggregation paradigms support the deployment of tens of millions of heterogeneous devices, implements high-performance and high-availability federated aggregation computing, and can cope with network instability and sudden load changes. - -3. Federated Learning Efficiency Improvement - - The adaptive frequency modulation strategy and gradient compression algorithm are supported to improve the federated learning efficiency and saving bandwidth resources. - - Multiple federated aggregation policies are supported to improve the smoothness of federated learning convergence and optimize both global and local accuracies. - -4. Easy to Use - - Only one line of code is required to switch between the standalone training and federated learning modes. - - The network models, aggregation algorithms, and security algorithms are programmable, and the security level can be customized. - - It supports the effectiveness evaluation of federated training models and provides monitoring capabilities for federated tasks. - -Advantages of the MindSpore Federated Vertical Framework ------------------------------------------------------------ - -Vertical Federated Architecture: - -.. raw:: html - - - -1. Privacy Protection - - Support high-performance Privacy Set Intersection Protocol (PSI), which prevents federated participants from obtaining ID information outside the intersection and can cope with data imbalance scenarios. - - Support feature protection software solution that combines quantization and differential privacy, to prevent attackers from reconstructing original privacy data from intermediate features. - - Support feature protection hardware solution which is based on trusted execution environment, to provide high-strength and efficient feature protection capabilities. - - Support label protection solution which is based on differential privacy, to prevent the leakage of user label data. - -2. Federated training - - Support multiple types of split learning network structures. - - Cross-domain training for large models with pipelined parallel optimization. - -MindSpore Federated Working Process ------------------------------------- - -1. `Scenario Identification and Data Accumulation `_ - - Identify scenarios where federated learning is used and accumulate local data for federated tasks on the client. - -2. `Model Selection and Framework Deployment `_ - - Select or develop a model prototype and use a tool to generate a federated learning model that is easy to deploy. - -3. `Application Deployment `_ - - Deploy the corresponding components to the business application and set up federated configuration tasks and deployment scripts on the server. - -Common Application Scenarios ----------------------------- - -1. `Image Classification `_ - - Use the federated learning to implement image classification applications. - -2. `Text Classification `_ - - Use the federated learning to implement text classification applications. - -.. toctree:: - :maxdepth: 1 - :caption: Deployment - - federated_install - deploy_federated_server - deploy_federated_client - deploy_vfl - -.. toctree:: - :maxdepth: 1 - :caption: Horizontal Application - - image_classfication_dataset_process - image_classification_application - sentiment_classification_application - image_classification_application_in_cross_silo - object_detection_application_in_cross_silo - -.. toctree:: - :maxdepth: 1 - :caption: Vertical Application - - data_join - split_wnd_application - split_pangu_alpha_application - -.. toctree:: - :maxdepth: 1 - :caption: Security and Privacy - - local_differential_privacy_training_noise - local_differential_privacy_training_signds - local_differential_privacy_eval_laplace - pairwise_encryption_training - private_set_intersection - secure_vertical_federated_learning_with_EmbeddingDP - secure_vertical_federated_learning_with_TEE - secure_vertical_federated_learning_with_DP - -.. toctree:: - :maxdepth: 1 - :caption: Communication Compression - - communication_compression - vfl_communication_compress - -.. toctree:: - :maxdepth: 1 - :caption: Horizontal Federated API Reference - - horizontal_server - cross_device - horizontal/cross_silo - -.. toctree:: - :maxdepth: 1 - :caption: Vertical Federated API Reference - - Data_Join - vertical/vertical_communicator - vertical_federated_trainer - -.. toctree:: - :maxdepth: 1 - :caption: References - - faq - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE diff --git a/docs/federated/docs/source_en/interface_description_federated_client.md b/docs/federated/docs/source_en/interface_description_federated_client.md deleted file mode 100644 index abc50a417e61d83886203b57c6b04c5cdc319ea8..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/interface_description_federated_client.md +++ /dev/null @@ -1,350 +0,0 @@ -# Examples - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/interface_description_federated_client.md) - -Note that before using the following interfaces, you can first refer to the document [on-device deployment](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html) to deploy related environments. - -## flJobRun() for Starting Federated Learning - -Before calling the flJobRun() API, instantiate the parameter class FLParameter and set related parameters as follows: - -| Parameter | Type | Mandatory | Description | Remarks | -| -------------------- | ---------------------------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | -| dataMap | Map/> | Y | The path of Federated learning dataset. | The dataset of Map/> type, the key in the map is the RunType enumeration type, the value is the corresponding dataset list, when the key is RunType.TRAINMODE, the corresponding value is the training-related dataset list, when the key is RunType.EVALMODE, it means that the corresponding value is a list of verification-related datasets, and when the key is RunType.INFERMODE, it means that the corresponding value is a list of inference-related datasets. | -| flName | String | Y | The package path of model script used by federated learning. | We provide two types of model scripts for your reference ([Supervised sentiment classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert)), ([LeNet image classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)). For supervised sentiment classification tasks, this parameter can be set to the package path of the provided script file [AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java), like as `com.mindspore.flclient.demo.albert.AlbertClient`; for LeNet image classification tasks, this parameter can be set to the package path of the provided script file [LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java), like as `com.mindspore.flclient.demo.lenet.LenetClient`. At the same time, users can refer to these two types of model scripts, define the model script by themselves, and then set the parameter to the package path of the customized model file ModelClient.java (which needs to inherit from the class [Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java)). | -| trainModelPath | String | Y | Path of a training model used for federated learning, which is an absolute path of the .ms file. | It is recommended to set the path to the training App's own directory to protect the data access security of the model itself. | -| inferModelPath | String | Y | Path of an inference model used for federated learning, which is an absolute path of the .ms file. | For the normal federated learning mode (training and inference use the same model), the value of this parameter needs to be the same as that of `trainModelPath`; for the hybrid learning mode (training and inference use different models, and the server side also includes training process), this parameter is set to the path of actual inference model. It is recommended to set the path to the training App's own directory to protect the data access security of the model itself. | -| sslProtocol | String | N | The TLS protocol version used by the device-cloud HTTPS communication. | A whitelist is set, and currently only "TLSv1.3" or "TLSv1.2" is supported. Only need to set it up in the HTTPS communication scenario. | -| deployEnv | String | Y | The deployment environment for federated learning. | A whitelist is set, currently only "x86", "android" are supported. | -| certPath | String | N | The self-signed root certificate path used for device-cloud HTTPS communication. | When the deployment environment is "x86" and the device-cloud uses a self-signed certificate for HTTPS communication authentication, this parameter needs to be set. The certificate must be consistent with the CA root certificate used to generate the cloud-side self-signed certificate to pass the verification. This parameter is used for non-Android scenarios. | -| domainName | String | Y | The url for device-cloud communication. | Currently, https and http communication are supported, the corresponding formats are like: https://......, http://......, and when `useElb` is set to true, the format must be: https://127.0.0.0 : 6666 or http://127.0.0.0 : 6666 , where `127.0.0.0` corresponds to the ip of the machine providing cloud-side services (corresponding to the cloud-side parameter `--scheduler_ip`), and `6666` corresponds to the cloud-side parameter `--fl_server_port`. | -| ifUseElb | boolean | N | Used for multi-server scenarios to set whether to randomly send client requests to different servers within a certain range. | Setting to true means that the client will randomly send requests to a certain range of server addresses, and false means that the client's requests will be sent to a fixed server address. This parameter is used in non-Android scenarios, and the default value is false. | -| serverNum | int | N | The number of servers that the client can choose to connect to. | When `ifUseElb` is set to true, it can be set to be consistent with the `server_num` parameter when the server is started on the cloud side. It is used to randomly select different servers to send information. This parameter is used in non-Android scenarios. The default value is 1. | -| ifPkiVerify | boolean | N | The switch of device-cloud identity authentication. | Set to true to enable device-cloud security authentication, set to false to disable, and the default value is false. Identity authentication requires HUKS to provide a certificate. This parameter is only used in the Android environment (currently only supports HUAWEI phones). | -| threadNum | int | N | The number of threads used in federated learning training and inference. | The default value is 1. | -| cpuBindMode | BindMode | N | The cpu core that threads need to bind during federated learning training and inference. | It is the enumeration type `BindMode`, where BindMode.NOT_BINDING_CORE represents the unbound core, which is automatically assigned by the system, BindMode.BIND_LARGE_CORE represents the bound large core, and BindMode.BIND_MIDDLE_CORE represents the bound middle core. The default value is BindMode.NOT_BINDING_CORE. | -| batchSize | int | Y | The number of single-step training samples used in federated learning training and inference, that is, batch size. | It needs to be consistent with the batch size of the input data of the model. | -| iflJobResultCallback | IFLJobResultCallback | N | The federated learning callback function object `iflJobResultCallback`. | The user can implement the specific method of the interface class [IFLJobResultCallback.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/IFLJobResultCallback.java) in the project according to the needs of the actual scene, and set it as a callback function object in the federated learning task. We provide a simple implementation use case [FLJobResultCallback.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/IFLJobResultCallback.java) as the default value of this parameter. | - -Note 1: When using HTTP communication, there may exist communication security risks, please be aware. - -Note 2: In the Android environment, the following parameters need to be set when using HTTPS communication. The setting examples are as follows: - -```java -FLParameter flParameter = FLParameter.getInstance(); -SecureSSLSocketFactory sslSocketFactory = SecureSSLSocketFactory.getInstance(applicationContext) -SecureX509TrustManager x509TrustManager = new SecureX509TrustManager(applicationContext); -flParameter.setSslSocketFactory(sslSocketFactory); -flParameter.setX509TrustManager(x509TrustManager); -``` - -Among them, the two objects `SecureSSLSocketFactory` and `SecureX509TrustManager` need to be implemented in the Android project, and users need to design by themselves according to the type of certificate in the mobile phone. - -Note 3: In the x86 environment, currently only self-signed certificate authentication is supported when using HTTPS communication, and the following parameters need to be set. The setting examples are as follows: - -```java -FLParameter flParameter = FLParameter.getInstance(); -String certPath = "CARoot.pem"; // the self-signed root certificate path used for device-cloud HTTPS communication. -flParameter.setCertPath(certPath); -``` - -Note 4: In the Android environment, when `pkiVerify` is set to true and encrypt_train_type is set to PW_ENCRYPT on the cloud side, the following parameters need to be set. The setting examples are as follows: - -```java -FLParameter flParameter = FLParameter.getInstance(); -String equipCrlPath = certPath; -long validIterInterval = 3600000; -flParameter.setEquipCrlPath(equipCrlPath); -flParameter.setValidInterval(validIterInterval); -``` - -Among them, `equipCrlPath` is the CRL certificate required for certificate verification among devices, that is, the certificate revocation list. Generally, the device certificate CRL in "Huawei CBG Certificate Revocation Lists" can be preset; `validIterInterval` which is used to help prevent replay attacks in PW_ENCRYPT mode can generally be set to the time required for each round of device-cloud aggregation (unit: milliseconds, the default value is 3600000). - -Note 5: Before each federated learning task is started, the FLParameter class will be instantiated for related parameter settings. When FLParameter is instantiated, a clientID is automatically generated randomly, which is used to uniquely identify the client during the interaction with the cloud side. If the user needs to set the clientID by himself, after instantiating the FLParameter class, call its setClientID method to set it, and then after starting the federated learning task, the clientID set by the user will be used. - -Create a SyncFLJob object and use the flJobRun() method of the SyncFLJob class to start a federated learning task. - -The sample code (basic http communication) is as follows: - -1. Sample code of a supervised sentiment classification task - - ```java - // create dataMap - String trainTxtPath = "data/albert/supervise/client/1.txt"; - String evalTxtPath = "data/albert/supervise/eval/eval.txt"; // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - String vocabFile = "data/albert/supervise/vocab.txt"; // Path of the dictionary file for data preprocessing. - String idsFile = "data/albert/supervise/vocab_map_ids.txt" // Path of the mapping ID file of a dictionary. - Map> dataMap = new HashMap<>(); - List trainPath = new ArrayList<>(); - trainPath.add(trainTxtPath); - trainPath.add(vocabFile); - trainPath.add(idsFile); - List evalPath = new ArrayList<>(); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(evalTxtPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(vocabFile); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(idsFile); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - dataMap.put(RunType.TRAINMODE, trainPath); - dataMap.put(RunType.EVALMODE, evalPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // The package path of AlBertClient.java - String trainModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - - // start FLJob - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.flJobRun(); - ``` - -2. Sample code of a LeNet image classification task - - ```java - // create dataMap - String trainImagePath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_9_train_data.bin"; - String trainLabelPath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_9_train_label.bin"; - String evalImagePath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_data.bin"; // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - String evalLabelPath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_label.bin"; // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - Map> dataMap = new HashMap<>(); - List trainPath = new ArrayList<>(); - trainPath.add(trainImagePath); - trainPath.add(trainLabelPath); - List evalPath = new ArrayList<>(); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(evalImagePath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(evalLabelPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - dataMap.put(RunType.TRAINMODE, trainPath); - dataMap.put(RunType.EVALMODE, evalPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - - String flName = "com.mindspore.flclient.demo.lenet.LenetClient"; // The package path of LenetClient.java - String trainModelPath = "SyncFLClient/lenet_train.mindir0.ms"; // Absolute path - String inferModelPath = "SyncFLClient/lenet_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - flParameter.setBatchSize(batchSize); - - // start FLJob - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.flJobRun(); - ``` - -## modelInference() for Inferring Multiple Input Data Records - -Before calling the modelInference() API, instantiate the parameter class FLParameter and set related parameters as follows: - -| Parameter | Type | Mandatory | Description | Remarks | -| -------------- | ---------------------------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | -| flName | String | Y | The package path of model script used by federated learning. | We provide two types of model scripts for your reference ([Supervised sentiment classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert), [LeNet image classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)). For supervised sentiment classification tasks, this parameter can be set to the package path of the provided script file [AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java), like as `com.mindspore.flclient.demo.albert.AlbertClient`; for LeNet image classification tasks, this parameter can be set to the package path of the provided script file [LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java), like as `com.mindspore.flclient.demo.lenet.LenetClient`. At the same time, users can refer to these two types of model scripts, define the model script by themselves, and then set the parameter to the package path of the customized model file ModelClient.java (which needs to inherit from the class [Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java)). | -| dataMap | Map/> | Y | The path of Federated learning dataset. | The dataset of Map/> type, the key in the map is the RunType enumeration type, the value is the corresponding dataset list, when the key is RunType.TRAINMODE, the corresponding value is the training-related dataset list, when the key is RunType.EVALMODE, it means that the corresponding value is a list of verification-related datasets, and when the key is RunType.INFERMODE, it means that the corresponding value is a list of inference-related datasets. | -| inferModelPath | String | Y | Path of an inference model used for federated learning, which is an absolute path of the .ms file. | For the normal federated learning mode (training and inference use the same model), the value of this parameter needs to be the same as that of `trainModelPath`; for the hybrid learning mode (training and inference use different models, and the server side also includes training process), this parameter is set to the path of actual inference model. It is recommended to set the path to the training App's own directory to protect the data access security of the model itself. | -| threadNum | int | N | The number of threads used in federated learning training and inference. | The default value is 1. | -| cpuBindMode | BindMode | N | The cpu core that threads need to bind during federated learning training and inference. | It is the enumeration type `BindMode`, where BindMode.NOT_BINDING_CORE represents the unbound core, which is automatically assigned by the system, BindMode.BIND_LARGE_CORE represents the bound large core, and BindMode.BIND_MIDDLE_CORE represents the bound middle core. The default value is BindMode.NOT_BINDING_CORE. | -| batchSize | int | Y | The number of single-step training samples used in federated learning training and inference, that is, batch size. | It needs to be consistent with the batch size of the input data of the model. | - -Create a SyncFLJob object and use the modelInference() method of the SyncFLJob class to start an inference task on the device. The inferred label array is returned. - -The sample code is as follows: - -1. Sample code of a supervised sentiment classification task - - ```java - // create dataMap - String inferTxtPath = "data/albert/supervise/eval/eval.txt"; - String vocabFile = "data/albert/supervise/vocab.txt"; - String idsFile = "data/albert/supervise/vocab_map_ids.txt" - Map> dataMap = new HashMap<>(); - List inferPath = new ArrayList<>(); - inferPath.add(inferTxtPath); - inferPath.add(vocabFile); - inferPath.add(idsFile); - dataMap.put(RunType.INFERMODE, inferPath); - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // The package path of AlBertClient.java - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setInferModelPath(inferModelPath); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - flParameter.setBatchSize(batchSize); - - // inference - SyncFLJob syncFLJob = new SyncFLJob(); - int[] labels = syncFLJob.modelInference(); - ``` - -2. Sample code of a LeNet image classification - - ```java - // create dataMap - String inferImagePath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_data.bin"; - String inferLabelPath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_label.bin"; - Map> dataMap = new HashMap<>(); - List inferPath = new ArrayList<>(); - inferPath.add(inferImagePath); - inferPath.add(inferLabelPath); - dataMap.put(RunType.INFERMODE, inferPath); - - String flName = "com.mindspore.flclient.demo.lenet.LenetClient"; // The package path of LenetClient.java package - String inferModelPath = "SyncFLClient/lenet_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setInferModelPath(inferModelPath); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - flParameter.setBatchSize(batchSize); - - // inference - SyncFLJob syncFLJob = new SyncFLJob(); - int[] labels = syncFLJob.modelInference(); - ``` - -## getModel() for Obtaining the Latest Model on the Cloud - -Before calling the getModel() API, instantiate the parameter class FLParameter and set related parameters as follows: - -| Parameter | Type | Mandatory | Description | Remarks | -| -------------- | --------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | -| flName | String | Y | The package path of model script used by federated learning. | We provide two types of model scripts for your reference ([Supervised sentiment classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert), [LeNet image classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)). For supervised sentiment classification tasks, this parameter can be set to the package path of the provided script file [AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java), like as `com.mindspore.flclient.demo.albert.AlbertClient`; for LeNet image classification tasks, this parameter can be set to the package path of the provided script file [LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java), like as `com.mindspore.flclient.demo.lenet.LenetClient`. At the same time, users can refer to these two types of model scripts, define the model script by themselves, and then set the parameter to the package path of the customized model file ModelClient.java (which needs to inherit from the class [Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java)). | -| trainModelPath | String | Y | Path of a training model used for federated learning, which is an absolute path of the .ms file. | It is recommended to set the path to the training App's own directory to protect the data access security of the model itself. | -| inferModelPath | String | Y | Path of an inference model used for federated learning, which is an absolute path of the .ms file. | For the normal federated learning mode (training and inference use the same model), the value of this parameter needs to be the same as that of `trainModelPath`; for the hybrid learning mode (training and inference use different models, and the server side also includes training process), this parameter is set to the path of actual inference model. It is recommended to set the path to the training App's own directory to protect the data access security of the model itself. | -| sslProtocol | String | N | The TLS protocol version used by the device-cloud HTTPS communication. | A whitelist is set, and currently only "TLSv1.3" or "TLSv1.2" is supported. Only need to set it up in the HTTPS communication scenario. | -| deployEnv | String | Y | The deployment environment for federated learning. | A whitelist is set, currently only "x86", "android" are supported. | -| certPath | String | N | The self-signed root certificate path used for device-cloud HTTPS communication. | When the deployment environment is "x86" and the device-cloud uses a self-signed certificate for HTTPS communication authentication, this parameter needs to be set. The certificate must be consistent with the CA root certificate used to generate the cloud-side self-signed certificate to pass the verification. This parameter is used for non-Android scenarios. | -| domainName | String | Y | The url for device-cloud communication. | Currently, https and http communication are supported, the corresponding formats are like: https://......, http://......, and when `useElb` is set to true, the format must be: https://127.0.0.0 : 6666 or http://127.0.0.0 : 6666 , where `127.0.0.0` corresponds to the ip of the machine providing cloud-side services (corresponding to the cloud-side parameter `--scheduler_ip`), and `6666` corresponds to the cloud-side parameter `--fl_server_port`. | -| ifUseElb | boolean | N | Used for multi-server scenarios to set whether to randomly send client requests to different servers within a certain range. | Setting to true means that the client will randomly send requests to a certain range of server addresses, and false means that the client's requests will be sent to a fixed server address. This parameter is used in non-Android scenarios, and the default value is false. | -| serverNum | int | N | The number of servers that the client can choose to connect to. | When `ifUseElb` is set to true, it can be set to be consistent with the `server_num` parameter when the server is started on the cloud side. It is used to randomly select different servers to send information. This parameter is used in non-Android scenarios. The default value is 1. | -| serverMod | ServerMod | Y | The federated learning training mode. | The federated learning training mode of ServerMod enumeration type, where ServerMod.FEDERATED_LEARNING represents the normal federated learning mode (training and inference use the same model) ServerMod.HYBRID_TRAINING represents the hybrid learning mode (training and inference use different models, and the server side also includes training process). | - -Note 1: When using HTTP communication, there may exist communication security risks, please be aware. - -Note 2: In the Android environment, the following parameters need to be set when using HTTPS communication. The setting examples are as follows: - -```java -FLParameter flParameter = FLParameter.getInstance(); -SecureSSLSocketFactory sslSocketFactory = SecureSSLSocketFactory.getInstance(applicationContext) -SecureX509TrustManager x509TrustManager = new SecureX509TrustManager(applicationContext); -flParameter.setSslSocketFactory(sslSocketFactory); -flParameter.setX509TrustManager(x509TrustManager); -``` - -Among them, the two objects `SecureSSLSocketFactory` and `SecureX509TrustManager` need to be implemented in the Android project, and users need to design themselves according to the type of certificate in the mobile phone. - -Note 3: In the x86 environment, currently only self-signed certificate authentication is supported when using HTTPS communication, and the following parameters need to be set. The setting examples are as follows: - -```java -FLParameter flParameter = FLParameter.getInstance(); -String certPath = "CARoot.pem"; // the self-signed root certificate path used for device-cloud HTTPS communication. -flParameter.setCertPath(certPath); -``` - -Note 4: Before calling the getModel method, the FLParameter class will be instantiated for related parameter settings. When FLParameter is instantiated, a clientID is automatically generated randomly, which is used to uniquely identify the client during the interaction with the cloud side. If the user needs to set the clientID by himself, after instantiating the FLParameter class, call its setCertPath method to set it, and then after starting the getModel task, the clientID set by the user will be used. - -Create a SyncFLJob object and use the getModel() method of the SyncFLJob class to start an asynchronous inference task. The status code of the getModel request is returned. - -The sample code is as follows: - -1. Supervised sentiment classification task - - ```java - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // The package path of AlBertClient.java package - String trainModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - ServerMod serverMod = ServerMod.FEDERATED_LEARNING; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setServerMod(ServerMod.valueOf(serverMod)); - - // getModel - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.getModel(); - ``` - -2. LeNet image classification task - - ```java - String flName = "com.mindspore.flclient.demo.lenet.LenetClient"; // The package path of LenetClient.java package - String trainModelPath = "SyncFLClient/lenet_train.mindir0.ms"; // Absolute path - String inferModelPath = "SyncFLClient/lenet_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4 - ServerMod serverMod = ServerMod.FEDERATED_LEARNING; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setServerMod(ServerMod.valueOf(serverMod)); - - // getModel - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.getModel(); - ``` diff --git a/docs/federated/docs/source_en/java_api_callback.md b/docs/federated/docs/source_en/java_api_callback.md deleted file mode 100644 index b1fc6d1a83771313bc8f678d7b623b5f2969097f..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/java_api_callback.md +++ /dev/null @@ -1,66 +0,0 @@ -# Callback - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/java_api_callback.md) - -```java -import com.mindspore.flclient.model.Callback -``` - -Callback defines the hook function used to record training, evaluate and predict the results of different stages in end-to-side federated learning. - -## Public Member Functions - -| function | -| -------------------------------- | -| [abstract Status stepBegin()](#stepbegin) | -| [abstract Status stepEnd()](#stepend) | -| [abstract Status epochBegin()](#epochbegin) | -| [abstract Status epochEnd()](#epochend) | - -## stepBegin - -```java - public abstract Status stepBegin() -``` - -Execute step begin function. - -- Returns - - Whether the execution is successful. - -## stepEnd - -```java -public abstract Status stepEnd() -``` - -Execute step end function. - -- Returns - - Whether the execution is successful. - -## epochBegin - -```java -public abstract Status epochBegin() -``` - -Execute epoch begin function. - -- Returns - - Whether the execution is successful. - -## epochEnd - -```java -public abstract Status epochEnd() -``` - -Execute epoch end function. - -- Returns - - Whether the execution is successful. diff --git a/docs/federated/docs/source_en/java_api_client.md b/docs/federated/docs/source_en/java_api_client.md deleted file mode 100644 index 92816018eafe3c27f6b0e72cdb75080b8be53c83..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/java_api_client.md +++ /dev/null @@ -1,173 +0,0 @@ -# Client - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/java_api_client.md) - -```java -import com.mindspore.flclient.model.Client -``` - -Client defines the execution process object of the end-side federated learning algorithm. - -## Public Member Functions - -| function | -| -------------------------------- | -| [abstract List initCallbacks(RunType runType, DataSet dataSet)](#initcallbacks) | -| [abstract Map initDataSets(Map\> files)](#initdatasets) | -| [abstract float getEvalAccuracy(List evalCallbacks)](#getevalaccuracy) | -| [abstract List getInferResult(List inferCallbacks)](#getinferresult) | -| [Status trainModel(int epochs)](#trainmodel) | -| [float evalModel()](#evalmodel) | -| [Map genUnsupervisedEvalData(List evalCallbacks)](#genunsupervisedevaldata) | -| [List inferModel()](#infermodel) | -| [Status setLearningRate(float lr)](#setlearningrate) | -| [void setBatchSize(int batchSize)](#setbatchsize) | - -## initCallbacks - -```java -public abstract List initCallbacks(RunType runType, DataSet dataSet) -``` - -Initialize the callback list. - -- Parameters - - - `runType`: RunType class, identify whether the training, evaluation or prediction phase. - - `dataSet`: DataSet class, identify whether the training, evaluation or prediction phase datasets. - -- Returns - - The initialized callback list. - -## initDataSets - -```java -public abstract Map initDataSets(Map> files) -``` - -Initialize dataset list. - -- Parameters - - - `files`: Data files used in the training, evaluation or prediction phase. - -- Returns - - Data counts in different run type. - -## getEvalAccuracy - -```java -public abstract float getEvalAccuracy(List evalCallbacks) -``` - -Get eval model accuracy. - -- Parameters - - - `evalCallbacks`: Callback used in eval phase. - -- Returns - - The accuracy in eval phase. - -## getInferResult - -```java -public abstract List getInferResult(List inferCallbacks) -``` - -Get infer phase result. - -- Parameters - - - `inferCallbacks`: Callback used in prediction phase. - -- Returns - - predict results. - -## trainModel - -```java -public Status trainModel(int epochs) -``` - -Execute train model process. - -- Parameters - - - `epochs`: Epoch num used in train process. - -- Returns - - Whether the train model is successful. - -## evalModel - -```java -public float evalModel() -``` - -Execute eval model process. - -- Returns - - The accuracy in eval process. - -## genUnsupervisedEvalData - -```java -public Map genUnsupervisedEvalData(List evalCallbacks) -``` - -Generate unsupervised training evaluation data, and the subclass needs to rewrite this function. - -- Parameters - - - `evalCallbacks`: the eval Callback that generates data. - -- Returns - - unsupervised training evaluation data - -## inferModel - -```java -public List inferModel() -``` - -Execute model prediction process. - -- Returns - - The prediction result. - -## setLearningRate - -```java -public Status setLearningRate(float lr) -``` - -Set learning rate. - -- Parameters - - - `lr`: Learning rate. - -- Returns - - Whether the set is successful. - -## setBatchSize - -```java -public void setBatchSize(int batchSize) -``` - -Set batch size. - -- Parameters - - - `batchSize`: batch size. diff --git a/docs/federated/docs/source_en/java_api_clientmanager.md b/docs/federated/docs/source_en/java_api_clientmanager.md deleted file mode 100644 index 03d7ae2f519c476631ab90a58f19b1742a51fb8f..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/java_api_clientmanager.md +++ /dev/null @@ -1,44 +0,0 @@ -# ClientManager - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/java_api_clientmanager.md) - -```java -import com.mindspore.flclient.model.ClientManager -``` - -ClientManager defines end-side federated learning custom algorithm model management objects. - -## Public Member Functions - -| function | -| -------------------------------- | -| [static void registerClient(Client client)](#registerclient) | -| [static Client getClient(String name)](#getclient) | - -## registerClient - -```java -public static void registerClient(Client client) -``` - -Register client object. - -- Parameters - - - `client`: Need register client object. - -## getClient - -```java -public static Client getClient(String name) -``` - -Get client object. - -- Parameters - - - `name`: Client object name. - -- Returns - - Client object. diff --git a/docs/federated/docs/source_en/java_api_dataset.md b/docs/federated/docs/source_en/java_api_dataset.md deleted file mode 100644 index 990bcbfff14fbc3779a78a12f0d65e42ba3e0ca2..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/java_api_dataset.md +++ /dev/null @@ -1,63 +0,0 @@ -# DataSet - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/java_api_dataset.md) - -```java -import com.mindspore.flclient.model.DataSet -``` - -DataSet defines end-side federated learning dataset object. - -## Public Member Functions - -| function | -| -------------------------------- | -| [abstract void fillInputBuffer(List var1, int var2)](#fillinputbuffer) | -| [abstract void shuffle()](#shuffle) | -| [abstract void padding()](#padding) | -| [abstract Status dataPreprocess(List var1)](#datapreprocess) | - -## fillInputBuffer - -```java -public abstract void fillInputBuffer(List var1, int var2) -``` - -Fill input buffer data. - -- Parameters - - - `var1`: Need fill buffer. - - `var2`: Need fill batch index. - -## shuffle - -```java - public abstract void shuffle() -``` - -Shuffle data. - -## padding - -```java - public abstract void padding() -``` - -Pad data. - -## dataPreprocess - -```java -public abstract Status dataPreprocess(List var1) -``` - -Data preprocess. - -- Parameters - - - `var1`: Data files. - -- Returns - - Whether the execution is successful. diff --git a/docs/federated/docs/source_en/java_api_flparameter.md b/docs/federated/docs/source_en/java_api_flparameter.md deleted file mode 100644 index fd9807412781e15a6ad854900edcb38a90660168..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/java_api_flparameter.md +++ /dev/null @@ -1,636 +0,0 @@ -# FLParameter - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/java_api_flparameter.md) - -```java -import com.mindspore.flclient.FLParameter -``` - -FLParameter is used to define parameters related to federated learning. - -## Public Member Functions - -| **function** | -| ------------------------------------------------------------ | -| public static synchronized FLParameter getInstance() | -| public String getDeployEnv() | -| public void setDeployEnv(String env) | -| public String getDomainName() | -| public void setDomainName(String domainName) | -| public String getClientID() | -| public void setClientID(String clientID) | -| public String getCertPath() | -| public void setCertPath(String certPath) | -| public SSLSocketFactory getSslSocketFactory() | -| public void setSslSocketFactory(SSLSocketFactory sslSocketFactory) | -| public X509TrustManager getX509TrustManager() | -| public void setX509TrustManager(X509TrustManager x509TrustManager) | -| public IFLJobResultCallback getIflJobResultCallback() | -| public void setIflJobResultCallback(IFLJobResultCallback iflJobResultCallback) | -| public String getFlName() | -| public void setFlName(String flName) | -| public String getTrainModelPath() | -| public void setTrainModelPath(String trainModelPath) | -| public String getInferModelPath() | -| public void setInferModelPath(String inferModelPath) | -| public String getSslProtocol() | -| public void setSslProtocol(String sslProtocol) | -| public int getTimeOut() | -| public void setTimeOut(int timeOut) | -| public int getSleepTime() | -| public void setSleepTime(int sleepTime) | -| public boolean isUseElb() | -| public void setUseElb(boolean useElb) | -| public int getServerNum() | -| public void setServerNum(int serverNum) | -| public boolean isPkiVerify() | -| public void setPkiVerify(boolean ifPkiVerify) | -| public String getEquipCrlPath() | -| public void setEquipCrlPath(String certPath) | -| public long getValidInterval() | -| public void setValidInterval(long validInterval) | -| public int getThreadNum() | -| public void setThreadNum(int threadNum) | -| public int getCpuBindMode() | -| public void setCpuBindMode(BindMode cpuBindMode) | -| public List getHybridWeightName(RunType runType) | -| public void setHybridWeightName(List hybridWeightName, RunType runType) | -| public Map> getDataMap() | -| public void setDataMap(Map/> dataMap) | -| public ServerMod getServerMod() | -| public void setServerMod(ServerMod serverMod) | -| public int getBatchSize() | -| public void setBatchSize(int batchSize) | - -## getInstance - -```java -public static synchronized FLParameter getInstance() -``` - -Obtains a single FLParameter instance. - -- Return value - - Single object of the FLParameter type. - -## getDeployEnv - -```java -public String getDeployEnv() -``` - -Obtains the deployment environment for federated learning set by users. - -- Return value - - The deployment environment for federated learning of the string type. - -## setDeployEnv - -```java -public void setDeployEnv(String env) -``` - -Used to set the deployment environment for federated learning, a whitelist is set, currently only "x86", "android" are supported. - -- Parameter - - - `env`: the deployment environment for federated learning. - -## getDomainName - -```java -public String getDomainName() -``` - -Obtains the domain name set by a user. - -- Return value - - Domain name of the string type. - -## setDomainName - -```java -public void setDomainName(String domainName) -``` - -Used to set the url for device-cloud communication. Currently, https and http communication are supported, the corresponding formats are like: https://......, http://......, and when `useElb` is set to true, the format must be: https://127.0.0.0 : 6666 or http://127.0.0.0 : 6666 , where `127.0.0.0` corresponds to the ip of the machine providing cloud-side services (corresponding to the cloud-side parameter `--scheduler_ip`), and `6666` corresponds to the cloud-side parameter `--fl_server_port`. - -- Parameter - - - `domainName`: domain name. - -## getClientID - -```java -public String getClientID() -``` - -A clientID that uniquely identifies the client is automatically generated before each federated learning task is started (if the user needs to set the clientID, setClientID is used to set before the federated learning training task is started), and this method is used to obtain the ID, which can be used to generate the relevant certificates in the device-cloud security authentication scenarios. - -- Return value - - Unique ID of the client, which is of the string type. - -## setClientID - -```java -public void setClientID(String clientID) -``` - -Used to set the clientID that uniquely identifies the client. - -- Parameter - - - `clientID`: unique ID of the client. - -## getCertPath - -```java -public String getCertPath() -``` - -Obtains the self-signed root certificate path used for device-cloud HTTPS communication. - -- Return value - - The self-signed root certificate path of the string type. - -## setCertPath - -```java -public void setCertPath(String certPath) -``` - -Sets the self-signed root certificate path used for device-cloud HTTPS communication. When the deployment environment is "x86" and the device-cloud uses a self-signed certificate for HTTPS communication authentication, this parameter needs to be set. The certificate must be consistent with the CA root certificate used to generate the cloud-side self-signed certificate to pass the verification. This parameter is used for non-Android scenarios. - -- Parameter - - `certPath`: the self-signed root certificate path used for device-cloud HTTPS communication. - -## getSslSocketFactory - -```java -public SSLSocketFactory getSslSocketFactory() -``` - -Obtains the ssl certificate authentication library `sslSocketFactory` set by the user. - -- Return value - - The ssl certificate authentication library `sslSocketFactory` , which is of the SSLSocketFactory type. - -## setSslSocketFactory - -```java -public void setSslSocketFactory(SSLSocketFactory sslSocketFactory) -``` - -Used to set the ssl certificate authentication library `sslSocketFactory`. - -- Parameter - - `sslSocketFactory`: the ssl certificate authentication library. - -## getX509TrustManager - -```java -public X509TrustManager getX509TrustManager() -``` - -Obtains the ssl certificate authentication manager `x509TrustManager` set by the user. - -- Return value - - the ssl certificate authentication manager `x509TrustManager`, which is of the X509TrustManager type. - -## setX509TrustManager - -```java -public void setX509TrustManager(X509TrustManager x509TrustManager) -``` - -Used to set the ssl certificate authentication manager `x509TrustManager`. - -- Parameter - - `x509TrustManager`: the ssl certificate authentication manager. - -## getIflJobResultCallback - -```java -public IFLJobResultCallback getIflJobResultCallback() -``` - -Obtains the federated learning callback function object `iflJobResultCallback` set by the user. - -- Return value - - The federated learning callback function object `iflJobResultCallback`, which is of the IFLJobResultCallback type. - -## setIflJobResultCallback - -```java -public void setIflJobResultCallback(IFLJobResultCallback iflJobResultCallback) -``` - -Used to set the federated learning callback function object `iflJobResultCallback`, the user can implement the specific method of the interface class [IFLJobResultCallback.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/IFLJobResultCallback.java) in the project according to the needs of the actual scene, and set it as a callback function object in the federated learning task. - -- Parameter - - `iflJobResultCallback`: the federated learning callback function object. - -## getFlName - -```java -public String getFlName() -``` - -Obtains the package path of model script set by a user. - -- Return value - - Name of the package path of model script of the string type. - -## setFlName - -```java -public void setFlName(String flName) -``` - -Sets the package path of model script . We provide two types of model scripts for your reference ([Supervised sentiment classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert), [Lenet image classification task](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)). For supervised sentiment classification tasks, this parameter can be set to the package path of the provided script file [AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java), like as `com.mindspore.flclient.demo.albert.AlbertClient`; for Lenet image classification tasks, this parameter can be set to the package path of the provided script file [LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java), like as `com.mindspore.flclient.demo.lenet.LenetClient`. At the same time, users can refer to these two types of model scripts, define the model script by themselves, and then set the parameter to the package path of the customized model file ModelClient.java (which needs to inherit from the class [Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java)). - -- Parameter - - `flName`: package path of model script. - -## getTrainModelPath - -```java -public String getTrainModelPath() -``` - -Obtains the path of the training model set by a user. - -- Return value - - Path of the training model of the string type. - -## setTrainModelPath - -```java -public void setTrainModelPath(String trainModelPath) -``` - -Sets the path of the training model. - -- Parameter - - `trainModelPath`: training model path. - -## getInferModelPath - -```java -public String getInferModelPath() -``` - -Obtains the path of the inference model set by a user. - -- Return value - - Path of the inference model of the string type. - -## setInferModelPath - -```java -public void setInferModelPath(String inferModelPath) -``` - -Sets the path of the inference model. - -- Parameter - - `inferModelPath`: path of the inference model. - -## getSslProtocol - -```java -public String getSslProtocol() -``` - -Obtains the TLS protocol version used by the device-cloud HTTPS communication. - -- Return value - - The TLS protocol version used by the device-cloud HTTPS communication of the string type. - -## setSslProtocol - -```java -public void setSslProtocol(String sslProtocol) -``` - -Used to set the TLS protocol version used by the device-cloud HTTPS communication, a whitelist is set, and currently only "TLSv1.3" or "TLSv1.2" is supported. Only need to set it up in the HTTPS communication scenario. - -- Parameter - - `sslProtocol`: the TLS protocol version used by the device-cloud HTTPS communication. - -## getTimeOut - -```java -public int getTimeOut() -``` - -Obtains the timeout interval set by a user for device-side communication. - -- Return value - - Timeout interval for communication on the device, which is an integer. - -## setTimeOut - -```java -public void setTimeOut(int timeOut) -``` - -Sets the timeout interval for communication on the device. - -- Parameter - - `timeOut`: timeout interval for communication on the device. - -## getSleepTime - -```java -public int getSleepTime() -``` - -Obtains the waiting time of repeated requests set by a user. - -- Return value - - Waiting time of repeated requests, which is an integer. - -## setSleepTime - -```java -public void setSleepTime(int sleepTime) -``` - -Sets the waiting time of repeated requests. - -- Parameter - - `sleepTime`: waiting time for repeated requests. - -## isUseElb - -```java -public boolean isUseElb() -``` - -Determines whether the elastic load balancing is simulated, that is, whether a client randomly sends requests to a server address within a specified range. - -- Return value - - The value is of the boolean type. The value true indicates that the client sends requests to a random server address within a specified range. The value false indicates that the client sends a request to a fixed server address. - -## setUseElb - -```java -public void setUseElb(boolean useElb) -``` - -Determines whether to simulate the elastic load balancing, that is, whether a client randomly sends a request to a server address within a specified range. - -- Parameter - - `useElb`: determines whether to simulate the elastic load balancing. The default value is false. - -## getServerNum - -```java -public int getServerNum() -``` - -Obtains the number of servers that can send requests when simulating the elastic load balancing. - -- Return value - - Number of servers that can send requests during elastic load balancing simulation, which is an integer. - -## setServerNum - -```java -public void setServerNum(int serverNum) -``` - -Sets the number of servers that can send requests during elastic load balancing simulation. - -- Parameter - - `serverNum`: number of servers that can send requests during elastic load balancing simulation. The default value is 1. - -## isPkiVerify - -```java -public boolean isPkiVerify() -``` - -Whether to perform device-cloud security authentication. - -- Return value - - The value is of the boolean type. The value true indicates that device-cloud security authentication is performed, and the value false indicates that device-cloud security authentication is not performed. - -## setPkiVerify - -```java -public void setPkiVerify(boolean pkiVerify) -``` - -Determines whether to perform device-cloud security authentication. - -- Parameter - - - `pkiVerify`: whether to perform device-cloud security authentication. - -## getEquipCrlPath - -```java -public String getEquipCrlPath() -``` - -Obtains the CRL certification path `equipCrlPath` of the device certificate set by the user. This parameter is used in the Android environment. - -- Return value - - The certification path of the string type. - -## setEquipCrlPath - -```java -public void setEquipCrlPath(String certPath) -``` - -Used to set the CRL certification path of the device certificate. It is used to verify whether the digital certificate is revoked. This parameter is used in the Android environment. - -- Parameter - - `certPath`: the certification path. - -## getValidInterval - -```java -public long getValidInterval() -``` - -Obtains the valid iteration interval validIterInterval set by the user. This parameter is used in the Android environment. - -- Return value - - The valid iteration interval validIterInterval of the long type. - -## setValidInterval - -```java -public void setValidInterval(long validInterval) -``` - -Used to set the valid iteration interval validIterInterval. The recommended duration is the duration of one training epoch between the device-cloud(unit: milliseconds). It is used to prevent replay attacks. This parameter is used in the Android environment. - -- Parameter - - `validInterval`: the valid iteration interval validIterInterval. - -## getThreadNum - -```java -public int getThreadNum() -``` - -Obtains the number of threads used in federated learning training and inference. The default value is 1. - -- Return value - - The number of threads used in federated learning training and inference, which is of the int type. - -## setThreadNum - -```java -public void setThreadNum(int threadNum) -``` - -Used to set the number of threads used in federated learning training and inference. - -- Parameter - - `threadNum`: the number of threads used in federated learning training and inference. - -## getCpuBindMode - -```java -public int getCpuBindMode() -``` - -Obtains the cpu core that threads need to bind during federated learning training and inference. - -- Return value - - Convert the enumerated type of cpu core to int type and return. - -## setCpuBindMode - -```java -public void setCpuBindMode(BindMode cpuBindMode) -``` - -Used to set the cpu core that threads need to bind during federated learning training and inference. - -- Parameter - - `cpuBindMode`: it is the enumeration type `BindMode`, where BindMode.NOT_BINDING_CORE represents the unbound core, which is automatically assigned by the system, BindMode.BIND_LARGE_CORE represents the bound large core, and BindMode.BIND_MIDDLE_CORE represents the bound middle core. - -## getHybridWeightName - -```java -public List getHybridWeightName(RunType runType) -``` - -Used in hybrid training mode. Get the training weight name and inference weight name set by the user. - -- Parameter - -- `runType`: RunType enumeration type, only supports to be set to RunType.TRAINMODE (representing the training weight name) , RunType.INFERMODE (representing the inference weight name). - -- Return value - - A list of corresponding weight names according to the parameter runType, which is of the List type. - -## setHybridWeightName - -```java -public void setHybridWeightName(List hybridWeightName, RunType runType) -``` - -Due to the hybrid training mode, part of the weights delivered by the server is imported into the training model, and part is imported into the inference model, but the framework itself cannot judge it, so the user needs to set the relevant training weight name and inference weight name by himself. This method is provided for the user to set. - -- Parameter - - `hybridWeightName`: a list of weight names of the List type. - - `runType`: RunType enumeration type, only supports setting to RunType.TRAINMODE (representing setting training weight name), RunType.INFERMODE (representing setting reasoning weight name). - -## getDataMap - -```java -public Map> getDataMap() -``` - -Obtains the federated learning dataset set by the user. - -- Return value - - the federated learning dataset set of the Map> type. - -## setDataMap - -```java -public void setDataMap(Map> dataMap) -``` - -Used to set the federated learning dataset set by the user. - -- Parameter - - `dataMap`: the dataset of Map> type, the key in the map is the RunType enumeration type, the value is the corresponding dataset list, when the key is RunType.TRAINMODE, the corresponding value is the training-related dataset list, when the key is RunType.EVALMODE, it means that the corresponding value is a list of verification-related datasets, and when the key is RunType.INFERMODE, it means that the corresponding value is a list of inference-related datasets. - -## getServerMod - -```java -public ServerMod getServerMod() -``` - - Obtains the federated learning training mode. - -- Return value - - The federated learning training mode of ServerMod enumeration type. - -## setServerMod - -```java -public void setServerMod(ServerMod serverMod) -``` - -Used to set the federated learning training mode. - -- Parameter - - `serverMod`: the federated learning training mode of ServerMod enumeration type, where ServerMod.FEDERATED_LEARNING represents the normal federated learning mode (training and inference use the same model) ServerMod.HYBRID_TRAINING represents the hybrid learning mode (training and inference use different models, and the server side also includes training process). - -## getBatchSize - -```java -public int getBatchSize() -``` - -Obtains the number of single-step training samples used in federated learning training and inference, that is, batch size. - -- Return value - - BatchSize, the number of single-step training samples of int type. - -## setBatchSize - -```java -public void setBatchSize(int batchSize) -``` - -Used to set the number of single-step training samples used in federated learning training and inference, that is, batch size. It needs to be consistent with the batch size of the input data of the model. - -- Parameter - - `batchSize`: the number of single-step training samples of int type. \ No newline at end of file diff --git a/docs/federated/docs/source_en/java_api_syncfljob.md b/docs/federated/docs/source_en/java_api_syncfljob.md deleted file mode 100644 index 67d79396551c9245ec389afb1d66ff9fcf9df1b0..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/java_api_syncfljob.md +++ /dev/null @@ -1,64 +0,0 @@ -# SyncFLJob - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/java_api_syncfljob.md) - -```java -import com.mindspore.flclient.SyncFLJob -``` - -SyncFLJob defines the API flJobRun() for starting federated learning on the device, the API modelInference() for inference on the device, the API getModel() for obtaining the latest model on the cloud, and the API stopFLJob() for stopping federated learning training tasks. - -## Public Member Functions - -| **Function** | -| -------------------------------- | -| public FLClientStatus flJobRun() | -| public int[] modelInference() | -| public FLClientStatus getModel() | -| public void stopFLJob() | - -## flJobRun - -```java -public FLClientStatus flJobRun() -``` - -Starts a federated learning task on the device, for specific usage, please refer to the [interface introduction document](https://www.mindspore.cn/federated/docs/en/master/interface_description_federated_client.html). - -- Return value - - The status code of the flJobRun request. - -## modelInference - -```java -public int[] modelInference() -``` - -Starts an inference task on the device, for specific usage, please refer to the [interface introduction document](https://www.mindspore.cn/federated/docs/en/master/interface_description_federated_client.html). - -- Return value - - int[] composed of the labels inferred from the input. - -## getModel - -```java -public FLClientStatus getModel() -``` - -Obtains the latest model on the cloud, for specific usage, please refer to the [interface introduction document](https://www.mindspore.cn/federated/docs/en/master/interface_description_federated_client.html). - -- Return value - - The status code of the getModel request. - -## stopFLJob - -```java -public void stopFLJob() -``` - -The training task can be stopped by calling this interface during the federated learning training process. - -When a thread calls SyncFLJob.flJobRun(), it can use another thread to call SyncFLJob.stopFLJob() to stop the federated learning training task during the federated learning training process. \ No newline at end of file diff --git a/docs/federated/docs/source_en/local_differential_privacy_eval_laplace.md b/docs/federated/docs/source_en/local_differential_privacy_eval_laplace.md deleted file mode 100644 index 878f3bf693e404179a9a2123e9e5aba708b0e326..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/local_differential_privacy_eval_laplace.md +++ /dev/null @@ -1,236 +0,0 @@ -# Horizontal Federated-Local Differential Privacy Inference Result Protection - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/local_differential_privacy_eval_laplace.md) - -## Privacy Protection Background - -Evaluating the federated unsupervised model training can be judged by the $loss$ of end-side feedback, or the end-side inference results combined with cloud-side clustering and clustering evaluation metrics can be used to further monitor the federated unsupervised model training progress. The latter involves end-side inference data on the cloud, and in order to meet privacy protection requirements, privacy protection processing of end-side inference data is required, while the cloud side can still be evaluated for clustering. This task is a secondary task compared to the training task, so we use lightweight algorithms and cannot introduce privacy protection algorithms with higher computational or communication overhead than the training phase. This paper presents a lightweight scheme for protecting inference results by using the local differential privacy Laplace noise mechanism. - -The effective integration of privacy protection technology into the product services will, on the one hand, help enhance the trust of users and the industry in the products and technology, and on the other hand, help to better carry out the federated tasks under the current privacy compliance requirements and create a full lifecycle (training-inference-evaluation) of privacy protection. - -## Algorithm Analysis - -### $L1$ and $L2$ Paradigm - -The $L1$-norm of a vector $V$ with length $k$ is $||V|||_1=\sum^{k}_{i=1}{|V_i|}$, then the $L1$-norm of the difference between two vectors in two-dimensional space is the Manhattan distance. - -$L2$-norm is $||V||_2=\sqrt{\sum^{k}_{i=1}{V^2_i}}$. - -The inference result is generally a $softmax$ result with a sum of $1$, and each dimension value of the vector indicates the probability of belonging to the corresponding category of that dimension. - -### $L1$ and $L2$ Sensitivity - -Local differential privacy introduces uncertainty on the data to be uploaded, and the sensitivity describes an upper bound on the uncertainty. Gaussian noise with $L2$ sensitivity can be added to the gradient in the optimizer and federated training, since a cropping operation is performed on the gradient vector before addition. Here the $softmax$ inference result satisfies the sum as $1$, so the Laplace noise of $L1$ is added. For applications where the $L2$ sensitivity is much lower than the $L1$ sensitivity, the Gaussian mechanism allows to add less noise, but the scenario has no $L2$-related constraint limits and uses only the $L1$ sensitivity. - -The $L1$-sensitivity is expressed as the maximum distance for any input in the defined domain in local differential privacy: - -$\Delta f=max||X-Y||_1$ - -In this scenario, $X=, Y=, \sum X = 1, \sum Y = 1, |x_1-y_1|+|x_2-y_2|+...+|x_k-y_k|\leq1=\Delta f$. - -### Laplace Distribution - -The Laplace distribution is continuous, and the probability density function of the Laplace with mean value 0 is: - -$Lap(x|b)=\frac{1}{2b}exp(-\frac{|x|}{b})$ - -### Laplace Mechanism - -$M(x,\epsilon)=X+Lap(\Delta f/\epsilon)$ - -where $Lap(\Delta f/\epsilon)$ is a vector of random variables with the same shape as $X$, independently and identically distributed. - -In this scenario, $b$ (also called $scale$, $lambda$, $beta$) is $1/\epsilon$. - -### Proving that the Laplace Mechanism is Satisfied with the $\epsilon-LDP$ - -Any two different clients, after being processed by the Laplace mechanism, both output the same result to achieve the confusion indistinguishable and the purpose probability ratio of outputting the same result has upper exact bound. Substituting $b=\Delta f/\epsilon$ yields: - -$Lap(\Delta f/\epsilon)=\frac{\epsilon}{2\Delta f}exp(-\frac{\epsilon|x|}{\Delta f})$ - -$\frac{P(Z|X)}{P(Z|Y)}$ - -$=\prod^k_{i=1}(\frac{exp(-\frac{\epsilon|x_i-z_i|}{\Delta f})}{exp(-\frac{\epsilon |y_i-z_i|}{\Delta f})})$ - -$=\prod^k_{i=1}exp(\epsilon\frac{|x_i-z_i|-|y_i-z_i|}{\Delta f})$ - -$\leq\prod^k_{i=1}(\epsilon\frac{|x_i-y_i|}{\Delta f})$ - -$=exp(\epsilon\frac{X-Y}{\Delta f})$ - -$\leq exp(\epsilon)$ - -#### The Determination of $\epsilon$ with the Corresponding Probability Density Plot - -The privacy budget with high availability is calculated by combining the data characteristics, such as the requirement to output noise of the order of $1e-5$ with high probability, otherwise it will directly affect the clustering results. The privacy budget calculation method corresponding to generating the specified amount of noise is given below. - -There is the $90\%$ probability to output the magnitude of $1e-5$, and the value of $\epsilon$ is obtained by integrating the probability density curve. - -$x>=0, Lap(x|b)=\frac{1}{2b}exp(-\frac{x}{b})$ - -$\int^ {E^{-5}}_0 {Lap(x|b)dx}$ - -$=1-\frac{1}{2}exp(-\frac{x}{b})|^{E^{-5}}_{0}$ - -$=\frac{1}{2}(exp(0)-exp(-\frac{E^{-5}}{b}))$ - -$=0.5(1-exp(-\frac{E^{-5}}{b})) = 0.45$ - -i.e. - -$exp(-\frac{E^{-5}}{b})=0.1$ - -$b=-E^{-5}/ln(0.1)=E^{-5}/2.3026=1/\epsilon$ - -$\epsilon=2.3026E^5$ - -When the privacy budget takes this value, the Laplace probability density function is as follows: - -![laplace](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/laplace_pdf.png) - -### Impact Analysis of Clustering Evaluation Indicators - -Using the **Calinski-Harabasz Index** assessment method as an example, the evaluation indicator is calculated in two steps: - -1. Each class calculates the sum of the squares of the distances from all `points` in the class to the `center of the class`; - -2. Calculate the sum of squares of distances from each `class` to the `center of the class`; - -Source code implementation and impact analysis after noise addition: - -```python -# 1.The cloud-side clustering algorithm gets the class ordinal number to which it belongs, with impact -n_labels = argmax(X) - -extra_disp, intra_disp = 0.0, 0.0 -# 2.Calculate the class center of all points, without impact -mean = np.mean(X, axis=0) -for k in range(n_labels): - # 3.Get all points in class k, based on the effect of 1 - cluster_k = X[labels == k] - # 4.Get the class center, based on the impact of 1 - mean_k = np.mean(cluster_k, axis=0) - # 5.The distance between the class and the center of all classes, based on the impact of 1 - extra_disp += len(cluster_k) * np.sum((mean_k - mean) ** 2) - # 6.The distance from the point to the center of the class, with impact - intra_disp += np.sum((cluster_k - mean_k) ** 2) - -return ( - 1.0 - if intra_disp == 0.0 - else extra_disp * (n_samples - n_labels) / (intra_disp * (n_labels - 1.0)) -) -``` - -In a comprehensive analysis, the main impact is on the clustering algorithm after noise addition, and the error on the distance calculation. When calculating the class center, the error introduced is small because the noise sum is expected to be $0$. - -Taking **SILHOUETTE SCORE** as an example, the process of calculating this evaluation indicator is divided into two steps: - -1. Calculate the average distance of a sample point $i$ from all other sample points in the same cluster, which is denoted as $a_i$. The smaller the value is, the more the sample $i$ should be assigned to this cluster. - -2. Calculate the average distance $b_{ij}$ of sample $i$ to all samples of some other cluster $C_j$, which is called the dissimilarity of sample $i$ to cluster $C_j$. The inter-cluster dissimilarity of sample $i$ is defined as: $b_i = min(b_{i1}, b_{i2}, ..., b_{ik})$. The larger the value is, the less the sample $i$ should belong to this cluster. - -![flow](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/two_cluster.png) - -$s_i=(b_i-a_i) / max(a_i, b_i)$. - -The smaller $a_i$ is, the larger $b_i$ is, and the result is $1-a_i / b_i$. The closer to $1$, the better the clustering effect. - -Pseudocode implementation and impact analysis after noise addition: - -```c++ -// Calculate distance matrix, space for time, upper triangle storage, which has an impact on noise addition -euclidean_distance_matrix(&distance_matrix, group_ids); - -// Perform the same calculation for each point, and finally calculate the mean value -for (size_t i = 0; i < n_samples; ++i) { - std::unordered_map> b_i_map; - for (size_t j = 0; j < n_samples; ++j) { - size_t label_j = labels[j]; - float distance = distance_matrix[i][j]; - // Same cluster calculates ai - if (label_j == label_i) { - a_distances.push_back(distance); - } else { - // Different clusters calculate bi - b_i_map[label_j].push_back(distance); - } - } - if (a_distances.size() > 0) { - // Calculate the average distance of the point from other points in the same cluster - a_i = std::accumulate(a_distances.begin(), a_distances.end(), 0.0) / a_distances.size(); - } - for (auto &item : b_i_map) { - auto &b_i_distances = item.second; - float b_i_distance = std::accumulate(b_i_distances.begin(), b_i_distances.end(), 0.0) / b_i_distances.size(); - b_i = std::min(b_i, b_i_distance); - } - if (a_i == 0) { - s_i[i] = 0; - } else { - s_i[i] = (b_i - a_i) / std::max(a_i, b_i); - } -} -return std::accumulate(s_i.begin(), s_i.end(), 0.0) / n_samples; -``` - -As above, the main impact is the main impact is on the clustering algorithm after noise addition, and the error on the distance calculation. - -### End-side Java Implementation - -There is no function in the Java basic library to generate Laplace distributed random numbers. The following combination strategy of random numbers is used to generate. - -The source code is as follows: - -```java -float genLaplaceNoise(SecureRandom secureRandom, float beta) { - float u1 = secureRandom.nextFloat(); - float u2 = secureRandom.nextFloat(); - if (u1 <= 0.5f) { - return (float) (-beta * log(1. - u2)); - } else { - return (float) (beta * log(u2)); - } -} -``` - -After obtaining a new round of model on the end-side, the inference calculation is executed immediately. After the training, the inference results after privacy protection are uploaded to the cloud side together with the new model, and the cloud side finally performs operations such as clustering and score calculation. The flow is shown in the following figure, where the red part is the output result of privacy protection processing: - -![flow](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/eval_flow.png) - -## Quick Start - -### Preparation - -To use this feature, one first needs to successfully complete the training aggregation process for either end-cloud federated scenario. [Implementing an Image Classification Application of Cross-device Federated Learning (x86)](https://www.mindspore.cn/federated/docs/en/master/image_classification_application.html) details the preparation of datasets and network models, as well as simulates the process of initiating multi-client participation in federated learning. - -### Configuration Items - -The [cloud-side yaml configuration file](https://gitee.com/mindspore/federated/blob/master/tests/st/cross_device_cloud/default_yaml_config.yaml) gives the complete configuration items for opening the end-cloud federated, and the program involves the following additional configuration file items: - -```c -encrypt: - privacy_eval_type: LAPLACE - laplace_eval: - laplace_eval_eps: 230260 -``` - -where `privacy_eval_type` currently supports only `NOT_ENCRYPT` and `LAPLACE`, indicating that the inference results are processed without privacy protection methods and with the `LAPLACE` mechanism, respectively. - -`laplace_eval_eps` indicates how much of the privacy budget is used if `LAPLACE` processing is used. - -## Experimental Results - -The basic configuration associated with the inference result evaluation function is used as follows: - -```c -unsupervised: - cluster_client_num: 1000 - eval_type: SILHOUETTE_SCORE -``` - -We can see that the relationship between $loss$ and the score under the `LAPLACE` mechanism by using `NOT_ENCRYPT` and using `laplace_eval_eps=230260` is shown in the figure: - -![flow](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/SILHOUETTE.png) - -The red dashed line shows the SILHOUETTE scores after the Laplace mechanism is used to protect the inference results. Since the model contains $dropout$ and Gaussian input, the $loss$ of the two trainings are slightly different and the scores obtained based on different models are slightly different. However, the overall trend remains consistent and can assist $loss$ together to detect the model training progress. diff --git a/docs/federated/docs/source_en/local_differential_privacy_training_noise.md b/docs/federated/docs/source_en/local_differential_privacy_training_noise.md deleted file mode 100644 index 457d3624c4ecfd9979db06aff3847e33243e746e..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/local_differential_privacy_training_noise.md +++ /dev/null @@ -1,45 +0,0 @@ -# Horizontal FL-Local Differential Privacy Perturbation Training - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/local_differential_privacy_training_noise.md) - -During federated learning, user data is used only for local device training and does not need to be uploaded to the central server. This prevents personal data leakage. -However, in the conventional federated learning framework, models are migrated to the cloud in plaintext. There is still a risk of indirect disclosure of user privacy. -After obtaining the plaintext model uploaded by a user, the attacker can restore the user's personal training data through attacks such as reconstruction and model inversion. As a result, user privacy is disclosed. - -As a federated learning framework, MindSpore Federated provides secure aggregation algorithms based on local differential privacy (LDP). Noise addition is performed on local models before they are migrated to the cloud. On the premise of ensuring the model availability, the problem of privacy leakage in horizontal federated learning is solved. - -## Principles - -Differential privacy is a mechanism for protecting user data privacy. Differential privacy is defined as follows: - -$$ -Pr[\mathcal{K}(D)\in S] \le e^{\epsilon} Pr[\mathcal{K}(D') \in S]+\delta​ -$$ - -For datasets $D, D'$ that have only one record difference, the random algorithm $\mathcal{K}$ is used to compute the probability of the $S$ subset, which meets the preceding formula. $\epsilon$ is the differential privacy budget, and $\delta$ is the perturbation. The smaller the values of $\epsilon$ and $\delta$, the closer the data distribution of $\mathcal{K}$ on $D$ and $D'$. - -In horizontal federated learning, if the model weight matrix after local training on the client is $W$, the attacker can use $W$ to restore the training dataset[1] of the user because the model "remembers" the features of the training set during the training process. - -MindSpore Federated provides a LDP-based secure aggregation algorithm to prevent privacy data leakage when local models are migrated to the cloud. - -The MindSpore Federated client generates a differential noise matrix $G$ that has the same dimension as the local model $W$, and then adds the two to obtain a weight $W_p$ that meets the differential privacy definition: - -$$ -W_p=W+G -$$ - -The MindSpore Federated client uploads the noise-added model $W_p$ to the cloud server for federated aggregation. The noise matrix $G$ is equivalent to adding a layer of mask to the original model, which reduces the risk of sensitive data leakage from models and affects the convergence of model training. How to achieve a better balance between model privacy and usability is still a question worth studying. Experiments show that when the number of participants $n$ is large enough (generally more than 1000), most of the noises can cancel each other, and the LDP mechanism has no obvious impact on the accuracy and convergence of the aggregation model. - -## Usage - -Local differential privacy training currently only supports cross device scenarios. Enabling differential privacy training is simple. You only need to set the `encrypt_train_type` field to `DP_ENCRYPT` via [yaml](https://www.mindspore.cn/federated/docs/en/master/horizontal/federated_server_yaml.html#) when starting the cloud-side service. - -In addition, to control the effect of privacy protection, three parameters are provided: `dp_eps`, `dp_delta`, and `dp_norm_clip`. They are also set through the yaml file. - -The valid value range of `dp_eps` and `dp_norm_clip` is greater than 0. The legal range of `dp_delta` is 0<`dp_delta`<1. In general, the smaller `dp_eps` and `dp_delta` are, the better the privacy protection will be, but the greater the impact on the convergence of the model. It is recommended that `dp_delta` be taken as the inverse of the number of clients and `dp_eps` be greater than 50. - -`dp_norm_clip` is the adjustment coefficient of the model weight before noise is added to the model weight by the LDP mechanism. It affects the convergence of the model. The recommended value ranges from 0.5 to 2. - -## References - -[1] Ligeng Zhu, Zhijian Liu, and Song Han. [Deep Leakage from Gradients](http://arxiv.org/pdf/1906.08935.pdf). NeurIPS, 2019. diff --git a/docs/federated/docs/source_en/local_differential_privacy_training_signds.md b/docs/federated/docs/source_en/local_differential_privacy_training_signds.md deleted file mode 100644 index 5d471046081b4e8c91dc524c4716c5988114e472..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/local_differential_privacy_training_signds.md +++ /dev/null @@ -1,178 +0,0 @@ -# Horizontal FL-Local Differential Privacy SignDS training - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/local_differential_privacy_training_signds.md) - -## Privacy Protection Background - -Federated learning enables the client user to participate in global model training without uploading the original dataset by allowing the participant to upload only the new model after local training or update the update information of the model, breaking through the data silos. This common scenario of federated learning corresponds to the default scheme in the MindSpore federated learning framework, where the `encrypt_train_type` switch defaults to `not_encrypt` when starting the `server`. The `installation and deployment` and `application practices` in the federated learning tutorial both use this approach by default, which is a common federated seeking averaging scheme without any privacy-protecting treatment such as cryptographic perturbation. For the convenience of description, `not_encrypt' is used below to refer specifically to this default scheme. - -This federated learning scheme is not free from privacy leakage, using the above `not_encrypt` scheme for training. The Server receives the local training model uploaded by the Client, which can still reconstruct the user training data through some attack methods [1], thus leaking user privacy, so the `not_encrypt` scheme needs to further increase the user privacy protection mechanism. - -The global model `oldModel` received by the Client in each round of federated learning is issued by the Server, which does not involve user privacy issues. However, the local model `newModel` obtained by each Client after several epochs of local training fits its local privacy data, so the privacy protection focuses on the weight difference between the two `newModel`-`oldModel`=`update`. - -The `DP_ENCRYPT` differential noise scheme already implemented in the MindSpore Federated framework achieves privacy preservation by iteratively perturbing Gaussian random noise to `update`. However, as the dimensionality of the model increases, the increase in the `update` paradigm will increase the noise, thus requiring more Clients to participate in the same round of aggregation to neutralize the noise impact, otherwise the convergence and accuracy of the model will be reduced. If the noise is set too small, although the convergence and accuracy are close to the performance of the `not_encrypt` scheme, the privacy protection is not strong enough. Also each Client needs to send the perturbed model, and as the model increases, the communication overhead increases. We expect the Client represented by the cell phone to achieve convergence of the global model with as little communication overhead as possible. - -## Algorithm Flow Introduction - -SignDS [2] is the abbreviation of Sign Dimension Select, and the processing object is the `update` of Client. Preparation: each layer of Tensor of `update` is flattened and expanded into a one-dimensional vector, connected together, and the number of splicing vector dimensions is noted as $d$. - -One sentence summarizes the algorithm: Each participant only uploads information about the important dimensions, including their gradient directions and privacy-preserving steps, which corresponds to the SignDS and MagRR (Magnitude Random Response) modules in the figure below, respectively. - -![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/signds_framework.png) - -Here is an example: there are 3 clients Client1, 2, 3, whose `update` is a $d=8$-dimensional vector after flattening and expanding, and the Server calculates the `avg` of these 3 clients Client and updates the global model with the value, that is, completes a round of federated learning. - -| Client | d_1 | d_2 | d_3 | d_4 | d_5 | d_6 | d_7 | d_8 | -| :----: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :---: | -| 1 | 0.4 | 0.1 | -0.2 | 0.3 | 0.5 | 0.1 | -0.2 | -0.3 | -| 2 | 0.5 | 0.2 | 0 | 0.1 | 0.3 | 0.2 | -0.1 | -0.2 | -| 3 | 0.3 | 0.1 | -0.1 | 0.5 | 0.2 | 0.3 | 0 | 0.1 | -| avg | 0.4 | 0.13 | -0.1 | 0.3 | 0.33 | 0.2 | -0.1 | -0.13 | - -### SignDS - -The dimension with higher importance should be selected, and the importance measure is the size of the **fetching value**, and the update needs to be sorted. update takes positive and negative values to represent different update directions, so in each round of federated learning, the sign values of Client each have **0.5 probability** of taking `1` or `-1`. If sign=1, the largest $k$ number of `update` dimensions are noted as the `topk` set and the remaining ones are noted as the `non-topk` set. If sign=-1, the smallest $k$ number of ones are noted as the `topk` set. - -If the Server specifies `h`, the total number of selected dimensions, the Client will directly use this value, otherwise each Client will locally calculate the optimal output dimension `h`. - -The SignDS algorithm outputs the number of dimensions (denoted as $v$) that should be selected from the `topk` set and the `non-topk` set, as in the example in the table below, where the two sets pick a total of dimensions h=3. - -Client selects dimensions uniformly and randomly according to the number of dimensions output by the SignDS algorithm, sends the dimension number and sign value to the Server. If the dimension number is output in the order of picking from `topk` first and then from `non-topk`, the dimension number list `index` needs to be shuffled and disordered. The following table shows the part of information finally transferred from each Client of this algorithm to the Server. - -| Client | index | sign | -| :----: | :---: | :--: | -| 1 | 1,5,8 | 1 | -| 2 | 2,3,4 | -1 | -| 3 | 3,6,7 | 1 | - -### MagRR - -The Server receives the dimension direction from the client, but it is not clear what the step size to update in that direction is. Generally speaking, the step length tends to be large at the beginning of training, and shrinks as the training gradually converges. The general trend of step length change is shown in the following figure: - -![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/signds_step_length.png) - -The Server wants to estimate a dynamic range $[0,2∗r_{est}]$ for the actual step $r$, and thus compute the global learning rate $lr_{global}=2∗r_{est}*num_{clients}$. - -The $r$ adjustment uses a similar dichotomous idea. The specific process is as follows: - -1. The server initializes a smaller $r_{est}$ before the start of training (which does not affect the direction of model convergence too much); -2. After each round of local training, the participant calculates the true magnitude $r$ (mean of topk dimensions) and converts $r$ to $b$ with certain rules based on the current $r_{est}$ issued from the cloud side; -3. The participant performs local differential Binary Randomized Response (BRR) perturbation on $b$ and upload the results. - -The whole training process is divided into two phases, namely the **fast growth** phase and the **contraction** phase. The rules for $r \rightarrow b$ conversions and server-side updates of $r_{est}$ are slightly different for the participant in the two phases: - -- In the fast growth phase, a smaller $r_{est}$ is chosen, such as $e^{-5}$. At this point, $r_{est}$ is expanded by a certain multiple. - Therefore, we can define: - - $$ - b = \begin{cases} - 0 & r \in [2*r_{est}, \infty] \\ - 1 & r \in [0,2*r_{est})] - \end{cases} - $$ - - The server aggregates all device-side random response results for frequency statistics and calculates the plurality $B$. - If $B=0$, it is considered that $r_{est}$ has not reached the range of 𝑟 at present and needs to continue increasing $r_{est}$; - If $B=1$, $r_{est}$ is considered to have reached the range of 𝑟, and keep $r_{est}$ unchanged. -- In the contraction phase, it is necessary to fine-tune $r_{est}$ according to the changes in $r$. Therefore we can define: - - $$ - b = \begin{cases} - 0 & r \in [r_{est}, \infty] \\ - 1 & r \in [0,r_{est})] - \end{cases} - $$ - - Calculate $B$, and if $B=0$ , consider that $r_{est}$ and $r$ are currently closer and keep $r_{est}$ unchanged; - If $B=1$, $r$ is considered to be generally smaller than $r_{est}$, and $r_{est}$ is halved. - -The Server constructs `update` with privacy protection based on the dimension serial number, sign value and $r_{est}$ uploaded by each Client, and aggregates and averages all `update` and updates the current `oldModel` to complete one round of federated learning. The following table shows the aggregation when $2∗r_{est}*num_{clients}=1$. - -| Client | d_1 | d_2 | d_3 | d_4 | d_5 | d_6 | d_7 | d_8 | -| :----: | :---: | :----: | :----: | :----: | :---: | :---: | :---: | :---: | -| 1 | **1** | 0 | 0 | 0 | **1** | 0 | 0 | **1** | -| 2 | 0 | **-1** | **-1** | **-1** | 0 | 0 | 0 | 0 | -| 3 | 0 | 0 | **1** | 0 | 0 | **1** | **1** | 0 | -| avg | 1/3 | -1/3 | 0 | -1/3 | 1/3 | 1/3 | 1/3 | 1/3 | - -The SignDS scheme enables the device-side client to upload only a list of dimensional ordinal numbers of type int output by the algorithm, a random Sign value of type boolean and feedback results on the estimated value to the cloud side, which significantly reduces the communication overhead compared to uploading tens of thousands of float-level complete model weights or gradients in a common scenario. From the perspective of the actual reconstruction attack, the cloud side only obtains the dimension serial number, a Sign value representing the direction of gradient update and the step estimation feedback value for privacy protection, and the attack is more difficult to achieve. The data flow fields of the overall scheme are shown in the following figure: - -![img](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/signds_flow.png) - -## Privacy Protection Certificate - -The differential privacy noise scheme achieves privacy protection by adding noise so that the attacker cannot determine the original information, while the differential privacy SignDS scheme activates partial dimensions and replaces the original value with the sign value, which largely protects user privacy. Further, using the differential privacy index mechanism makes it impossible for an attacker to confirm whether the activated dimensions are significant (from the `topk` set) and whether the number of dimensions from `topk` in the output dimensions exceeds a given threshold. - -### Dimensional Selection Mechanism Based on Index Mechanism - -For any two updates $\Delta$ and $\Delta'$ of each Client, the set of `topk` dimensions is $S_{topk}$ , ${S'}_{topk}$ , respectively. The set of any possible output dimensions of the algorithm is ${J}\in {\mathcal{J}}$ . Note that $\nu=|{S}_ {topk}\cap {J}|$ , $\nu'=|{S'}_{topk}\cap {J}|$ is the number of intersections of ${J}$ and `topk` sets, and the algorithm such that the following inequality holds: - -$$ -\frac{{Pr}[{J}|\Delta]}{{Pr}[{J}|\Delta']}=\frac{{Pr}[{J}|{S}_{topk}]}{{Pr}[{J}|{S'}_{topk}]}=\frac{\frac{{exp}(\frac{\epsilon}{\phi_u}\cdot u({S}_{topk},{J}))}{\sum_{{J'}\in {\mathcal{J}}}{exp}(\frac{\epsilon}{\phi_u}\cdot u({S}_{topk}, {J'}))}}{\frac{{exp}(\frac{\epsilon}{\phi_u}\cdot u({S'}_{topk}, {J}))}{\sum_{ {J'}\in {\mathcal{J}}}{exp}(\frac{\epsilon}{\phi_u}\cdot u( {S'}_{topk},{J'}))}}=\frac{\frac{{exp}(\epsilon\cdot \unicode{x1D7D9}(\nu \geq \nu_{th}))}{\sum_{\tau=0}^{\tau=\nu_{th}-1}\omega_{\tau} + \sum_{\tau=\nu_{th}}^{\tau=h}\omega_{\tau}\cdot {exp}(\epsilon)}}{\frac{ {exp}(\epsilon\cdot \unicode{x1D7D9}(\nu' \geq\nu_{th}))}{\sum_{\tau=0}^{\tau=\nu_{th}-1}\omega_{\tau}+\sum_{\tau=\nu_{th}}^{\tau=h}\omega_{\tau}\cdot {exp}(\epsilon)}}\\= \frac{{exp}(\epsilon\cdot \unicode{x1D7D9} (\nu \geq \nu_{th}))}{ {exp}(\epsilon\cdot \unicode{x1D7D9} (\nu' \geq \nu_{th}))} \leq \frac{{exp}(\epsilon\cdot 1)}{{exp}(\epsilon\cdot 0)} = {exp}(\epsilon), -$$ - -It is proved that the algorithm satisfies local differential privacy. - -### Local Differential Privacy-Random Response Mechanism - -The participant receives the estimate sent from the server, and after the local training is completed, the topk dimensional weight mean of the real update is calculated, and 0 or 1 is output according to the magRR strategy. We consider that 0 or 1 still carries the weight mean range information, and it needs further protection. - -The input of the random response mechanism is the data to be protected ($\b\in \{0,1\}$) and the privacy parameter $\epsilon$, which flips the data according to a certain probability and outputs $\hat{b} \in \{0,1\}$ with the following rules: - -$$ -\hat{b} = \begin{cases} -b & with \quad probability \quad P \\ -1-b & with \quad probability \quad 1-P -\end{cases} -$$ - -where $P=\frac{e^\epsilon}{1+e^\epsilon}$. - -#### Frequency Statistics Based on Random Response Mechanism - -It is difficult for adversaries to distinguish real data from scrambled data by random responses, but it also affects the availability of cloud-side statistical tasks. The server side can approximate the true statistical frequency values by noise reduction, but it is difficult to infer the true input of the user in reverse. Let $N$ be the total number of participants in a round, $N^T$ be the total number of 1 originally, and $N^C$ be the total number of 1 collected by the server, then we have: - -$$ -N^T*P+(N-N^T)*(1-P)=N^C \\ -N^T=\frac{N^C-N+NP}{2P-1} -$$ - -## Preparation - -To use the algorithm, one first needs to successfully complete the training aggregation process for either cross-device federated scenario. [Implementing an Image Classification Application of Cross-device Federated Learning (x86)](https://www.mindspore.cn/federated/docs/en/master/image_classification_application.html) describes the preparation work such as datasets, network models, and simulations to initiate the process of multi-client participation in federated learning in detail. - -## Algorithm Opening Script - -Local differential privacy SignDS training currently only supports cross-device federated learning scenarios. The opening method needs to change the following parameter configuration in the yaml file when opening the cloud-side service. The complete cloud-side opening script can be referred to the cloud-side deployment, and the relevant parameter configuration for opening this algorithm is given here. Taking LeNet task as an example, the yaml related configuration is as follows: - -```python -encrypt: - encrypt_train_type: SIGNDS - ... - signds: - sign_k: 0.2 - sign_eps: 100 - sign_thr_ratio: 0.6 - sign_global_lr: 0.1 - sign_dim_out: 0 -``` - -For the detailed example, refer to [Implementing an Image Classification Application of Cross-device Federated Learning (x86)](https://www.mindspore.cn/federated/docs/en/master/image_classification_application.html). The cloud-side code implementation gives the definition domain of each parameter. If it is not in the definition domain, Server will report an error prompting the definition domain. The following parameter changes are subject to keeping the remaining 4 parameters unchanged. - -- `sign_k`: (0,0.25], k*inputDim>50. default=0.01. `inputDim` is the pulling length of the model or update. If not satisfied, there is a device-side warning. Sort update, and the `topk` set is composed of the first k (%) of it. Decreasing k means to pick from more important dimensions with greater probability. The output will have fewer dimensions, but the dimensions are more important and the change in convergence cannot be determined. The user needs to observe the sparsity of model update to determine the value. When it is quite sparse (update has many zeros), it should be taken smaller. -- `sign_eps`: (0,100], default=100. Privacy-preserving budget. The number sequence symbol is $\epsilon$, abbreviated as eps. When eps decreases, the probability of picking unimportant dimensions increases. When privacy protection is enhanced, output dimensions decrease, the percentage remains the same, and precision decreases. -- `sign_thr_ratio`: [0.5,1], default=0.6. The dimension from `topk` in the activation dimension is occupied threshold lower bound. Increasing will reduce the output dimension, but the proportion of output dimensions from `topk` will increase. When the value is increased excessively, more from `topk` is required in the output, and the total output dimension can only be reduced to meet the requirement, and the accuracy decreases when the number of clients is not large enough. -- `sign_global_lr`: (0,), default=1. This value is multiplied by sign instead of update, which directly affects the convergence speed and accuracy. Moderately increasing this value will improve the convergence speed, but it may make the model oscillate and the gradient explode. If more epochs are run locally per client and the learning rate used for local training is increased, the value needs to be increased accordingly. If the number of clients involved in the aggregation increases, the value also needs to be increased, because the value needs to be aggregated and then divided by the number of users when reconstruction. The result will remain the same only if the value is increased. If the percentage of participants in the new version (r0.2) involved in aggregation is less than 5%, the $lr_{global}$ of the MagRR algorithm is directly adjusted to this parameter. -- `sign_dim_out`: [0,50], default=0. If a non-zero value is given, the client side uses the value directly, increasing the value to output more dimensions, but the proportion of dimensions from `topk` will decrease. If it is 0, the client user has to calculate the optimal output parameters. If eps is not large enough, and the value is increased, many `non-topk` insignificant dimensions will be output leading to affect the mode convergence and accuracy decrease. When eps is large enough, increasing the value will allow important dimension information of more users to leave the local area and improve the accuracy. - -## LeNet Experiment Results - -Use 100 client datasets of `3500_clients_bin`, 600 iterations of federated aggregation. 20 epochs run locally per client, and using learning rate of device-side local training is 0.01. The related parameter of SignDS is `k=0.2, eps=100, ratio=0.6, lr=4, out=0`, and the variation curves of Loss and Auc are shown in the following figure. In the unencrypted scenario, the length of the data uploaded to the cloud side at the end of training on the device side is 266,084, but the length of the data uploaded by SignDS is only 656. - -![loss](./images/lenet_signds_loss_auc.png) - -## References - -[1] Ligeng Zhu, Zhijian Liu, and Song Han. [Deep Leakage from Gradients](http://arxiv.org/pdf/1906.08935.pdf). NeurIPS, 2019. - -[2] Xue Jiang, Xuebing Zhou, and Jens Grossklags. "SignDS-FL: Local Differentially-Private Federated Learning with Sign-based Dimension Selection." ACM Transactions on Intelligent Systems and Technology, 2022. \ No newline at end of file diff --git a/docs/federated/docs/source_en/object_detection_application_in_cross_silo.md b/docs/federated/docs/source_en/object_detection_application_in_cross_silo.md deleted file mode 100644 index 6655d58eb152f31d2dfd4cd02f5ef35858c236c9..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/object_detection_application_in_cross_silo.md +++ /dev/null @@ -1,264 +0,0 @@ -# Implementing a Cross-Silo Federated Target Detection Application (x86) - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/object_detection_application_in_cross_silo.md) - -Based on the type of participating clients, federated learning can be classified into cross-silo federated learning and cross-device federated learning. In a cross-silo federated learning scenario, the clients involved in federated learning are different organizations (e.g., healthcare or finance) or geographically distributed data centers, i.e., training models on multiple data silos. In the cross-device federated learning scenario, the participating clients are a large number of mobile or IoT devices. This framework will describe how to implement a target detection application by using network Fast R-CNN on MindSpore Federated cross-silo federated framework. - -The full script to launch cross-silo federated target detection application can be found [here](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_faster_rcnn). - -## Preparation - -This tutorial deploy the cross-silo federated target detection task based on the faster_rcnn network provided in MindSpore model_zoo. Please first follow the official [faster_rcnn tutorial and code](https://gitee.com/mindspore/models/tree/master/official/cv/FasterRCNN) to understand the COCO dataset, faster_rcnn network structure, training process and evaluation process first. Since the COCO dataset is open source, please refer to its [official website](https://cocodataset.org/#home) guidelines to download a dataset by yourself and perform dataset slicing (for example, suppose there are 100 clients, the dataset can be sliced into 100 copies, each representing the data held by one client). - -Since the original COCO dataset is in json file format, the target detection script provided by cross-silo federated learning framework only supports input data in MindRecord format. You can convert the json file to MindRecord format file according to the following steps. - -- Configure the following parameters in the configuration file[default_config.yaml](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/default_config.yaml): - - - `mindrecord_dir` - - Used to set the generated MindRecord format file save path. The folder name must be mindrecord_{num} format, and the number num represents the client label number 0, 1, 2, 3, ...... - - ```sh - mindrecord_dir:"./datasets/coco_split/split_100/mindrecord_0" - ``` - - - `instance_set` - - Used to set original json file path. - - ```sh - instance_set: "./datasets/coco_split/split_100/train_0.json" - ``` - -- Run the script [generate_mindrecord.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/generate_mindrecord.py) to generate MindRecord file according to `train_0.json`, saved in the `mindrecord_dir` path. - -## Starting the Cross-Silo Federated Mission - -### Installing MindSpore and MindSpore Federated - -Including both downloading source code and downloading release version, supporting CPU, GPU, Ascend hardware platforms, just choose to install according to the hardware platforms. For the installing step, refer to [MindSpore installation](https://www.mindspore.cn/install) and [MindSpore Federated installation](https://www.mindspore.cn/federated/docs/en/master/index.html). - -Currently the federated learning framework is only supported for deployment in Linux environments, and cross-silo federated learning framework requires MindSpore version number >= 1.5.0. - -### Starting Mission - -Refer to [example](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_faster_rcnn) to start the cluster. The reference example directory structure is as follows: - -```text -cross_silo_faster_rcnn -├── src -│ ├── FasterRcnn -│ │ ├── __init__.py // init file -│ │ ├── anchor_generator.py // Anchor generator -│ │ ├── bbox_assign_sample.py // Phase I Sampler -│ │ ├── bbox_assign_sample_stage2.py // Phase II Sampler -│ │ ├── faster_rcnn_resnet.py // Faster R-CNN network -│ │ ├── faster_rcnn_resnet50v1.py // Faster R-CNN network taking Resnet50v1.0 as backbone -│ │ ├── fpn_neck.py // Feature Pyramid Network -│ │ ├── proposal_generator.py // Candidate generator -│ │ ├── rcnn.py // R-CNN network -│ │ ├── resnet.py // Backbone network -│ │ ├── resnet50v1.py // Resnet50v1.0 backbone network -│ │ ├── roi_align.py // ROI aligning network -│ │ └── rpn.py // Regional candidate network -│ ├── dataset.py // Create and process datasets -│ ├── lr_schedule.py // Learning rate generator -│ ├── network_define.py // Faster R-CNN network definition -│ ├── util.py // Routine operation -│ └── model_utils -│ ├── __init__.py // init file -│ ├── config.py // Obtain .yaml configuration parameter -│ ├── device_adapter.py // Obtain on-cloud id -│ ├── local_adapter.py // Get local id -│ └── moxing_adapter.py // On-cloud data preparation -├── requirements.txt -├── mindspore_hub_conf.py -├── generate_mindrecord.py // Convert annotations files in .json format to MindRecord format for reading datasets -├── default_yaml_config.yaml // Required configuration files for Federated training -├── default_config.yaml // Required configuration file of network structure, dataset address, and fl_plan -├── run_cross_silo_fasterrcnn_worker.py // Start Cloud Federated worker script -├── run_cross_silo_fasterrcnn_worker_distribute.py // Start the Cloud Federated distributed worker training script -└── test_fl_fasterrcnn.py // Training scripts used by the client -└── run_cross_silo_fasterrcnn_sched.py // Start Cloud federated scheduler script -└── run_cross_silo_fasterrcnn_server.py // Start Cloud federated server script -``` - -1. Note that you can choose whether to record the loss value for each step by setting the parameter `dataset_sink_mode` in the `test_fl_fasterrcnn.py` file. - - ```python - model.train(config.client_epoch_num, dataset, callbacks=cb, dataset_sink_mode=True) # Not setting dataset_sink_mode means that only the loss value of the last step in each epoch is recorded. - model.train(config.client_epoch_num, dataset, callbacks=cb, dataset_sink_mode=False) # Set dataset_sink_mode=False to record the loss value of each step, which is the default mode in the code. - ``` - -2. Set the following parameters in configuration file [default_config.yaml](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/default_config.yaml): - - - `pre_trained` - - Used to set the pre-trained model path (.ckpt format). - - The pre-trained model experimented in this tutorial is a ResNet-50 checkpoint trained on ImageNet 2012. You can use the [resnet50](https://gitee.com/mindspore/models/tree/master/official/cv/ResNet) script in ModelZoo to train, and then use src/convert_checkpoint.py to convert the trained resnet50 weight file into a loadable weight file. - -3. Start redis - - ```sh - redis-server --port 2345 --save "" - ``` - -4. Start Scheduler - - `run_cross_silo_fasterrcnn_sched.py` is the Python script used to start `Scheduler` and supports modifying the configuration by passing argument `argparse`. Execute the following command, which represents the `Scheduler` that starts this federated learning task. `--yaml_config` is used to set the yaml file path, and its management ip:port is `127.0.0.1:18019`. - - ```sh - python run_cross_silo_fasterrcnn_sched.py --yaml_config="default_yaml_config.yaml" --scheduler_manage_address="127.0.0.1:18019" - ``` - - For the detailed implementation, see [run_cross_silo_fasterrcnn_sched.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/run_cross_silo_fasterrcnn_sched.py). - - The following print represents a successful starting: - - ```sh - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.878 [mindspore_federated/fl_arch/ccsrc/scheduler/scheduler.cc:35] Run] Scheduler started successfully. - [INFO] FEDERATED(3944,2b28c5ada700,python):2022-10-10-17:11:08.155.056 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:90] Run] Start http server! - ``` - -5. Start Server - - `run_cross_silo_fasterrcnn_server.py` is a Python script for starting a number of `Server`s, and supports modifying the configuration by passing argument `argparse`. Execute the following command, representing the `Server` that starts this Federated Learning task with a TCP address of `127.0.0.1`. The starting port for the Federated Learning HTTP service is `6668` and the number of `Server`s is `4`. - - ```sh - python run_cross_silo_fasterrcnn_server.py --yaml_config="default_yaml_config.yaml" --tcp_server_ip="127.0.0.1" --checkpoint_dir="/path/to/fl_ckpt" --local_server_num=4 --http_server_address="127.0.0.1:6668" - ``` - - The above command is equivalent to starting four `Server` processes, each with a federated learning service port of `6668`, `6669`, `6670` and `6671`, as detailed in [run_cross_silo_fasterrcnn_server.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/run_cross_silo_fasterrcnn_server.py), and checkpoint_dir needs to enter the directory path where the checkpoint is located. The server will read the checkpoint initialization weight from this path. The prefix format of the checkpoint needs to be `{fl_name}_ recovery_ iteration_`. - - The following print represents a successful starting: - - ```sh - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.645 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_server.cc:122] Start] Start http server! - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.725 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:85] Initialize] Ev http register handle of: [/d isableFLS, /enableFLS, /state, /queryInstance, /newInstance] success. - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.878 [mindspore_federated/fl_arch/ccsrc/scheduler/scheduler.cc:35] Run] Scheduler started successfully. - [INFO] FEDERATED(3944,2b28c5ada700,python):2022-10-10-17:11:08.155.056 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:90] Run] Start http server! - ``` - -6. Start Worker - - `run_cross_silo_femnist_worker.py` is a Python script for starting a number of `worker`s, and supports modifying the configuration by the passing argument `argparse`. The following instruction is executed, representing the `worker` that starts this federated learning task, and the number of `workers` needed for the federated learning task to proceed properly is at least `2`. - - ```sh - python run_cross_silo_fasterrcnn_worker.py --local_worker_num=2 --yaml_config="default_yaml_config.yaml" --pre_trained="/path/to/pre_trained" --dataset_path=/path/to/datasets/coco_split/split_100 --http_server_address=127.0.0.1:6668 - ``` - - For the detailed implementation, see [run_cross_silo_femnist_worker.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/run_cross_silo_fasterrcnn_worker.py). Note that in dataset sink mode, the unit of the synchronization frequency of Cloud Federated is in epoch, otherwise the synchronization frequency is in step. - - As the above command, `--local_worker_num=2` means starting two clients, and the datasets used by the two clients are `datasets/coco_split/split_100/mindrecord_0` and `datasets/coco_split/split_100/mindrecord_1`. Please prepare the required datasets for the corresponding clients according to the `pre-task preparation` tutorial. - - After executing the above three commands and waiting for a while, go to the `worker_0` folder in the current directory and check the `worker_0` log with the command `grep -rn "\epoch:" *` and you will see a log message similar to the following: - - ```sh - epoch: 1 step: 1 total_loss: 0.6060338 - ``` - - Then it means that cross-silo federated is started successfully and `worker_0` is training. Other workers can be viewed in a similar way. - - At present, the `worker` node of Cloud Federated supports the distributed training mode of single machine multi-card and multi-machine multi-card. `run_cross_silo_fasterrcnn_worker_distributed.py` is a python script for users to start distributed training of the worker node, and supports configuration modification via argparse. Execute the following instructions, representing the distributed `worker` that starts this federated learning task, where `device_num` represents the number of processes started by the `worker` cluster, `run_distribute` represents the distributed training started by the cluster, and its http start port is `6668`. Number of `worker` processes is `4`: - - ```sh - python run_cross_silo_fasterrcnn_worker_distributed.py --device_num=4 --run_distribute=True --dataset_path=/path/to/datasets/coco_split/split_100 --http_server_address=127.0.0.1:6668 - ``` - - Enter the `worker_distributed/log_output/` folder in the current directory and run the `grep -rn "epoch" *` command to view the logs of the `worker` distributed cluster. You can see the following information: - - ```sh - epoch: 1 step: 1 total_loss: 0.613467 - ``` - - Please refer to [yaml configuration notes](https://www.mindspore.cn/federated/docs/en/master/horizontal/federated_server_yaml.html) for the description of parameter configuration in the above script. - -### Viewing the Log - -After successfully starting the task, the corresponding log file will be generated under the current directory `cross_silo_faster_rcnn`. The log file directory structure is as follows: - -```text -cross_silo_faster_rcnn -├── scheduler -│ └── scheduler.log # Print logs during running scheduler -├── server_0 -│ └── server.log # Print logs during running server_0 -├── server_1 -│ └── server.log # Print logs during running server_1 -├── server_2 -│ └── server.log # Print logs during running server_2 -├── server_3 -│ └── server.log # Print logs during running server_3 -├── worker_0 -│ ├── ckpt # Store the aggregated model ckpt obtained by worker_0 at the end of each federated learning iteration -│ │ └── mindrecord_0 -│ │ ├── mindrecord_0-fast-rcnn-0epoch.ckpt -│ │ ├── mindrecord_0-fast-rcnn-1epoch.ckpt -│ │ │ -│ │ │ ...... -│ │ │ -│ │ └── mindrecord_0-fast-rcnn-29epoch.ckpt -│ ├──loss_0.log # Record the loss value of each step in the training process of worker_0 -│ └── worker.log # Record the output logs during worker_0 participation in the federal learning task -└── worker_1 - ├── ckpt # Store the aggregated model ckpt obtained by worker_1 at the end of each federated learning iteration - │ └── mindrecord_1 - │ ├── mindrecord_1-fast-rcnn-0epoch.ckpt - │ ├── mindrecord_1-fast-rcnn-1epoch.ckpt - │ │ - │ │ ...... - │ │ - │ └── mindrecord_1-fast-rcnn-29epoch.ckpt - ├──loss_0.log # Record the loss value of each step in the training process of worker_1 - └── worker.log # Record the output logs during worker_1 participation in the federal learning task -``` - -### Closing the Mission - -If you want to exit in the middle, the following command is available: - -```sh -python finish_cross_silo_fasterrcnn.py --redis_port=2345 -``` - -For the detailed implementation, see [finish_cloud.py](https://gitee.com/mindspore/federated/blob/master/tests/st/cross_device_cloud/finish_cloud.py). - -Or when the training task is finished, the cluster exits automatically, no need to close it manually. - -### Results - -- Use data: - - COCO dataset is split into 100 copies, and the first two copies are taken as two worker datasets respectively - -- The number of client-side local training epochs: 1 - -- Total number of cross-silo federated learning iterations: 30 - -- Results (recording the loss values during the client-side local training): - - Go to the `worker_0` folder in the current directory, and check the `worker_0` log with the command `grep -rn "\]epoch:" *` to see the loss values output in each step: - - ```sh - epoch: 1 step: 1 total_loss: 5.249325 - epoch: 1 step: 2 total_loss: 4.0856013 - epoch: 1 step: 3 total_loss: 2.6916502 - epoch: 1 step: 4 total_loss: 1.3917351 - epoch: 1 step: 5 total_loss: 0.8109232 - epoch: 1 step: 6 total_loss: 0.99101084 - epoch: 1 step: 7 total_loss: 1.7741735 - epoch: 1 step: 8 total_loss: 0.9517553 - epoch: 1 step: 9 total_loss: 1.7988946 - epoch: 1 step: 10 total_loss: 1.0213892 - epoch: 1 step: 11 total_loss: 1.1700443 - . - . - . - ``` - -The histograms of the training loss transformations in each step of worker_1 and worker_2 during the 30 iterations training are as follows, [1] and [2]: - -The polygrams of the average loss (the sum of the losses of all the steps in an epoch divided by the number of steps) in each step of worker_1 and worker_2 during the 30 iterations training are as follows, [3] and [4]: - -![cross-silo_fastrcnn-2workers-loss.png](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/cross-silo_fastrcnn-2workers-loss.png) diff --git a/docs/federated/docs/source_en/pairwise_encryption_training.md b/docs/federated/docs/source_en/pairwise_encryption_training.md deleted file mode 100644 index 830d531997708f5d1003f57b9c6e6413a1acc1f6..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/pairwise_encryption_training.md +++ /dev/null @@ -1,67 +0,0 @@ -# Horizontal FL-Pairwise Encryption Training - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/pairwise_encryption_training.md) - -During federated learning, user data is used only for local device training and does not need to be uploaded to the central server. This prevents personal data leakage. -However, in the conventional federated learning framework, models are migrated to the cloud in plaintext. There is still a risk of indirect disclosure of user privacy. -After obtaining the plaintext model uploaded by a user, the attacker can restore the user's personal training data through attacks such as reconstruction and model inversion. As a result, user privacy is disclosed. - -As a federated learning framework, MindSpore Federated provides secure aggregation algorithms based on local secure multi-party computation (MPC). -Secret noise addition is performed on local models before they are migrated to the cloud. On the premise of ensuring the model availability, the problem of privacy leakage and model theft in horizontal federated learning are solved. - -## Principles - -Although the LDP technology can properly protect user data privacy, when there are a relatively small quantity of participating clients or a Gaussian noise amplitude is relatively large, the model accuracy is greatly affected. -To meet both model protection and model convergence requirements, we provide the MPC-based secure aggregation solution. - -In this training mode, assuming that the participating client set is $U$, for any client $u$ and $v$, they negotiate a pair of random perturbations $p_{uv}$ and $p_{vu}$, which meet the following condition: - -$$ -p_{uv}=\begin{cases} -p_{vu}, &u{\neq}v\\\\ 0, &u=v \end{cases} -$$ - -Therefore, each client $u$ adds the perturbation negotiated with other users to the original model weight $x_u$ before uploading the model to the server: - -$$ -x_{encrypt}=x_u+\sum\limits_{v{\in}U}p_{uv} -$$ - -Therefore, the Server aggregation result $\overline{x}$ is as follows: - -$$ -\begin{align} -\overline{x}&=\sum\limits_{u{\in}U}(x_{u}+\sum\limits_{v{\in}U}p_{uv})\\\\ -&=\sum\limits_{u{\in}U}x_{u}+\sum\limits_{u{\in}U}\sum\limits_{v{\in}U}p_{uv}\\\\ -&=\sum\limits_{u{\in}U}x_{u} -\end{align} -$$ - -The preceding process describes only the main idea of the aggregation algorithm. The MPC-based aggregation solution is accuracy-lossless but increases the number of communication rounds. -If you are interested in the specific steps of the algorithm, refer to the paper[1]. - -## Usage - -### Cross Device Scenario - -Enabling pairwise encryption training is simple. Just set the `encrypt_train_type` field to `PW_ENCRYPT` through yaml file when starting the cloud-side service. - -In addition, most of the workers participating in the training are unstable edge computing nodes such as mobile phones, so the problems of dropping the line and secret key reconstruction should be considered. Related parameters are `share_secrets_ratio`, `reconstruct_secrets_threshold`, and `cipher_time_window`. - -`share_client_ratio` indicates the client threshold decrease ratio of public key broadcast round, secret sharing round and secret reconstruction round. The value must be less than or equal to 1. - -`reconstruct_secrets_threshold` indicates the number of secret shares required to reconstruct a secret. The value must be less than the number of clients that participate in updateModel (start_fl_job_threshold*update_model_ratio). - -To ensure system security, the value of `reconstruct_secrets_threshold` must be greater than half of the number of federated learning clients when the server and client are not colluded. -When the server and client are colluded, the value of `reconstruct_secrets_threshold` must be greater than two thirds of the number of federated learning clients. - -`cipher_time_window` indicates the duration limit of each communication round for secure aggregation. It is used to ensure that the server can start a new round of iteration when some clients are offline. - -### Cross Silo Scenario - -In cross silo scenario, you only need to set the `encrypt_train_type` field to `PW_ENCRYPT` through yaml file in the cloud-side startup script. - -Different from cross silo scenario, all of the workers are stable computing nodes in cross silo scenario. You only need to set the parameter `cipher_time_window`. - -## References - -[1] Keith Bonawitz, Vladimir Ivanov, Ben Kreuter, et al. [Practical Secure Aggregationfor Privacy-Preserving Machine Learning](https://dl.acm.org/doi/pdf/10.1145/3133956.3133982). Proceedings of the 2017 ACM SIGSAC Conference on Computer and communications Security. 2017. diff --git a/docs/federated/docs/source_en/private_set_intersection.md b/docs/federated/docs/source_en/private_set_intersection.md deleted file mode 100644 index fcb7bef84f70fb76fdc7320dc30f256e7a516daa..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/private_set_intersection.md +++ /dev/null @@ -1,129 +0,0 @@ -# Vertical Federated-Privacy Set Intersection - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/private_set_intersection.md) - -## Privacy Protection Background - -With the rise in demand for digital transformation and the circulation of data elements, as well as the implementation of the Data Security Law, the Personal Information Protection Law and the EU General Data Protection Regulation (GDPR), privacy of data is increasingly becoming a necessary requirement in many scenarios. For example, when the dataset is sensitive information of users (medical diagnosis information, transaction records, identification codes, device unique identifier OAID, etc.) or secret information of the company, cryptography or desensitization must be used to ensure the confidentiality of the data before using it in the open state to achieve the goal of "usable but invisible" of the data in order to prevent information leakage. Considering two participants who jointly train a machine learning model (e.g., vertical federated learning) by using their respective data, the first step of this task is to align the sample sets of both parties, a process known as Entity Resolution. Traditional plaintext intersection inevitably reveals the OAID of the entire database and damages the data privacy of both parties, so the Privacy Set Intersection (PSI) technique is needed to accomplish this task. - -PSI is a type of secure multi-party computing (MPC) protocol that takes data collection from two parties as input, after a series of hashing, encryption and data exchange steps, eventually outputs the intersection of the collection to an agreed output party, while ensuring that the participating parties cannot obtain any information about the data outside the intersection. The use of the PSI protocol in vertical federated learning tasks, in compliance with the GDPR requirement of Data Minimisation, i.e. there is no non-essential exposure of data, except for the parts necessary for the training process (intersections). From the data controller's perspective, the service has to share data appropriately, but wants to share only necessary data based on the service and not expose additional data to the public. It should be noted that while PSI can directly apply existing MPC protocols to its calculations, this often results in a large computational and communication overhead, which is not conducive to business. In this paper, we introduce a technique combining Bloom filter and eliminable inverse scalar multiplication on the elliptic curve to implement ECDH-PSI (Elliptic Curve Diffie-Hellman key Exchange-PSI) to better support cloud services and carry out privacy preserving set intersection computing services. - -## Algorithm Process Introduction - -The core idea of ECDH-PSI is that a piece of data is first encrypted by Alice and then encrypted by Bob, with the same result as exchanging the encryption order. One party sends the data encrypted with its own private key without revealing its privacy, and the other party re-encrypts it with its own private key based on the received encrypted data. If the encryption result is the same, the original data is the same. - -The core optimization point of the inverse ECDH-PSI is to minimize the encryption computation based on the set of large amount of data when facing the scenario of intersection between two parties with unbalanced amount of data (Bob is the party with less data, $a$ and $b$ are the private keys of Alice and Bob respectively, the original data of both parties are mapped to the elliptic curve as $P_1$ and $P_2$ respectively, the point multiplication encryption of the elliptic curve with the private key $k$ is $P^k$ or $kP$, and the inverse of the private key $k$ is $k^{-1}$). Then after Alice executes $p_1^a$ and sends it to Bob, Bob no longer performs the encryption calculation based on it, but sends $p_2^b$ to Alice. After Alice sends $P_2^{ba}$, Bob completes the offset operation by point multiplying the inverse of its private key, i.e., calculating $P_2^{bab^{-1}}$ and comparing it with the $P_1^a$ sent by Alice. If the encryption result is the same, it means $P_1=P_2$. The flowchart of the inverse ECDH-PSI is shown in the figure, and the red letters indicate the received data from the other side. - -![inverse_ecdh_psi_flow](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/inverse_ecdh_psi_flow.png) - -The $bf$ in the figure stands for Bloom filter (bf). If you want to query whether an element exists in a collection, the basic method is to iterate through the collection to query, or sort the collection and use dichotomous lookup to query, but when the amount of data is too large, sorting does not support parallelism, which is very time-consuming. If a bloom filter is used, the elements of the set are mapped to a number of bits in an initial all-0 bit string by a number of hash functions, and all the elements of the sets share a single bit string. When querying, simply use the same number of hash functions to process the data to be queried, and directly access all the corresponding bits to see if they are activated to 1. If all of them are 1, it means that the data exists. Otherwise, it does not exist. The probability of collision can be achieved by controlling the number of hash functions. The communication overhead of the latter is lower compared to sending the entire set and sending a single bit string from the output of the Bloom filter. The computation can also be accelerated by parallelism during the creation of the bloom filter and the use of the filter for large-scale data queries. - -## Quick Experience - -### Front-end Needs - -Finish installing the `mindspore-federated` library in the Python environment. - -### Starting the Script - -You can get the PSI start script for both sides from [MindSpore federated ST](https://gitee.com/mindspore/federated/blob/master/tests/st/psi/run_psi.py) and open two processes to simulate both sides. The start command of local device and local communication: - -```python -python run_psi.py --comm_role="server" --http_server_address="127.0.0.1:8004" --remote_server_address="127.0.0.1:8005" --input_begin=1 --input_end=100 - -python run_psi.py --comm_role="client" --http_server_address="127.0.0.1:8005" --remote_server_address="127.0.0.1:8004" --input_begin=50 --input_end=150 -``` - -- `input_begin` is used in conjunction with `input_end` to generate the dataset for intersection. -- `peer_input_begin` and `peer_input_end` indicate the start and end ranges of each other's data, taking `--need_check` as `True`, which can be intersected by the Python set1.intersection(set2) function to get the true result, and is used to check the correctness of the PSI. -- `---bucket_size` (optional) indicates the number of for loops that serially perform multiple bucket intersections. -- `--thread_num` (optional) indicates the number of parallel threads used for the calculation. -- To run plaintext intersection, add the parameter `--plain_intersection=True` to the command. - -At present, psi supports the intersection of hundreds of millions of large data. You can specify the size of the input data set by setting 'input_begin', 'input_end', 'peer_input_begin', and 'peer_input_end' parameters. The theory proves that the memory resources of the machine and the system are sufficient, and there is no upper limit to the number of data computations psi can support. The startup command is as follows: - -```python -python run_psi.py --comm_role="server" --http_server_address="127.0.0.1:8004" --remote_server_address="127.0.0.1:8005" --input_begin=1 --input_end=100000000 - -python run_psi.py --comm_role="client" --http_server_address="127.0.0.1:8005" --remote_server_address="127.0.0.1:8004" --input_begin=1 --input_end=100000000 -``` - -### Output Results - -Before running the script, you can set the environment variable `export GLOG_v=1` to display the `INFO` level log, and also observe the operation of each phase within the protocol. After running the script, the intersection results will be printed out. As the amount of intersection data may be too large, the output here is limited to the first 20 intersection results. - -```bash -PSI result: ['50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69'] (display limit: 20) -``` - -## Deep Experience - -### Import Module - -To run the privacy set intersection, you need to rely on the communication module and the intersection module of the Federated Library, which are imported as follows: - -```python -from mindspore_federated.startup.vertical_federated_local import VerticalFederatedCommunicator, ServerConfig -from mindspore_federated._mindspore_federated import RunPSI -from mindspore_federated._mindspore_federated import PlainIntersection -``` - -### Data Preparation - -Both `RunPSI` and `PlainIntersection` require input data in `List(String)` format, and methods for generating datasets via file reading and for loops are given here: - -```python -def generate_input_data(input_begin_, input_end_, read_file_, file_name_): - input_data_ = [] - if read_file_: - with open(file_name_, 'r') as f: - for line in f.readlines(): - input_data_.append(line.strip()) - else: - input_data_ = [str(i) for i in range(input_begin_, input_end_)] - return input_data_ -``` - -The input parameters `input_begin_` and `input_end_` limit the data range of the for loop. `read_file_` and `file_name_` indicate whether to read the file and the path where the file is located. The file can be constructed by itself, each line representing one piece of data. - -### Constructing Communication - -Before calling this interface, a vertical federated communication instance needs to be initialized, as follows: - -```python -http_server_config = ServerConfig(server_name=comm_role, server_address=http_server_address) -remote_server_config = ServerConfig(server_name=peer_comm_role, server_address=remote_server_address) -vertical_communicator = VerticalFederatedCommunicator(http_server_config=http_server_config, - remote_server_config=remote_server_config) -vertical_communicator.launch() -``` - -- `server_name` is determined by whether the process belongs to `server` or `client`. `comm_role` is assigned to the corresponding `server` or `client`, and `peer_comm_role_` indicates the role of the other party. -- The format of `server_address` is "IP:port". `http_server_address` is assigned to the `IP` and `port` information of the process, such as "127.0.0.1:8004". `remote_server_address` is assigned to the `IP` and `port` information of the other party. - -### Starting Intersection - -The external interfaces for secure set intersection are `RunPSI` and `PlainIntersection`, which are ciphertext and plaintext intersections respectively, with the same type and meaning of input and return results. Only ciphertext intersection `RunPSI` is described here: - -```python -result = RunPSI(input_data, comm_role, peer_comm_role, bucket_id, thread_num) -``` - -- `input_data`: (list[string]), psi, the input data of one party. -- `comm_role`: (string), communication-related parameter, "server" or "client". -- `peer_comm_role`: (string), communication-related parameter, "server" or "client", different with comm_role. -- `bucket_id`: (int), outer part of the barrel, serial number of the pass-in barrel. `TypeError` error for passing in negative numbers, decimals or other types. If the value is different between two processes, the server will exit with an error and the client will block and wait. -- `thread_num`: (int), number of threads, natural number. 0 is the default value, which means use the maximum number of threads available on the machine minus 5, and other values will be limited to 1 to the maximum available on the machine. `TypeError` error for passing in negative numbers, decimals or other types. - -### Output Results - -The `result` is in `list[string]` format, which represents the intersection result and can be printed out by itself. Here's the method of the Python set intersection: - -```python -def compute_right_result(self_input, peer_input): - self_input_set = set(self_input) - peer_input_set = set(peer_input) - return self_input_set.intersection(peer_input_set) -``` - -The results of the above methods can be compared with `result` to check if they are consistent, and the correctness of the interface can be verified. diff --git a/docs/federated/docs/source_en/secure_vertical_federated_learning_with_DP.md b/docs/federated/docs/source_en/secure_vertical_federated_learning_with_DP.md deleted file mode 100644 index 6a302346788f1c971ffdee8805caf30c79e2ee07..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/secure_vertical_federated_learning_with_DP.md +++ /dev/null @@ -1,158 +0,0 @@ -# Vertical Federated - Label Protection Based on Differential Privacy - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/secure_vertical_federated_learning_with_DP.md) - -## Background - -Vertical federated learning (vFL) is a major branch of federated learning (FL). When different participants have data from the same batch of users but with different attributes, they can use vFL for collaborative training. In vFL, the participants with user features (follower for short, participant A as shown in the figure below) hold a bottom network (Bottom Model). They input the features into the bottom network, compute the intermediate results (embedding), and send them to the participants with labels (leader for short, participant B as shown in the figure below). The leader uses these embeddings and its own labels to train the upper network (upper network), and then passes the computed gradients back to each participant to train the bottom network. It can be seen that vFL does not require any participant to upload their own raw data to collaboratively train the model. - -![image.png](./images/vfl_1_en.png) - -The vFL framework avoids the direct upload of raw data, protecting data privacy at a certain level. However, there exists possibility for a semi-honest or a malicious party to infer the label information from the gradients passing back from the leader party, causing privacy disclosure. Given the large number of vFL scenarios where labels are the most valuable and most important piece of information to protect, in this context, we need to provide stronger privacy guarantees for vFL training to avoid the privacy disclosure. - -Differential privacy (DP) is a definition of privacy based strictly on statistics/information theory, which currently the golden standard of privacy-preserving data analysis. The core idea behind DP is to induce randomness to overwhelm each individual data's influence on the algorithm's result, making sure that it is hard for the algorithm's results to be inverted to the individual data. The protection of DP can hold under an extreme threat model, which holds even when: - -- the adversary knows all the details of the DP algorithm -- the adversary has infinite computing power -- the adversary has arbitrary auxiliary information about the raw data - -Regarding the backgrounds, theories and implementation of DP, please refer to [1] for an excellent survey. - -Our scheme is based on label differential privacy (label dp) [2], which provides differential privacy guarantees for the labels of the leader participants during vertical federated learning training, so that an attacker cannot invert the label information of the data from the returned gradients. Under the protection of this scheme, even if the follower party is semi-honest or malicious, the label information of the leader party is guaranteed to be protected, mitigating vFL participants' concerns on the data privacy risk. - -## Algorithm Implementation - -MindSpore Federated adopt a lightweight implementation of label dp. During training, a certain percentage of the labels are randomly flipped before using the label data from the leader participants. Due to the introduction of randomness, an attacker who wants to invert the labels can at most invert the labels after the random flip or perturbation, increasing the difficulty of inverting the original labels and satisfying the differential privacy guarantee. In practical applications, we can adjust the privacy parameter `eps` (which can be interpreted as the ratio of randomly flipped labels) to meet the needs of different scenarios: - -- smaller `eps` (<1.0) corresponds to high privacy, low performance -- larger `eps` (>5.0) corresponds to high performance, low privacy - -![image.png](./images/label_dp_en.png) - -The implementation of this scheme is divided into the binary case and the onehot case. Whether the input labels are binary or onehot is automatically recognized, and the same type of labels will then be output. The detailed algorithm is shown as follows: - -### Binary Labels Protection - -1. Calculate the flip probability $p = \frac{1}{1 + e^{eps}}$ according to the preset privacy parameter eps. -2. Flip each label with probability $p$. - -### Onehot Labels Protection - -1. For n classes of labels, calculate $p_1 = \frac{e^{eps}}{n - 1 + e^{eps}}$, $p_2 = \frac{1}{n - 1 + e^{eps}}$. -2. Randomly scramble the labels according to the following probabilities: the probability of keeping the current label unchanged is $p_1$ and the probability of changing to any of the other n - 1 classes is $p_2$. - -## Quick Experience - -We use the local case in [Wide&Deep Vertical Federated Learning Case](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo) as an example of how to add label dp to a vertical federated model protection. - -### Front-End Needs - -1. Install MindSpore 1.8.1 or its higher version, please refer to [MindSpore official website installation guide](https://www.mindspore.cn/install). -2. Install MindSpore Federated and the Python libraries which the MindSpore Federated depends on. - - ```shell - cd federated - python -m pip install -r requirements_test.txt - ``` - -3. Prepare the criteo dataset, please refer to [Wide&Deep Vertical Federated Learning Case](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo). - -### Starting the Script - -1. Download federated - - ```bash - git clone https://gitee.com/mindspore/federated.git - ``` - -2. Go to the folder where the script is located - - ```bash - cd federated/example/splitnn_criteo - ``` - -3. Run the script - - ```bash - sh run_vfl_train_local_label_dp.sh - ``` - -### Viewing Results - -Check loss changes of the model training in the training log `log_local_gpu.txt`. - -```sh -INFO:root:epoch 0 step 100/2582 loss: 0.588637 -INFO:root:epoch 0 step 200/2582 loss: 0.561055 -INFO:root:epoch 0 step 300/2582 loss: 0.556246 -INFO:root:epoch 0 step 400/2582 loss: 0.557931 -INFO:root:epoch 0 step 500/2582 loss: 0.553283 -INFO:root:epoch 0 step 600/2582 loss: 0.549618 -INFO:root:epoch 0 step 700/2582 loss: 0.550243 -INFO:root:epoch 0 step 800/2582 loss: 0.549496 -INFO:root:epoch 0 step 900/2582 loss: 0.549224 -INFO:root:epoch 0 step 1000/2582 loss: 0.547547 -INFO:root:epoch 0 step 1100/2582 loss: 0.546989 -INFO:root:epoch 0 step 1200/2582 loss: 0.552165 -INFO:root:epoch 0 step 1300/2582 loss: 0.546926 -INFO:root:epoch 0 step 1400/2582 loss: 0.558071 -INFO:root:epoch 0 step 1500/2582 loss: 0.548258 -INFO:root:epoch 0 step 1600/2582 loss: 0.546442 -INFO:root:epoch 0 step 1700/2582 loss: 0.549062 -INFO:root:epoch 0 step 1800/2582 loss: 0.546558 -INFO:root:epoch 0 step 1900/2582 loss: 0.542755 -INFO:root:epoch 0 step 2000/2582 loss: 0.543118 -INFO:root:epoch 0 step 2100/2582 loss: 0.542587 -INFO:root:epoch 0 step 2200/2582 loss: 0.545770 -INFO:root:epoch 0 step 2300/2582 loss: 0.554520 -INFO:root:epoch 0 step 2400/2582 loss: 0.551129 -INFO:root:epoch 0 step 2500/2582 loss: 0.545622 -... -``` - -## Deep Experience - -We take the local case in [Wide&Deep Vertical Federated Learning Case](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo) as an example to introduce the specific operation method of adding label dp protection in the vertical federated model. - -### Front-End Needs - -Same as [Quick Experience](#quick-experience): Install MindSpore, Install MindSpore Federated, and Prepare dataset. - -### Option 1: Call the integrated label dp function in the FLModel class - -MindSpore Federated uses `FLModel` (see [Vertical Federated Learning Model Training Interface](https://www.mindspore.cn/federated/docs/en/master/vertical/vertical_federated_FLModel.html)) and yaml files (see [detailed configuration items of Vertical Federated Learning yaml](https://www.mindspore.cn/federated/docs/en/master/vertical/vertical_federated_yaml.html)) to model the training process of vertical federated learning. - -We have integrated the label dp function in the `FLModel` class. After the normal completion of modeling the entire vertical federated learning training process (for detailed vFl training, see [Vertical Federated Learning Model Training - Pangu Alpha Large Model Cross-Domain Training](https://www.mindspore.cn/federated/docs/en/master/split_pangu_alpha_application.html)), users can simply add the `label_dp` submodule under the `privacy` module in the yaml file of the label side (or add it by user if there is no `privacy` module), and set the `eps` parameter in the `label_dp` module (differential privacy parameter $\epsilon$, the user can set the value of this parameter according to the actual needs). Let the model enjoy label dp protection: - -```yaml -privacy: - label_dp: - eps: 1.0 -``` - -### Option 2: Directly call the LabelDP class - -Users can also call the `LabelDP` class directly to use the label dp function more flexibly. The `LabelDP` class is integrated in the `mindspore_federated.privacy` module. The user can define a `LabelDP` object by specifying the value of `eps`, and then pass the label group as an argument to this object. The `_call_` functio of objext will automatically recognize whether the current input is one-hot or binary label and outputs a label group processed by label dp. Refer to the following example: - -```python -# make private a batch of binary labels -import numpy as np -import mindspore -from mindspore import Tensor -from mindspore_federated.privacy import LabelDP -label_dp = LabelDP(eps=0.0) -label = Tensor(np.zero(5, 1), dtype=mindspore.float32) -dp_label = label_dp(label) - -# make private a batch of one-hot labels -label = Tensor(np.hstack((np.ones((5, 1)), np.zeros((5, 2)))), dtype=mindspore.float32) -dp_label = label_dp(label) -print(dp_label) -``` - -## References - -[1] Dwork C, Roth A. The algorithmic foundations of differential privacy[J]. Foundations and Trends® in Theoretical Computer Science, 2014, 9(3–4): 211-407. - -[2] Ghazi B, Golowich N, Kumar R, et al. Deep learning with label differential privacy[J]. Advances in Neural Information Processing Systems, 2021, 34: 27131-27145. - diff --git a/docs/federated/docs/source_en/secure_vertical_federated_learning_with_EmbeddingDP.md b/docs/federated/docs/source_en/secure_vertical_federated_learning_with_EmbeddingDP.md deleted file mode 100644 index 57147fb4aa4098e2f855039ce67e98ffd10f7474..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/secure_vertical_federated_learning_with_EmbeddingDP.md +++ /dev/null @@ -1,135 +0,0 @@ -# Vertical Federated-Feature Protection Based on Information Obfuscation - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/secure_vertical_federated_learning_with_EmbeddingDP.md) - -## Background - -Vertical Federated Learning (vFL) is a mainstream and important joint learning paradigm. In vFL, n (n ≥ 2) participants have a large number of identical users, but the overlap of user characteristics is small. MindSpore Federated uses Split Learning (SL) technology to implement vFL. Taking the two-party split learning shown in the figure below as an example, each participant does not share the original data directly, but shares the intermediate features extracted by the local model for training and inference, satisfying the privacy requirement that the original data will not be leaked. - -However, it has been shown [1] that an attacker (e.g., participant 2) can reduce the corresponding original data (feature) by intermediate features (E), resulting in privacy leakage. For such feature reconstruction attacks, this tutorial provides a lightweight feature protection scheme based on information obfuscation [2]. - -![image.png](./images/vfl_feature_reconstruction_en.png) - -## Scheme Details - -The protection scheme is named EmbeddingDP, and the overall picture is shown below. For the generated intermediate features E, the obfuscation operations such as Quantization and Differential Privacy (DP) are applied sequentially to generate P and send P to the participant 2 as an intermediate feature. The obfuscation operation greatly reduces the correlation between the intermediate features and the original input, which makes the attack more difficult. - -![image.png](./images/vfl_feature_reconstruction_defense_en.png) - -Currently, this tutorial supports single-bit quantization and differential privacy protection based on random responses, and the details of the scheme are shown in the figure below. - -1. **Single-bit quantization**: For the input vector E, single-bit quantization will set the number greater than 0 to 1 and the number less than or equal to 0 to 0, generating the binary vector B. - -2. **Differential privacy based on random responses (DP)**: Differential privacy requires the configuration of the key parameter `eps`. If `eps` is not configured, no differential privacy is performed and the binary vector B is directly used as the intermediate feature to be transmitted. If `eps` is correctly configured (i.e., `eps` is a non-negative real number), the larger `eps` is, the lower the probability of confusion and the smaller the impact on the data, and at the same time, the privacy protection is relatively weak. For any dimension i in the binary vector B, if B[i] = 1, the value is kept constant with probability p. If B[i] = 0, B[i] is flipped with probability q, i.e., so that B[i] = 1. Probabilities p and q are calculated based on the following equations, where e denotes the natural base number. - -$$p = \frac{e^{(eps / 2)}}{e^{(eps / 2)} + 1},\quad q = \frac{1}{e^{(eps / 2)} + 1}$$ - -![image.png](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/vfl_mnist_detail.png) - -## Feature Experience - -This feature can work with one-dimensional or two-dimensional tensor arrays. One-dimensional arrays can only consist of the numbers 0 and 1, and two-dimensional arrays need to consist of one-dimensional vectors in the one-hot encoded format. After [installing MindSpore and Federated](https://mindspore.cn/federated/docs/en/master/federated_install.html#obtaining-mindspore-federated), this feature can be applied to process a tensor array that meets the requirements, as shown in the following sample program: - -```python -import mindspore as ms -from mindspore import Tensor -from mindspore.common.initializer import Normal -from mindspore_federated.privacy import EmbeddingDP - -ori_tensor = Tensor(shape=(2,3), dtype=ms.float32, init=Normal()) -print(ori_tensor) -dp_tensor = EmbeddingDP(eps=1)(ori_tensor) -print(dp_tensor) -``` - -## Application Examples - -### Protecting the Pangu Alpha Large Model Cross-Domain Training - -#### Preparation - -Download the federated code repository and follow the tutorial [Longitudinal Federated Learning Model Training - Pangu Alpha Large Model Cross-Domain Training](https://mindspore.cn/federated/docs/en/master/split_pangu_alpha_application.html#environment-preparation), configure the runtime environment and experimental dataset, and then run the single-process or multi-process example program as needed. - -```bash -git clone https://gitee.com/mindspore/federated.git -``` - -#### Single-process Sample - -1. Go to the directory where the sample is located and execute [running single-process sample](https://mindspore.cn/federated/docs/en/master/split_pangu_alpha_application.html#running-a-single-process-example) in steps 2 to 4: - - ```bash - cd federated/example/splitnn_pangu_alpha - ``` - -2. Start the training script with EmbeddingDP configured: - - ```bash - sh run_pangu_train_local_embedding_dp.sh - ``` - -3. View the training loss in the training log `splitnn_pangu_local.txt`: - - ```text - 2023-02-07 01:34:00 INFO: The embedding is protected by EmbeddingDP with eps 5.000000. - 2023-02-07 01:35:40 INFO: epoch 0 step 10/43391 loss: 10.653997 - 2023-02-07 01:36:25 INFO: epoch 0 step 20/43391 loss: 10.570406 - 2023-02-07 01:37:11 INFO: epoch 0 step 30/43391 loss: 10.470503 - 2023-02-07 01:37:58 INFO: epoch 0 step 40/43391 loss: 10.242296 - 2023-02-07 01:38:45 INFO: epoch 0 step 50/43391 loss: 9.970814 - 2023-02-07 01:39:31 INFO: epoch 0 step 60/43391 loss: 9.735226 - 2023-02-07 01:40:16 INFO: epoch 0 step 70/43391 loss: 9.594692 - 2023-02-07 01:41:01 INFO: epoch 0 step 80/43391 loss: 9.340107 - 2023-02-07 01:41:47 INFO: epoch 0 step 90/43391 loss: 9.356388 - 2023-02-07 01:42:34 INFO: epoch 0 step 100/43391 loss: 8.797981 - ... - ``` - -#### Multi-process Sample - -1. Go to the directory where the sample is located, install the dependency packages, and configure the dataset: - - ```bash - cd federated/example/splitnn_pangu_alpha - python -m pip install -r requirements.txt - cp -r {dataset_dir}/wiki ./ - ``` - -2. Start the training script on Server 1 with EmbeddingDP configured: - - ```bash - sh run_pangu_train_leader_embedding_dp.sh {ip1:port1} {ip2:port2} ./wiki/train ./wiki/train - ``` - - `ip1` and `port1` denote the IP address and port number of the participating local server (server 1). `ip2` and `port2` denote the IP address and port number of the peer server (server 2). `. /wiki/train` is the training dataset file path, and `. /wiki/test` is the evaluation dataset file path. - -3. Start training script of another participant on Server 2: - - ```bash - sh run_pangu_train_follower.sh {ip2:port2} {ip1:port1} - ``` - -4. View the training loss in the training log `leader_process.log`: - - ```text - 2023-02-07 01:39:15 INFO: config is: - 2023-02-07 01:39:15 INFO: Namespace(ckpt_name_prefix='pangu', ...) - 2023-02-07 01:39:21 INFO: The embedding is protected by EmbeddingDP with eps 5.000000. - 2023-02-07 01:41:05 INFO: epoch 0 step 10/43391 loss: 10.669225 - 2023-02-07 01:41:38 INFO: epoch 0 step 20/43391 loss: 10.571924 - 2023-02-07 01:42:11 INFO: epoch 0 step 30/43391 loss: 10.440327 - 2023-02-07 01:42:44 INFO: epoch 0 step 40/43391 loss: 10.253876 - 2023-02-07 01:43:16 INFO: epoch 0 step 50/43391 loss: 9.958257 - 2023-02-07 01:43:49 INFO: epoch 0 step 60/43391 loss: 9.704673 - 2023-02-07 01:44:21 INFO: epoch 0 step 70/43391 loss: 9.543740 - 2023-02-07 01:44:54 INFO: epoch 0 step 80/43391 loss: 9.376131 - 2023-02-07 01:45:26 INFO: epoch 0 step 90/43391 loss: 9.376905 - 2023-02-07 01:45:58 INFO: epoch 0 step 100/43391 loss: 8.766671 - ... - ``` - -## Works Cited - -[1] Erdogan, Ege, Alptekin Kupcu, and A. Ercument Cicek. "Unsplit: Data-oblivious model inversion, model stealing, and label inference attacks against split learning." arXiv preprint arXiv:2108.09033 (2021). - -[2] Anonymous Author(s). "MistNet: Towards Private Neural Network Training with Local Differential Privacy". (https://github.com/TL-System/plato/blob/2e5290c1f3acf4f604dad223b62e801bbefea211/docs/papers/MistNet.pdf) diff --git a/docs/federated/docs/source_en/secure_vertical_federated_learning_with_TEE.md b/docs/federated/docs/source_en/secure_vertical_federated_learning_with_TEE.md deleted file mode 100644 index a2ba1b5df70f31390a4d92a8a87d031bb7fb5048..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/secure_vertical_federated_learning_with_TEE.md +++ /dev/null @@ -1,281 +0,0 @@ -# Vertical Federated - Feature Protection Based on Trusted Execution Environment - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/secure_vertical_federated_learning_with_TEE.md) - -Note: This is an experimental feature and may be modified or removed in the future. - -## Background - -Vertical federated learning (vFL) is a major branch of federated learning (FL). When different participants have data from the same batch of users but with different attributes, they can use vFL for collaborative training. In vFL, each participant with attributes holds a bottom model, and they input the attributes into the bottom model to get the intermediate result (embedding), which is sent to the participant with labels (referred to as leader paraticipant, participant B as shown in the figure below, as shown in the figure below, and the participant without labels, called follower, as shown in the figure below, as participant A). The leader side uses the embedding and labels to train the upper layer network, and then passes the calculated gradients back to each participant to train the lower layer network. It can be seen that vFL does not require any participant to upload their own raw data to collaboratively train the model. - -![image.png](./images/vfl_1_en.png) - -By avoiding direct uploading of raw data, vFL protects privacy security to a certain extent, which is one of the core goals of vFL. However, it is still possible for an attacker to reverse user information from the uploaded embedding, causing privacy security risks. In such a context, we need to provide stronger privacy guarantees for the embedding and gradients transmitted during vFL training to circumvent privacy security risks. - -Trusted execution environment (TEE) is a hardware-based trusted computing solution that provides data security of the computing process by making the whole computing process in hardware black-boxed relative to the outside world. By shielding the key layer in the vFL network through TEE, it can make the computation of that layer difficult to be reversed, thus ensuring the data security of the vFL training and inference process. - -## Algorithm Introduction - -![image.png](./images/vfl_with_tee_en.png) - -As shown in the figure, if participant A sends the intermediate result $\alpha^{(A)}$ directly to participant B, it is easy for participant B to use the intermediate result to reverse the original data $X^{(A)}$ of participant A. To reduce such risk, participant A encrypts the intermediate result $\alpha^{(A)}$ computed by Bottom Model to get $E(\alpha^{(A)})$ first, and passes $E(\alpha^{(A)})$ to participant B. Participant B inputs $E(\alpha^{(A)})$ into the TEE-based Cut Layer, and then decrypts it into $\alpha^{(A)}$ for forward propagation inside the TEE, and the whole process is black-boxed for B. - -The gradient is passed backward similarly, Cut Layer computes the gradient $\nabla\alpha^{(A)}$, encrypts it into $E(\nabla\alpha^{(A)})$ and then passes it back from participant B to participant A. Then participant A decrypts it into $\nabla\alpha^{(A)}$ and continues to do backward propagation. - -## Quick Experience - -We use the local case in [Wide&Deep Vertical Federated Learning Case](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo) as an example of configuring TEE protection. - -### Front-End Needs and Environment Configuration - -1. Environmental requirements. - - - Processor: Intel SGX (Intel Sofrware Guard Extensions) support required - - OS: openEuler 20.03, openEuler 21.03 LTS SP2 or higher - -2. Install SGX and SecGear (you can refer to [secGear official website](https://gitee.com/openeuler/secGear)). - - ```sh - sudo yum install -y cmake ocaml-dune linux-sgx-driver sgxsdk libsgx-launch libsgx-urts sgxssl - git clone https://gitee.com/openeuler/secGear.git - cd secGear - source /opt/intel/sgxsdk/environment && source environment - mkdir debug && cd debug && cmake .. && make && sudo make install - ``` - -3. Install MindSpore 1.8.1 or its higher version, please refer to the [MindSpore Official Site Installation Guide](https://www.mindspore.cn/install). - -4. Download federated - - ```sh - git clone https://gitee.com/mindspore/federated.git - ``` - -5. Download four lib files as TEE dependencies: [libsgx_0.so](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/libsgx_0.so), [libsecgear.so](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/libsecgear.so), [enclave.signed.so](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/enclave.signed.so) and [libcsecure_channel_static.a](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/libcsecure_channel_static.a). Put them into `mindspore_federated/fl_arch/ccsrc/armour/lib` (make new directory required). - -6. For installing MindSpore Federated relies on Python libraries, see [Wide&Deep Vertical Federated Learning Case](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo). - -7. Install MindSpore Federated for TEE compilation (need to additionally set compiler options to indicate whether to use SGX or not). - - ```sh - sh federated/build.sh -s on - pip install federated/build/packages/mindspore_federated-XXXXX.whl - ``` - -8. To prepare the criteo dataset, please refer to [Wide&Deep Vertical Federated Learning Case](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo). - -### Starting the Script - -1. Go to the folder where the script is located - - ```sh - cd federated/example/splitnn_criteo - ``` - -2. Run the script - - ```sh - sh run_vfl_train_local_tee.sh - ``` - -### Viewing Results - -Check loss changes of the model training in the training log `log_local_cpu_tee.txt`. - -```sh -INFO:root:epoch 0 step 100/41322 wide_loss: 0.661822 deep_loss: 0.662018 -INFO:root:epoch 0 step 100/41322 wide_loss: 0.685003 deep_loss: 0.685198 -INFO:root:epoch 0 step 200/41322 wide_loss: 0.649380 deep_loss: 0.649381 -INFO:root:epoch 0 step 300/41322 wide_loss: 0.612189 deep_loss: 0.612189 -INFO:root:epoch 0 step 400/41322 wide_loss: 0.630079 deep_loss: 0.630079 -INFO:root:epoch 0 step 500/41322 wide_loss: 0.602897 deep_loss: 0.602897 -INFO:root:epoch 0 step 600/41322 wide_loss: 0.621647 deep_loss: 0.621647 -INFO:root:epoch 0 step 700/41322 wide_loss: 0.624762 deep_loss: 0.624762 -INFO:root:epoch 0 step 800/41322 wide_loss: 0.622042 deep_loss: 0.622042 -INFO:root:epoch 0 step 900/41322 wide_loss: 0.585274 deep_loss: 0.585274 -INFO:root:epoch 0 step 1000/41322 wide_loss: 0.590947 deep_loss: 0.590947 -INFO:root:epoch 0 step 1100/41322 wide_loss: 0.586775 deep_loss: 0.586775 -INFO:root:epoch 0 step 1200/41322 wide_loss: 0.597362 deep_loss: 0.597362 -INFO:root:epoch 0 step 1300/41322 wide_loss: 0.607390 deep_loss: 0.607390 -INFO:root:epoch 0 step 1400/41322 wide_loss: 0.584204 deep_loss: 0.584204 -INFO:root:epoch 0 step 1500/41322 wide_loss: 0.583618 deep_loss: 0.583618 -INFO:root:epoch 0 step 1600/41322 wide_loss: 0.573294 deep_loss: 0.573294 -INFO:root:epoch 0 step 1700/41322 wide_loss: 0.600686 deep_loss: 0.600686 -INFO:root:epoch 0 step 1800/41322 wide_loss: 0.585533 deep_loss: 0.585533 -INFO:root:epoch 0 step 1900/41322 wide_loss: 0.583466 deep_loss: 0.583466 -INFO:root:epoch 0 step 2000/41322 wide_loss: 0.560188 deep_loss: 0.560188 -INFO:root:epoch 0 step 2100/41322 wide_loss: 0.569232 deep_loss: 0.569232 -INFO:root:epoch 0 step 2200/41322 wide_loss: 0.591643 deep_loss: 0.591643 -INFO:root:epoch 0 step 2300/41322 wide_loss: 0.572473 deep_loss: 0.572473 -INFO:root:epoch 0 step 2400/41322 wide_loss: 0.582825 deep_loss: 0.582825 -INFO:root:epoch 0 step 2500/41322 wide_loss: 0.567196 deep_loss: 0.567196 -INFO:root:epoch 0 step 2600/41322 wide_loss: 0.602022 deep_loss: 0.602022 -``` - -## Deep Experience - -The forward and backward propagation of the TEE layer requires calling its own functions rather than MindSpore, so there are differences in implementation from the usual vFL model. - -Usually, the Top Model and Cut Layer are put together for the backpropagation of the vFL model during training, and are derived and updated in one step by Participant B through MindSpore. When the network containing TEE is back propagated, the Top Model is updated by Participant B based on MindSpore, while the Cut Layer (TEE) is updated within itself after receiving the gradients back from the Top Model. The gradients that need to be passed back to Participant A are encrypted and passed out to Participant B. The whole process is done within the TEE. - -Currently in MindSpore Federated, the above function is used to implement a custom backward propagation process by passing `grad_network` into the `mindspore_federated.vfl_model.FLModel()` definition. Therefore, to implement a network containing TEE, the user can define the backward propagation process for Top Model and Cut Layer in `grad_network` and just pass in `FLModel`, and `FLModel` will go through the user-defined training process during backward propagation. - -We use the local case in [Wide&Deep Vertical Federated Learning Case](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo) as an example of how to configure TEE protection in a vertical federated model. The presentation focuses on the differences between the configuration and the usual case when using TEE, and the same points will be skipped (a detailed description of vFL training can be found in [Vertical Federated Learning Model Training - Pangu Alpha Large Model Cross-Domain Training](https://mindspore.cn/federated/docs/en/master/split_pangu_alpha_application.html)). - -### Front-End Needs and Environment Configuration - -Refer to [Quick Experience](#quick-experience). - -### Defining the Network Model - -#### Forward Propagation - -As usual vFL training, users need to define a network model containing TEE based on the `nn.Cell` provided by MindSpore (see [mindspore.nn.Cell](https://mindspore.cn/docs/en/master/api_python/nn/mindspore.nn.Cell.html#mindspore-nn-cell)) to develop the training network. The difference is that at the layer where the TEE is located, the user needs to call the TEE forward propagation function in the `construct` function of the class: - -```python -from mindspore_federated._mindspore_federated import init_tee_cut_layer, backward_tee_cut_layer, \ - encrypt_client_data, secure_forward_tee_cut_layer - -class TeeLayer(nn.Cell): - """ - TEE layer of the leader net. - Args: - config (class): default config info. - """ - def __init__(self, config): - super(TeeLayer, self).__init__() - init_tee_cut_layer(config.batch_size, 2, 2, 1, 3.5e-4, 1024.0) - self.concat = ops.Concat(axis=1) - self.reshape = ops.Reshape() - - def construct(self, wide_out0, deep_out0, wide_embedding, deep_embedding): - """Convert and encrypt the intermediate data""" - local_emb = self.concat((wide_out0, deep_out0)) - remote_emb = self.concat((wide_embedding, deep_embedding)) - aa = remote_emb.flatten().asnumpy().tolist() - bb = local_emb.flatten().asnumpy().tolist() - enc_aa, enc_aa_len = encrypt_client_data(aa, len(aa)) - enc_bb, enc_bb_len = encrypt_client_data(bb, len(bb)) - tee_output = secure_forward_tee_cut_layer(remote_emb.shape[0], remote_emb.shape[1], - local_emb.shape[1], enc_aa, enc_aa_len, enc_bb, enc_bb_len, 2) - tee_output = self.reshape(Tensor(tee_output), (remote_emb.shape[0], 2)) - return tee_output -``` - -#### Backward Propagation - -In the usual vfl model, backward propagation is automatically configured by the `FLModel` class, but in models containing TEE, the user needs to develop a `grad_network` to define the backward propagation process. `grad_network` is also based on `nn.Cell` and includes a `__init__` function and a `construct` function. When initializing, you need to pass in the network used for training and define in the `__init__` function: the derivative operator, the parameters for the network outside Cut Layer, the loss function, the Optimizer for the network outside the Cut Layer. The example is as follows: - -```python -class LeaderGradNet(nn.Cell): - """ - grad_network of the leader party. - Args: - net (class): LeaderNet, which is the net of leader party. - config (class): default config info. - """ - - def __init__(self, net: LeaderNet): - super().__init__() - self.net = net - self.sens = 1024.0 - - self.grad_op_param_sens = ops.GradOperation(get_by_list=True, sens_param=True) - self.grad_op_input_sens = ops.GradOperation(get_all=True, sens_param=True) - - self.params_head = ParameterTuple(net.head_layer.trainable_params()) - self.params_bottom_deep = vfl_utils.get_params_by_name(self.net.bottom_net, ['deep', 'dense']) - self.params_bottom_wide = vfl_utils.get_params_by_name(self.net.bottom_net, ['wide']) - - self.loss_net = HeadLossNet(net.head_layer) - self.loss_net_l2 = L2LossNet(net.bottom_net, config) - - self.optimizer_head = Adam(self.params_head, learning_rate=3.5e-4, eps=1e-8, loss_scale=self.sens) - self.optimizer_bottom_deep = Adam(self.params_bottom_deep, learning_rate=3.5e-4, eps=1e-8, loss_scale=self.sens) - self.optimizer_bottom_wide = FTRL(self.params_bottom_wide, learning_rate=5e-2, l1=1e-8, l2=1e-8, - initial_accum=1.0, loss_scale=self.sens) -``` - -The input to the `construct` function of `grad_network` is two dictionaries `local_data_batch` and `remote_data_batch`. In the `construct` function, you first need to extract the corresponding data from the dictionaries. Next, the layers other than TEE need to call the MindSpore derivative operators on the parameters and the input for the derivative operation and update with the optimizer respectively. The TEE layer needs to call the built-in functions of TEE for the derivative and update. The examples are as follows: - -```python -def construct(self, local_data_batch, remote_data_batch): - """ - The backward propagation of the leader net. - """ - # data processing - id_hldr = local_data_batch['id_hldr'] - wt_hldr = local_data_batch['wt_hldr'] - label = local_data_batch['label'] - wide_embedding = remote_data_batch['wide_embedding'] - deep_embedding = remote_data_batch['deep_embedding'] - - # forward - wide_out0, deep_out0 = self.net.bottom_net(id_hldr, wt_hldr) - local_emb = self.concat((wide_out0, deep_out0)) - remote_emb = self.concat((wide_embedding, deep_embedding)) - head_input = self.net.cut_layer(wide_out0, deep_out0, wide_embedding, deep_embedding) - loss = self.loss_net(head_input, label) - - # update of head net - sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), 1024.0) - grad_head_input, _ = self.grad_op_input_sens(self.loss_net)(head_input, label, sens) - grad_head_param = self.grad_op_param_sens(self.loss_net, self.params_head)(head_input, label, sens) - self.optimizer_head(grad_head_param) - - # update of cut layer - - tmp = grad_head_input.flatten().asnumpy().tolist() - grad_input = backward_tee_cut_layer(remote_emb.shape[0], remote_emb.shape[1], local_emb.shape[1], 1, tmp) - grad_inputa = self.reshape(Tensor(grad_input[0]), remote_emb.shape) - grad_inputb = self.reshape(Tensor(grad_input[1]), local_emb.shape) - grad_cutlayer_input = (grad_inputb[:, :1], grad_inputb[:, 1:2], grad_inputa[:, :1], grad_inputa[:, 1:2]) - - # update of bottom net - grad_bottom_wide = self.grad_op_param_sens(self.net.bottom_net, - self.params_bottom_wide)(id_hldr, wt_hldr, - grad_cutlayer_input[0:2]) - self.optimizer_bottom_wide(grad_bottom_wide) - grad_bottom_deep = self.grad_op_param_sens(self.net.bottom_net, - self.params_bottom_deep)(id_hldr, wt_hldr, - grad_cutlayer_input[0:2]) - grad_bottom_l2 = self.grad_op_param_sens(self.loss_net_l2, self.params_bottom_deep)(sens) - zipped = zip(grad_bottom_deep, grad_bottom_l2) - grad_bottom_deep = tuple(map(sum, zipped)) - self.optimizer_bottom_deep(grad_bottom_deep) - - # output the gradients for follower party - scales = {} - scales['wide_loss'] = OrderedDict(zip(['wide_embedding', 'deep_embedding'], grad_cutlayer_input[2:4])) - scales['deep_loss'] = scales['wide_loss'] - return scales -``` - -#### Definig the Optimizer - -When defining the optimizer, there is no need to define the backward propagation part covered by `grad_network` in the yaml file, otherwise there is no difference with the usual vfl model to define the optimizer. - -### Constructing the Training Script - -#### Constructing the Network - -As the usual vFL training, users need to use the classes provided by MindSpore Federated to wrap their constructed networks into a vertical federated network. Detailed API documentation can be found in [Vertical Federated Training Interface](https://gitee.com/mindspore/federated/blob/master/docs/api/api_python_en/vertical/vertical_federated_FLModel.rst). The difference is that when constructing the leader network, you need to add `grad_network`. - -```python -from mindspore_federated import FLModel, FLYamlData -from network_config import config -from wide_and_deep import LeaderNet, LeaderLossNet, LeaderGradNet - - -leader_base_net = LeaderNet(config) -leader_train_net = LeaderLossNet(leader_base_net, config) -leader_grad_net = LeaderGradNet(leader_base_net, config) - -leader_yaml_data = FLYamlData(config.leader_yaml_path) -leader_fl_model = FLModel(yaml_data=leader_yaml_data, - network=leader_base_net, - grad_network=Leader_grad_net, - train_network=leader_train_net) -``` - -Except for the above, the rest of TEE training is identical to the usual vFL training, and the user can enjoy the security of TEE once the configuration is completed. diff --git a/docs/federated/docs/source_en/sentiment_classification_application.md b/docs/federated/docs/source_en/sentiment_classification_application.md deleted file mode 100644 index 30ae58fe1f1ffe457c0afe5158b0744ff273d6f1..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/sentiment_classification_application.md +++ /dev/null @@ -1,563 +0,0 @@ -# Implementing a Sentiment Classification Application (Android) - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/sentiment_classification_application.md) - -Through the federated learning modeling approach of cross-device collaboration, the advantages of device-side data can be fully utilized to avoid uploading sensitive user data directly to the cloud side. Since users attach great importance to the privacy of the text they input when using input methods, and the intelligent functions of input methods are important to improve user experience. Therefore, federated learning is naturally applicable to the input method application scenarios. - -MindSpore Federated has applied the Federated Language Model to the emoji image prediction feature of the input method. The Federated Language Model recommends emoji images that are appropriate for the current context based on chat text data. When modeling with federated learning, each emoji image is defined as a sentiment label category, and each chat phrase corresponds to an emoji image. MindSpore Federated defines the emoji image prediction task as a federated sentiment classification task. - -## Preparations - -### Environment - -For details, see [Server Environment Configuration](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_server.html) and [Client Environment Configuration](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_client.html). - -### Data - -The [training data](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/supervise/client.tar.gz) contains 20 user chat files. The directory structure is as follows: - -```text -datasets/supervise/client/ - ├── 0.txt # Training data of user 0 - ├── 1.txt # Training data of user 1 - │ - │ ...... - │ - └── 19.txt # Training data of user 19 -``` - -The [validation data](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/supervise/eval.tar.gz) contains one chat file. The directory structure is as follows: - -```text -datasets/supervise/eval/ - └── eval.txt # Validation data -``` - -The labels in the training data and validation data correspond to four types of emojis: `good`, `leimu`, `xiaoku`, `xin`. - -### Model-related Files - -The directory structures of the [dictionary](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vocab.txt) and the [mapping file of dictionary ID](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vocab_map_ids.txt) related to the model file are as follows: - -```text -datasets/ - ├── vocab.txt # Dictionary - └── vocab_map_ids.txt # Mapping file of Dictionary ID -``` - -## Defining the Network - -The ALBERT language model[1] is used in federated learning. The ALBERT model on the client includes the embedding layer, encoder layer, and classifier layer. - -For details about the network definition, see [source code](https://gitee.com/mindspore/federated/blob/master/tests/st/network/albert.py). - -### Generating a Device-Side Model File - -User can generate a Device-Side Model File as follow, or download the generated [ALBERT Device-Side Model File](https://gitee.com/link?target=https%3A%2F%2Fmindspore-website.obs.cn-north-4.myhuaweicloud.com%2Fnotebook%2Fmodels%2Falbert_supervise.mindir.ms). - -#### Exporting a Model as a MindIR File - -The sample code is as follows: - -```python -import argparse -import os -import random -from time import time -import numpy as np -import mindspore as ms -from mindspore.nn import AdamWeightDecay -from src.config import train_cfg, client_net_cfg -from src.utils import restore_params -from src.model import AlbertModelCLS -from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell - - -def parse_args(): - """ - parse args - """ - parser = argparse.ArgumentParser(description='export task') - parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU']) - parser.add_argument('--device_id', type=str, default='0') - parser.add_argument('--init_model_path', type=str, default='none') - parser.add_argument('--output_dir', type=str, default='./models/mindir/') - parser.add_argument('--seed', type=int, default=0) - return parser.parse_args() - - -def supervise_export(args_opt): - ms.set_seed(args_opt.seed), random.seed(args_opt.seed) - start = time() - # Parameter configuration - os.environ['CUDA_VISIBLE_DEVICES'] = args_opt.device_id - init_model_path = args_opt.init_model_path - output_dir = args_opt.output_dir - if not os.path.exists(output_dir): - os.makedirs(output_dir) - print('Parameters setting is done! Time cost: {}'.format(time() - start)) - start = time() - - # MindSpore configuration - ms.set_context(mode=ms.GRAPH_MODE, device_target=args_opt.device_target) - print('Context setting is done! Time cost: {}'.format(time() - start)) - start = time() - - # Build model - albert_model_cls = AlbertModelCLS(client_net_cfg) - network_with_cls_loss = NetworkWithCLSLoss(albert_model_cls) - network_with_cls_loss.set_train(True) - print('Model construction is done! Time cost: {}'.format(time() - start)) - start = time() - - # Build optimizer - client_params = [_ for _ in network_with_cls_loss.trainable_params()] - client_decay_params = list( - filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, client_params) - ) - client_other_params = list( - filter(lambda x: not train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter(x), client_params) - ) - client_group_params = [ - {'params': client_decay_params, 'weight_decay': train_cfg.optimizer_cfg.AdamWeightDecay.weight_decay}, - {'params': client_other_params, 'weight_decay': 0.0}, - {'order_params': client_params} - ] - client_optimizer = AdamWeightDecay(client_group_params, - learning_rate=train_cfg.client_cfg.learning_rate, - eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps) - client_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=client_optimizer) - print('Optimizer construction is done! Time cost: {}'.format(time() - start)) - start = time() - - # Construct data - input_ids = ms.Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), np.int32)) - attention_mask = ms.Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), np.int32)) - token_type_ids = ms.Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), np.int32)) - label_ids = ms.Tensor(np.zeros((train_cfg.batch_size,), np.int32)) - print('Client data loading is done! Time cost: {}'.format(time() - start)) - start = time() - - # Read checkpoint - if init_model_path != 'none': - init_param_dict = ms.load_checkpoint(init_model_path) - restore_params(client_network_train_cell, init_param_dict) - print('Checkpoint loading is done! Time cost: {}'.format(time() - start)) - start = time() - - # Export - ms.export(client_network_train_cell, input_ids, attention_mask, token_type_ids, label_ids, - file_name=os.path.join(output_dir, 'albert_supervise'), file_format='MINDIR') - print('Supervise model export process is done! Time cost: {}'.format(time() - start)) - - -if __name__ == '__main__': - total_time_start = time() - args = parse_args() - supervise_export(args) - print('All is done! Time cost: {}'.format(time() - total_time_start)) - -``` - -#### Converting the MindIR File into an MS File that Can be Used by the Federated Learning Framework on the Device - -For details about how to generate a model file on the device, see [Implementing an Image Classification Application](https://www.mindspore.cn/federated/docs/en/master/image_classification_application.html). - -## Starting the Federated Learning Process - -Start the script on the server. For details, see [Cloud-based Deployment](https://www.mindspore.cn/federated/docs/en/master/deploy_federated_server.html). -For corresponding cloud-side configuration and model weights document, refer to [albert example](https://gitee.com/mindspore/federated/tree/master/example/cross_device_albert). - -Based on the training and inference tasks of the ALBERT model, the overall process is as follows: - -1. Create an Android project. - -2. Build the MindSpore Lite AAR package. - -3. Describe the Android instance program structure. - -4. Write code. - -5. Configure Android project dependencies. - -6. Build and run on Android. - -### Creating an Android Project - -Create a project in Android Studio and install the corresponding SDK. (After the SDK version is specified, Android Studio automatically installs the SDK.) - -![New project](./images/create_android_project.png) - -### Obtaining a Related Package - -1. Obtain MindSpore Lite AAR package - - For details, see [MindSpore Lite](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html). - - ```text - mindspore-lite-full-{version}.aar - ``` - -2. Obtain MindSpore Federated device-side jar package - - For details, see [On-Device Deployment](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html). - - ```text - mindspore_federated/device_client/build/libs/jarAAR/mindspore-lite-java-flclient.jar - ``` - -3. Place the AAR package in the app/libs/ directory of the Android project. - -### Android Instance Program Structure - -```text -app -│ ├── libs # Binary archive file of the Android library project -| | ├── mindspore-lite-full-{version}.aar # MindSpore Lite archive file of the Android version -| | └── mindspore-lite-java-flclient.jar # MindSpore Federate archive file of the Android version -├── src/main -│ ├── assets # Resource directory -| | └── model # Model directory -| | └── albert_supervise.mindir.ms # Pre-trained model file -│ | └── albert_inference.mindir.ms # Inference model file -│ | └── data # Data directory -| | └── 0.txt # training data file -| | └── vocab.txt # Dictionary file -| | └── vocab_map_ids.txt # Dictionary ID mapping file -| | └── eval.txt # Training result evaluation file -| | └── eval_no_label.txt # Inference data file -│ | -│ ├── java # Application code at the Java layer -│ │ └── ... Storing Android code files. Related directories can be customized. -│ │ -│ ├── res # Resource files related to Android -│ └── AndroidManifest.xml # Android configuration file -│ -│ -├── build.gradle # Android project build file -├── download.gradle # Downloading the project dependency files -└── ... -``` - -### Writing Code - -1. AssetCopyer.java: This code file is used to store the resource files in the app/src/main/assets directory of the Android project to the disk of the Android system. In this way, the federated learning framework API can read the resource files based on the absolute path during model training and inference. - - ```java - import android.content.Context; - import java.io.File; - import java.io.FileOutputStream; - import java.io.InputStream; - import java.util.logging.Logger; - public class AssetCopyer { - private static final Logger LOGGER = Logger.getLogger(AssetCopyer.class.toString()); - public static void copyAllAssets(Context context,String destination) { - LOGGER.info("destination: " + destination); - copyAssetsToDst(context,"",destination); - } - // Copy the resource files in the assets directory to the disk of the Android system. You can view the specific path by printing destination. - private static void copyAssetsToDst(Context context,String srcPath, String dstPath) { - try { - // Recursively obtain all file names in the assets directory. - String[] fileNames =context.getAssets().list(srcPath); - if (fileNames.length > 0) { - // Build the destination file object. - File file = new File(dstPath); - // Create a destination directory. - file.mkdirs(); - for (String fileName : fileNames) { - // Copy the file to the specified disk. - if(!srcPath.equals("")) { - copyAssetsToDst(context,srcPath + "/" + fileName,dstPath+"/"+fileName); - }else{ - copyAssetsToDst(context, fileName,dstPath+"/"+fileName); - } - } - } else { - // Build the input stream of the source file. - InputStream is = context.getAssets().open(srcPath); - // Build the output stream of the destination file. - FileOutputStream fos = new FileOutputStream(new File(dstPath)); - // Define a 1024-byte buffer array. - byte[] buffer = new byte[1024]; - int byteCount=0; - // Write the source file to the destination file. - while((byteCount=is.read(buffer))!=-1) { - fos.write(buffer, 0, byteCount); - } - // Refresh the output stream. - fos.flush(); - // Close the input stream. - is.close(); - // Close the output stream. - fos.close(); - } - } catch (Exception e) { - e.printStackTrace(); - } - } - } - ``` - -2. FlJob.java: This code file is used to define training and inference tasks. For details about federated learning APIs, see [federated Learning APIs](https://www.mindspore.cn/federated/docs/en/master/interface_description_federated_client.html). - - ```java - import android.annotation.SuppressLint; - import android.os.Build; - import androidx.annotation.RequiresApi; - import com.mindspore.flAndroid.utils.AssetCopyer; - import com.mindspore.flclient.FLParameter; - import com.mindspore.flclient.SyncFLJob; - import java.util.Arrays; - import java.util.UUID; - import java.util.logging.Logger; - public class FlJob { - private static final Logger LOGGER = Logger.getLogger(AssetCopyer.class.toString()); - private final String parentPath; - public FlJob(String parentPath) { - this.parentPath = parentPath; - } - // Android federated learning training task - @SuppressLint("NewApi") - @RequiresApi(api = Build.VERSION_CODES.M) - public void syncJobTrain() { - // create dataMap - String trainTxtPath = "data/albert/supervise/client/1.txt"; - String evalTxtPath = "data/albert/supervise/eval/eval.txt"; // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - String vocabFile = "data/albert/supervise/vocab.txt"; // Path of the dictionary file for data preprocessing. - String idsFile = "data/albert/supervise/vocab_map_ids.txt" // Path of the mapping ID file of a dictionary. - Map> dataMap = new HashMap<>(); - List trainPath = new ArrayList<>(); - trainPath.add(trainTxtPath); - trainPath.add(vocabFile); - trainPath.add(idsFile); - List evalPath = new ArrayList<>(); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(evalTxtPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(vocabFile); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - evalPath.add(idsFile); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - dataMap.put(RunType.TRAINMODE, trainPath); - dataMap.put(RunType.EVALMODE, evalPath); // Not necessary, if you don't need verify model accuracy after getModel, you don't need to set this parameter - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // The package path of AlBertClient.java - String trainModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - - // The url for device-cloud communication. Ensure that the Android device can access the server. Otherwise, the message "connection failed" is displayed. - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(ifUseElb); - flParameter.setServerNum(serverNum); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - - // start FLJob - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.flJobRun(); - } - // Android federated learning inference task - public void syncJobPredict() { - // create dataMap - String inferTxtPath = "data/albert/supervise/eval/eval.txt"; - String vocabFile = "data/albert/supervise/vocab.txt"; - String idsFile = "data/albert/supervise/vocab_map_ids.txt"; - Map> dataMap = new HashMap<>(); - List inferPath = new ArrayList<>(); - inferPath.add(inferTxtPath); - inferPath.add(vocabFile); - inferPath.add(idsFile); - dataMap.put(RunType.INFERMODE, inferPath); - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // The package path of AlBertClient.java - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // Absolute path, consistent with trainModelPath - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setInferModelPath(inferModelPath); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(cpuBindMode); - flParameter.setBatchSize(batchSize); - - // inference - SyncFLJob syncFLJob = new SyncFLJob(); - int[] labels = syncFLJob.modelInference(); - LOGGER.info("labels = " + Arrays.toString(labels)); - } - } - ``` - - The above eval_no_label.txt refers to a file where no label exists, with one statement per line. The format reference is as follows, which the user is free to set: - - ```text - 愿以吾辈之青春 护卫这盛世之中华🇨🇳 - girls help girls - 太美了,祝祖国繁荣昌盛! - 中国人民站起来了 - 难道就我一个人觉得这个是plus版本? - 被安利到啦!明天起来就看!早点睡觉莲莲 - ``` - -3. MainActivity.java: This code file is used to start federated learning training and inference tasks. - - ```java - import android.os.Build; - import android.os.Bundle; - import androidx.annotation.RequiresApi; - import androidx.appcompat.app.AppCompatActivity; - import com.huawei.flAndroid.job.FlJob; - import com.huawei.flAndroid.utils.AssetCopyer; - @RequiresApi(api = Build.VERSION_CODES.P) - public class MainActivity extends AppCompatActivity { - private String parentPath; - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - // Obtain the disk path of the application in the Android system. - this.parentPath = this.getExternalFilesDir(null).getAbsolutePath(); - // Copy the resource files in the assets directory to the disk of the Android system. - AssetCopyer.copyAllAssets(this.getApplicationContext(), parentPath); - // Create a thread and start the federated learning training and inference tasks. - new Thread(() -> { - FlJob flJob = new FlJob(parentPath); - flJob.syncJobTrain(); - flJob.syncJobPredict(); - }).start(); - } - } - ``` - -### Configuring Android Project Dependencies - -1. AndroidManifest.xml - - ```xml - - - - - - - - - - - - - - - ``` - -2. app/build.gradle - - ```text - plugins { - id 'com.android.application' - } - android { - // Android SDK build version. It is recommended that the version be later than 27. - compileSdkVersion 30 - buildToolsVersion "30.0.3" - defaultConfig { - applicationId "com.mindspore.flAndroid" - minSdkVersion 27 - targetSdkVersion 30 - versionCode 1 - versionName "1.0" - multiDexEnabled true - testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - ndk { - // Different mobile phone models correspond to different NDKs. Mate 20 corresponds to 'armeabi-v7a'. - abiFilters 'armeabi-v7a' - } - } - // Specified NDK version - ndkVersion '21.3.6528147' - sourceSets{ - main { - // Specified JNI directory - jniLibs.srcDirs = ['libs'] - jni.srcDirs = [] - } - } - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 - } - } - dependencies { - // AAR package to be scanned in the libs directory - implementation fileTree(dir:'libs',include:['*.aar', '*.jar']) - implementation 'androidx.appcompat:appcompat:1.1.0' - implementation 'com.google.android.material:material:1.1.0' - implementation 'androidx.constraintlayout:constraintlayout:1.1.3' - androidTestImplementation 'androidx.test.ext:junit:1.1.1' - androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0' - implementation 'com.android.support:multidex:1.0.3' - - // Add third-party open source software that federated learning relies on - implementation group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.9' - implementation group: 'com.google.flatbuffers', name: 'flatbuffers-java', version: '2.0.0' - implementation(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.68') - } - ``` - -### Building and Running on Android - -1. Connect to the Android device and run federated learning training and inference applications. Connect to the Android device through a USB cable for debugging. Click `Run 'app'` to run the federated learning task on your device. - - ![run_app](./images/start_android_project.png) - -2. For details about how to connect the Android Studio to a device for debugging, see . Android Studio can identify the mobile phone only when USB debugging mode is enabled on the mobile phone. For Huawei phones, enable USB debugging mode by choosing `Settings > System & updates > Developer options > USB debugging`. - -3. Continue the installation on the Android device. After the installation is complete, you can start the app to train and infer the ALBERT model for federated learning. - -4. The program running result is as follows: - - ```text - I/SyncFLJob: [model inference] inference finish - I/SyncFLJob: labels = [2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4] - ``` - -## Results - -The total number of federated learning iterations is 10, the number of client-side local training epochs is 1, and the batchSize is set to 16. - -```text - total acc:0.44488978 - total acc:0.583166333 - total acc:0.609218437 - total acc:0.645290581 - total acc:0.667334669 - total acc:0.685370741 - total acc:0.70741483 - total acc:0.711422846 - total acc:0.719438878 - total acc:0.733466934 -``` - -## References - -[1] Lan Z, Chen M , Goodman S, et al. ALBERT: A Lite BERT for Self-supervised Learning of Language Representations[J]. 2019. diff --git a/docs/federated/docs/source_en/split_pangu_alpha_application.md b/docs/federated/docs/source_en/split_pangu_alpha_application.md deleted file mode 100644 index 3f9594025a94b9094251407dad3504a116f04468..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/split_pangu_alpha_application.md +++ /dev/null @@ -1,338 +0,0 @@ -# Vertical Federated Learning Model Training - Pangu Alpha Large Model Cross-Domain Training - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/split_pangu_alpha_application.md) - -## Overview - -With the advancement of hardware computing power and the continuous expansion of network data size, pre-training large models has increasingly become an important research direction in fields such as natural language processing and graphical multimodality. Take Pangu Alpha, which released a large pre-trained model of Chinese NLP in 2021, as an example, the number of model parameters reaches 200 billion, and the training process relies on massive data and advanced computing centers, which limits its application landing and technology evolution. A feasible solution is to integrate the computing power and data resources of multiple participants based on vertical federated learning or split learning techniques to achieve cross-domain collaborative training of pre-trained large models while ensuring security and privacy. - -MindSpore Federated provides a vertical federated learning base functional component based on split learning. This sample provides a federated learning training sample for large NLP models by taking the Pangaea alpha model as an example. - -![Implementing cross-domain training for the Pangu Alpha large model](./images/splitnn_pangu_alpha_en.png) - -As shown in the figure above, in this case, the Pangaea α model is sliced into three sub-networks, such as Embedding, Backbone and Head. The front-level subnetwork Embedding and the end-level subnetwork Head are deployed in the network domain of participant A, and the Backbone subnetwork containing multi-level Transformer modules is deployed in the network domain of participant B. The Embedding subnetwork and Head subnetwork read the data held by participant A and dominate the training and inference tasks for performing the Pangaea α model. - -* In the forward inference stage, Participant A uses the Embedding subnetwork to process the original data and transmits the output Embedding Feature tensor and Attention Mask Feature tensor to Participant B as the input of Participant B Backbone subnetwork. Then, Participant A reads the Hide State Feature tensor output from the Backbone subnetwork as the input of Participant A Head subnetwork, and finally the predicted result or loss value is output by the Head sub-network. - -* In the backward propagation phase, after completing the gradient calculation and parameter update of the Head subnetwork, Participant A transmits the gradient tensor associated with the Hide State Feature tensor to Participant B for the gradient calculation and parameter update of the Backbone subnetwork. Then, Participant B transmits the gradient tensor associated with the Embedding Feature tensor to Participant A for the gradient calculation and parameter update of the Embedding subnetwork after completing the gradient calculation and parameter update of the Backbone subnetwork. - -The feature tensor and gradient tensor exchanged between participant A and participant B during the above forward inference and backward propagation are processed by using privacy security mechanisms and encryption algorithms, so that it is not necessary to transmit the data held by participant A to participant B for implementing the collaboration training of the network model by the two participants. Due to the small number of Embedding and Head subnetwork parameters and the huge number of Backbone subnetwork parameters, this sample application is suitable for the large model collaboration training or deployment between the service side (corresponding to participant A) and the computing center (corresponding to participant B). - -For a detailed introduction to the pangu α model principles, please refer to [MindSpore ModelZoo - pangu_alpha](https://gitee.com/mindspore/models/tree/master/official/nlp/Pangu_alpha), [Introduction to Pengcheng -pangu α](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha), and its [research paper](https://arxiv.org/pdf/2104.12369.pdf). - -## Preparation - -### Environment Preparation - -1. Refer to [Obtaining MindSpore Federated](https://mindspore.cn/federated/docs/en/master/federated_install.html) to install MindSpore version 1.8.1 and above and MindSpore Federated. - -2. Download the MindSpore Federated code and install the Python packages that this sample application depends on. - - ```bash - git https://gitee.com/mindspore/federated.git - cd federated/example/splitnn_pangu_alpha/ - python -m pip install -r requirements.txt - ``` - -### Dataset Preparation - -Before running the sample, refer to [MindSpore ModelZoo - pangu_alpha - Dataset Generation](https://gitee.com/mindspore/models/tree/master/official/nlp/Pangu_alpha#dataset-generation) and use the preprocess.py script to convert the raw text corpus for training into a dataset that can be used for model training. - -## Defining the Vertical Federated Learning Training Process - -MindSpore Federated Vertical Federated Learning Framework uses FLModel (see [Vertical Federated Learning Model Training Interface](https://mindspore.cn/federated/docs/en/master/vertical/vertical_federated_FLModel.html)) and yaml files (see [Yaml Configuration file for model training of vertical federated learning](https://mindspore.cn/federated/docs/en/master/vertical/vertical_federated_yaml.html)), to model vertical federated learning training process. - -### Defining the Network Model - -1. Call the function components provided by MindSpore and take nn.Cell (see [mindspore.nn.Cell](https://mindspore.cn/docs/en/master/api_python/nn/mindspore.nn.Cell.html#mindspore-nn-cell)) as a base class to program the training network of this participant to be involved in vertical federated learning. Taking the Embedding subnetwork of participant A in this application practice as an example, [sample code](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/src/split_pangu_alpha.py) is as follows: - - ```python - class EmbeddingLossNet(nn.Cell): - """ - Train net of the embedding party, or the tail sub-network. - Args: - net (class): EmbeddingLayer, which is the 1st sub-network. - config (class): default config info. - """ - - def __init__(self, net: EmbeddingLayer, config): - super(EmbeddingLossNet, self).__init__(auto_prefix=False) - - self.batch_size = config.batch_size - self.seq_length = config.seq_length - dp = config.parallel_config.data_parallel - self.eod_token = config.eod_token - self.net = net - self.slice = P.StridedSlice().shard(((dp, 1),)) - self.not_equal = P.NotEqual().shard(((dp, 1), ())) - self.batch_size = config.batch_size - self.len = config.seq_length - self.slice2 = P.StridedSlice().shard(((dp, 1, 1),)) - - def construct(self, input_ids, position_id, attention_mask): - """forward process of FollowerLossNet""" - tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) - embedding_table, word_table = self.net(tokens, position_id, batch_valid_length=None) - return embedding_table, word_table, position_id, attention_mask - ``` - -2. In the yaml configuration file, describe the corresponding name, input, output and other information of the training network. Taking the Embedding subnetwork of Participant A in this application practice, [example code](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/embedding.yaml) is as follows: - - ```yaml - train_net: - name: follower_loss_net - inputs: - - name: input_ids - source: local - - name: position_id - source: local - - name: attention_mask - source: local - outputs: - - name: embedding_table - destination: remote - - name: word_table - destination: remote - - name: position_id - destination: remote - - name: attention_mask - destination: remote - ``` - - The `name` field is the name of the training network and will be used to name the checkpoints file saved during the training process. The `inputs` field is the list of input tensor in the training network, and the `outputs` field is the list of output tensor in the training network. - - The `name` fields under the `inputs` and `outputs` fields are the input/output tensor names. The names and order of the input/output tensors need to correspond strictly to the inputs/outputs of the `construct` method in the corresponding Python code of the training network. - - `source` under the `inputs` field identifies the data source of the input tensor, with `local` representing that the input tensor is loaded from local data and `remote` representing that the input tensor is from network transmission of other participants. - - `destination` under the `outputs` field identifies the destination of the output tensor, with `local` representing the output tensor for local use only, and `remote` representing that the output tensor is transferred to other participants via networks. - -3. Optionally, a similar approach is used to model the assessment network of vertical federated learning that this participant is to be involved. - -### Defining the Optimizer - -1. Call the functional components provided by MindSpore, to program the optimizer for parameter updates of this participant training network. As an example of a custom optimizer used by Participant A for Embedding subnetwork training in this application practice, [sample code](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/src/pangu_optim.py) is as follows: - - ```python - class PanguAlphaAdam(TrainOneStepWithLossScaleCell): - """ - Customized Adam optimizer for training of pangu_alpha in the splitnn demo system. - """ - def __init__(self, net, optim_inst, scale_update_cell, config, yaml_data) -> None: - # Custom optimizer-related operators - ... - - def __call__(self, *inputs, sens=None): - # Define the gradient calculation and parameter update process - ... - ``` - - Developers can customize the input and output of the `__init__` method in the optimizer class, but the input of the `__call__` method in the optimizer class needs to contain only `inputs` and `sens`. `inputs` is of type `list`, corresponding to the input tensor list of the training network, and its elements are of type `mindspore.Tensor`. `sens` is of type `dict`, which saves the weighting coefficients used to calculate the gradient values of the training network parameters, and its key is a gradient weighting coefficient identifier of type `str`. Value is of type `dict`, whose key is of type `str`, and it is the name of the output tensor of the training network. Value is of type `mindspore.Tensor`, which is the weighting coefficient of the training network parameter gradient values corresponding to this output tensor. - -2. In the yaml configuration file, describe the corresponding gradient calculation, parameter update, and other information of the optimizer. The [sample code](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/embedding.yaml) is as follows: - - ```yaml - opts: - - type: CustomizedAdam - grads: - - inputs: - - name: input_ids - - name: position_id - - name: attention_mask - output: - name: embedding_table - sens: hidden_states - - inputs: - - name: input_ids - - name: position_id - - name: attention_mask - output: - name: word_table - sens: word_table - params: - - name: word_embedding - - name: position_embedding - hyper_parameters: - learning_rate: 5.e-6 - eps: 1.e-8 - loss_scale: 1024.0 - ``` - - The `type` field is of the optimizer type. Here is the developer-defined optimizer. - - The `grads` field is a list of `GradOperation` associated with the optimizer, which will use the `GradOperation` operator in the list to compute the output gradient values and update the training network parameters. The `inputs` and `output` fields are input and output tensor lists of the `GradOperation` operator, whose elements are an input/output tensor name, respectively. The `sens` field is the gradient weighting coefficient or the sensitivity identifier of the `GradOperation` operator (refer to [mindspore.ops.GradOperation](https://mindspore.cn/docs/en/master/api_python/ops/mindspore.ops.GradOperation.html?highlight=gradoperation)). - - The `params` field is a list of training network parameter names to be updated by the optimizer, whose elements are the names of one training network parameter each. In this example, the custom optimizer will update the network parameters with the `word_embedding` string and the `position_embedding` string in their names. - - The `hyper_parameters` field is a list of hyperparameters for the optimizer. - -### Defining Gradient Weighting Coefficient Calculation - -According to the chain rule of gradient calculation, the subnetwork located at the backstream of the global network needs to calculate the gradient value of its output tensor relative to the input tensor, i.e., the gradient weighting coefficient or sensitivity, to be passed to the sub-network located at the upstream of the global network for its training parameter update. - -MindSpore Federated uses the `GradOperation` operator to complete the above gradient weighting coefficient or sensitivity calculation process. The developer needs to describe the `GradOperation` operator used to calculate the gradient weighting coefficients in the yaml configuration file. Taking Head of participant A in this application practice as an example, [sample code](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/head.yaml) is as follows: - -```yaml -grad_scalers: - - inputs: - - name: hidden_states - - name: input_ids - - name: word_table - - name: position_id - - name: attention_mask - output: - name: output - sens: 1024.0 -``` - -The `inputs` and `output` fields are lists of input and output tensors of the `GradOperation` operator, whose elements are input/output tensor names, respectively. The `sens` field is the gradient weighting coefficient or sensitivity of this `GradOperation` operator (refer to [mindspore.ops.GradOperation](https://mindspore.cn/docs/en/master/api_python/ops/mindspore.ops.GradOperation.html?highlight=gradoperation)). If it is a `float` or `int` type value, a constant tensor will be constructed as the gradient weighting coefficient. If it is a `str` type string, the tensor corresponding to the name will be parsed as a weighting coefficient from the weighting coefficients transmitted by the other participants via the network. - -### Executing the Training - -1. After completing the above Python programming development and yaml configuration file, the `FLModel` class and `FLYamlData` class provided by MindSpore Federated are used to build the vertical federated learning process. Taking the Embedding subnetwork of participant A in this application practice as an example, [sample code](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/run_pangu_train_local.py) is as follows: - - ```python - embedding_yaml = FLYamlData('./embedding.yaml') - embedding_base_net = EmbeddingLayer(config) - embedding_eval_net = embedding_train_net = EmbeddingLossNet(embedding_base_net, config) - embedding_with_loss = _VirtualDatasetCell(embedding_eval_net) - embedding_params = embedding_with_loss.trainable_params() - embedding_group_params = set_embedding_weight_decay(embedding_params) - embedding_optim_inst = FP32StateAdamWeightDecay(embedding_group_params, lr, eps=1e-8, beta1=0.9, beta2=0.95) - embedding_optim = PanguAlphaAdam(embedding_train_net, embedding_optim_inst, update_cell, config, embedding_yaml) - - embedding_fl_model = FLModel(yaml_data=embedding_yaml, - network=embedding_train_net, - eval_network=embedding_eval_net, - optimizers=embedding_optim) - ``` - - The `FLYamlData` class mainly completes the parsing and verification of yaml configuration files, and the `FLModel` class mainly provides the control interface for vertical federated learning training, inference and other processes. - -2. Call the interface methods of the `FLModel` class to perform vertical federated learning training. Taking the Embedding subnetwork of participant A in this application practice as an example, [sample code](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/run_pangu_train_local.py) is as follows: - - ```python - if opt.resume: - embedding_fl_model.load_ckpt() - ... - for epoch in range(50): - for step, item in enumerate(train_iter, start=1): - # forward process - step = epoch * train_size + step - embedding_out = embedding_fl_model.forward_one_step(item) - ... - # backward process - embedding_fl_model.backward_one_step(item, sens=backbone_scale) - ... - if step % 1000 == 0: - embedding_fl_model.save_ckpt() - ``` - - The `forward_one_step` method and the `backward_one_step` method perform the forward inference and backward propagation operations of a data batch, respectively. The `load_ckpt` method and the `save_ckpt` method perform the checkpoints file loading and saving operations respectively. - -## Running the Example - -This example provides 2 sample programs, both running as shell scripts to pull up Python programs. - -1. `run_pangu_train_local.sh`: Single-process example program. Participant A and participant B are trained in the same process, which transmits the feature tensor and gradient tensor directly to the other participant in the form of intra-program variables. - -2. `run_pangu_train_leader.sh` and `run_pangu_train_follower.sh`: Multi-process example program. Participant A and participant B run a separate process, which encapsulates the feature tensor and gradient tensor as protobuf messages, respectively, and transmits them to the other participant via the https communication interface. `run_pangu_train_leader.sh` and `run_pangu_train_follower.sh` can be run on two servers separately to achieve cross-domain collaboration training. - -3. The current vertical federated distributed training supports https cross-domain encrypted communication. The startup command is as follows: - - ```bash - # Start the leader process in https encrypted communication mode: - bash run_pangu_train_leader.sh 127.0.0.1:10087 127.0.0.1:10086 /path/to/train/data_set /path/to/eval/data_set True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - - # Start the follower process in https encrypted communication mode: - bash run_pangu_train_follower.sh 127.0.0.1:10086 127.0.0.1:10087 True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - ``` - -### Running a Single-Process Example - -Taking `run_pangu_train_local.sh` as an example, run the sample program as follows: - -1. Go to the sample program directory: - - ```bash - cd federated/example/splitnn_pangu_alpha/ - ``` - -2. Taking the wiki dataset as an example, copy the dataset to the sample program directory: - - ```bash - cp -r {dataset_dir}/wiki ./ - ``` - -3. Install the dependent Python packages: - - ```bash - python -m pip install -r requirements.txt - ``` - -4. Modify `src/utils.py` to configure parameters such as checkpoint file load path, training dataset path, and evaluation dataset path. Examples are as follows: - - ```python - parser.add_argument("--load_ckpt_path", type=str, default='./checkpoints', help="predict file path.") - parser.add_argument('--data_url', required=False, default='./wiki/train/', help='Location of data.') - parser.add_argument('--eval_data_url', required=False, default='./wiki/eval/', help='Location of eval data.') - ``` - -5. Execute the training script: - - ```bash - ./run_pangu_train_local.sh - ``` - -6. View the training loss information recorded in the training log `splitnn_pangu_local.txt`. - - ```text - INFO:root:epoch 0 step 10/43391 loss: 10.616087 - INFO:root:epoch 0 step 20/43391 loss: 10.424824 - INFO:root:epoch 0 step 30/43391 loss: 10.209235 - INFO:root:epoch 0 step 40/43391 loss: 9.950026 - INFO:root:epoch 0 step 50/43391 loss: 9.712448 - INFO:root:epoch 0 step 60/43391 loss: 9.557744 - INFO:root:epoch 0 step 70/43391 loss: 9.501564 - INFO:root:epoch 0 step 80/43391 loss: 9.326054 - INFO:root:epoch 0 step 90/43391 loss: 9.387547 - INFO:root:epoch 0 step 100/43391 loss: 8.795234 - ... - ``` - - The corresponding visualization results are shown below, where the horizontal axis is the number of training steps, the vertical axis is the loss value, the red curve is the Pangu α training loss value, and the blue curve is the Pangu α training loss value based on splitting learning in this example. The trend of decreasing loss values is basically the same, and the correctness of the training process can be verified considering that the initialization of the network parameter values has randomness. - - ![Cross-domain training results of the Pangu alpha large model](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/docs/federated/docs/source_zh_cn/images/splitnn_pangu_alpha_result.png) - -### Running a Multi-Process Example - -1. Similar to the single-process example, go to the sample program directory, and install the dependent Python packages: - - ```bash - cd federated/example/splitnn_pangu_alpha/ - python -m pip install -r requirements.txt - ``` - -2. Copy the dataset to the sample program directory on Server 1: - - ```bash - cp -r {dataset_dir}/wiki ./ - ``` - -3. Start the training script for Participant A on Server 1: - - ```bash - ./run_pangu_train_leader.sh {ip_address_server1} {ip_address_server2} ./wiki/train ./wiki/train - ``` - - The first parameter of the training script is the IP address and port number of the local server (Server 1), and the second parameter is the IP address and port number of the peer server (Server 2). The third parameter is the training dataset file path. The fourth parameter is the evaluation dataset file path, and the fifth parameter identifies whether to load an existing checkpoint file. - -4. Start the training script for Participant B on Server 2. - - ```bash - ./run_pangu_train_follower.sh {ip_address_server2} {ip_address_server1} - ``` - - The first parameter of the training script is the IP address and port number of the local server (Server 2), and the second parameter is the IP address and port number of the peer server (Server 2). The third parameter identifies whether to load an existing checkpoint file. - -5. Check the training loss information recorded in the training log `leader_processs.log` of Server 1. If the trend of its loss information is consistent with that of the centralized training loss values of Pangaea α, the correctness of the training process can be verified. \ No newline at end of file diff --git a/docs/federated/docs/source_en/split_wnd_application.md b/docs/federated/docs/source_en/split_wnd_application.md deleted file mode 100644 index 13b8f996eeaa37de6d07d60160f3a1bb8eea3f7f..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/split_wnd_application.md +++ /dev/null @@ -1,280 +0,0 @@ -# Vertical Federated Learning Model Training - Wide&Deep Recommendation Application - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/split_wnd_application.md) - -## Overview - -MindSpore Federated provides a vertical federated learning infrastructure component based on Split Learning. - -Vertical FL model training scenarios: including two stages of forward propagation and backward propagation/parameter update. - -Forward propagation: After the data intersection module processes the parameter-side data and aligns the feature information and label information, the Follower participant inputs the local feature information into the precursor network model, and the feature tensor output from the precursor network model is encrypted/scrambled by the privacy security module and transmitted to the Leader participant by the communication module. The Leader participants input the received feature tensor into the post-level network model, and the predicted values and local label information output from the post-level network model are used as the loss function input to calculate the loss values. - -![](./images/vfl_forward_en.png) - -Backward propagation: The Leader participant calculates the parameter gradient of the backward network model based on the loss value, trains and updates the parameters of the backward network model, and transmits the gradient tensor associated with the feature tensor to the Follower participant by the communication module after encrypted and scrambled by the privacy security module. The Follower participant uses the received gradient tensor for training and update of of frontward network model parameters. - -![](./images/vfl_backward_en.png) - -Vertical FL model inference scenario: similar to the forward propagation phase of the training scenario, but with the predicted values of the backward network model directly as the output, without calculating the loss values. - -## Network and Data - -![](./images/splitnn_wide_and_deep_en.png) - -This sample provides a federated learning training example for recommendation-oriented tasks by using Wide&Deep network and Criteo dataset as examples. As shown above, in this case, the vertical federated learning system consists of the Leader participant and the Follower participant. Among them, the Leader participant holds 20×2 dimensional feature information and label information, and the Follower participant holds 19×2 dimensional feature information. Leader participant and Follower participant deploy 1 set of Wide&Deep network respectively, and realize the collaborative training of the network model by exchanging embedding vectors and gradient vectors without disclosing the original features and label information. - -For a detailed description of the principle properties of Wide&Deep networks, see [MindSpore ModelZoo - Wide&Deep - Wide&Deep Overview](https://gitee.com/mindspore/models/blob/master/official/recommend/Wide_and_Deep/README.md#widedeep-description) and its [research paper](https://arxiv.org/pdf/1606.07792.pdf). - -## Dataset Preparation - -This sample is based on the Criteo dataset for training and testing. Before running the sample, you need to refer to [MindSpore ModelZoo - Wide&Deep - Quick Start](https://gitee.com/mindspore/models/blob/master/official/recommend/Wide_and_Deep/README.md#quick-start) to pre-process the Criteo dataset. - -1. Clone MindSpore ModelZoo code. - - ```shell - git clone https://gitee.com/mindspore/models.git - cd models/official/recommend/Wide_and_Deep - ``` - -2. Download the dataset - - ```shell - mkdir -p data/origin_data && cd data/origin_data - wget http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz - tar -zxvf criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz - ``` - -3. Use this script to pre-process the data. The preprocessing process may take up to an hour and the generated MindRecord data is stored in the data/mindrecord path. The preprocessing process consumes a lot of memory, so it is recommended to use a server. - - ```shell - cd ../.. - python src/preprocess_data.py --data_path=./data/ --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0 - ``` - -## Quick Experience - -This sample runs as a Shell script pulling up a Python program. - -1. Refer to [MindSpore website guidance](https://www.mindspore.cn/install), installing MindSpore 1.8.1 or higher. - -2. Use to install the Python libraries that MindSpore Federated depends on. - - ```shell - cd federated - python -m pip install -r requirements_test.txt - ``` - -3. Copy the Criteo dataset after [preprocessing](#dataset-preparation) to this directory. - - ```shell - cd tests/example/splitnn_criteo - cp -rf ${DATA_ROOT_PATH}/data/mindrecord/ ./ - ``` - -4. Run the sample program to start the script. - - ```shell - # start leader: - bash run_vfl_train_leader.sh 127.0.0.1:10087 127.0.0.1:10086 /path/to/data_set False - - # start follower: - bash run_vfl_train_follower.sh 127.0.0.1:10086 127.0.0.1:10087 /path/to/data_set False - ``` - - or - - ```shell - # Start the leader process with https encrypted communication: - bash run_vfl_train_leader.sh 127.0.0.1:10087 127.0.0.1:10086 /path/to/data_set True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - - # Start the follower process using https encrypted communication: - bash run_vfl_train_follower.sh 127.0.0.1:10086 127.0.0.1:10087 /path/to/data_set True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - ``` - -5. View training log `log_local_gpu.txt`. - - ```text - INFO:root:epoch 0 step 100/2582 wide_loss: 0.528141 deep_loss: 0.528339 - INFO:root:epoch 0 step 200/2582 wide_loss: 0.499408 deep_loss: 0.499410 - INFO:root:epoch 0 step 300/2582 wide_loss: 0.477544 deep_loss: 0.477882 - INFO:root:epoch 0 step 400/2582 wide_loss: 0.474377 deep_loss: 0.476771 - INFO:root:epoch 0 step 500/2582 wide_loss: 0.472926 deep_loss: 0.475157 - INFO:root:epoch 0 step 600/2582 wide_loss: 0.464844 deep_loss: 0.467011 - INFO:root:epoch 0 step 700/2582 wide_loss: 0.464496 deep_loss: 0.466615 - INFO:root:epoch 0 step 800/2582 wide_loss: 0.466895 deep_loss: 0.468971 - INFO:root:epoch 0 step 900/2582 wide_loss: 0.463155 deep_loss: 0.465299 - INFO:root:epoch 0 step 1000/2582 wide_loss: 0.457914 deep_loss: 0.460132 - INFO:root:epoch 0 step 1100/2582 wide_loss: 0.453361 deep_loss: 0.455767 - INFO:root:epoch 0 step 1200/2582 wide_loss: 0.457566 deep_loss: 0.459997 - INFO:root:epoch 0 step 1300/2582 wide_loss: 0.460841 deep_loss: 0.463281 - INFO:root:epoch 0 step 1400/2582 wide_loss: 0.460973 deep_loss: 0.463365 - INFO:root:epoch 0 step 1500/2582 wide_loss: 0.459204 deep_loss: 0.461563 - INFO:root:epoch 0 step 1600/2582 wide_loss: 0.456771 deep_loss: 0.459200 - INFO:root:epoch 0 step 1700/2582 wide_loss: 0.458479 deep_loss: 0.460963 - INFO:root:epoch 0 step 1800/2582 wide_loss: 0.449609 deep_loss: 0.452122 - INFO:root:epoch 0 step 1900/2582 wide_loss: 0.451775 deep_loss: 0.454225 - INFO:root:epoch 0 step 2000/2582 wide_loss: 0.460343 deep_loss: 0.462826 - INFO:root:epoch 0 step 2100/2582 wide_loss: 0.456814 deep_loss: 0.459201 - INFO:root:epoch 0 step 2200/2582 wide_loss: 0.452091 deep_loss: 0.454555 - INFO:root:epoch 0 step 2300/2582 wide_loss: 0.461522 deep_loss: 0.464001 - INFO:root:epoch 0 step 2400/2582 wide_loss: 0.442355 deep_loss: 0.444790 - INFO:root:epoch 0 step 2500/2582 wide_loss: 0.450675 deep_loss: 0.453242 - ... - ``` - -6. Close training process. - - ```shell - pid=`ps -ef|grep run_vfl_train_socket |grep -v "grep" | grep -v "finish" |awk '{print $2}'` && for id in $pid; do kill -9 $id && echo "killed $id"; done - ``` - -## Deep Experience - -Before starting the vertical federated learning training, users need to construct the dataset iterator and network structure as they do for normal deep learning training with MindSpore. - -### Building the Dataset - -The current simulation process is used, i.e., both participants read the same data source. But for training, both participants use only part of the feature or label data, as shown in [Network and Data](#network-and-data). Later, the [Data Access](https://www.mindspore.cn/federated/docs/en/master/data_join/data_join.html) method will be used for both participants to import the data individually. - -```python -from run_vfl_train_local import construct_local_dataset - - -ds_train, _ = construct_local_dataset() -train_iter = ds_train.create_dict_iterator() -``` - -### Building the Network - -Leader participant network: - -```python -from wide_and_deep import WideDeepModel, BottomLossNet, LeaderTopNet, LeaderTopLossNet, LeaderTopEvalNet, \ - LeaderTeeNet, LeaderTeeLossNet, LeaderTopAfterTeeNet, LeaderTopAfterTeeLossNet, LeaderTopAfterTeeEvalNet, \ - AUCMetric -from network_config import config - - -# Leader Top Net -leader_top_base_net = LeaderTopNet() -leader_top_train_net = LeaderTopLossNet(leader_top_base_net) -... -# Leader Bottom Net -leader_bottom_eval_net = leader_bottom_base_net = WideDeepModel(config, config.leader_field_size) -leader_bottom_train_net = BottomLossNet(leader_bottom_base_net, config) -``` - -Follower participant network: - -```python -from wide_and_deep import WideDeepModel, BottomLossNet -from network_config import config - - -follower_bottom_eval_net = follower_base_net = WideDeepModel(config, config.follower_field_size) -follower_bottom_train_net = BottomLossNet(follower_base_net, config) -``` - -### Vertical Federated Communication Base - -Before training, we first have to start the communication base to make Leader and Follower participants group network. Detailed API documentation can be found in [Vertical Federated Communicator](https://gitee.com/mindspore/federated/blob/master/docs/api/api_python_en/vertical/vertical_communicator.rst). - -Both parties need to import the vertical federated communicator: - -```python -from mindspore_federated.startup.vertical_federated_local import VerticalFederatedCommunicator, ServerConfig -``` - -Leader participant communication base: - -```python -http_server_config = ServerConfig(server_name='leader', server_address=config.http_server_address) -remote_server_config = ServerConfig(server_name='follower', server_address=config.remote_server_address) -self.vertical_communicator = VerticalFederatedCommunicator(http_server_config=http_server_config, - remote_server_config=remote_server_config, - compress_configs=compress_configs) -self.vertical_communicator.launch() -``` - -Follower participant communication base: - -```python -http_server_config = ServerConfig(server_name='follower', server_address=config.http_server_address) -remote_server_config = ServerConfig(server_name='leader', server_address=config.remote_server_address) -self.vertical_communicator = VerticalFederatedCommunicator(http_server_config=http_server_config, - remote_server_config=remote_server_config, - compress_configs=compress_configs) -self.vertical_communicator.launch() -``` - -### Building a Vertical Federated Network - -Users need to use the classes provided by MindSpore Federated to wrap their constructed networks into a vertical federated network. The detailed API documentation can be found in [Vertical Federated Training Interface](https://gitee.com/mindspore/federated/blob/master/docs/api/api_python_en/vertical/vertical_federated_FLModel.rst). - -Both parties need to import the vertical federated training interface: - -```python -from mindspore_federated import FLModel, FLYamlData -``` - -Leader participant vertical federated network: - -```python -leader_bottom_yaml_data = FLYamlData(config.leader_bottom_yaml_path) -leader_top_yaml_data = FLYamlData(config.leader_top_yaml_path) -... -self.leader_top_fl_model = FLModel(yaml_data=leader_top_yaml_data, - network=leader_top_train_net, - metrics=self.eval_metric, - eval_network=leader_top_eval_net) -... -self.leader_bottom_fl_model = FLModel(yaml_data=leader_bottom_yaml_data, - network=leader_bottom_train_net, - eval_network=leader_bottom_eval_net) -``` - -Follower participant vertical federated network: - -```python -follower_bottom_yaml_data = FLYamlData(config.follower_bottom_yaml_path) -... -self.follower_bottom_fl_model = FLModel(yaml_data=follower_bottom_yaml_data, - network=follower_bottom_train_net, - eval_network=follower_bottom_eval_net) -``` - -### Vertical Training - -For the process of vertical training, refer to [overview](#overview). - -Leader participant training process: - -```python -for epoch in range(config.epochs): - for step, item in enumerate(train_iter): - leader_embedding = self.leader_bottom_fl_model.forward_one_step(item) - item.update(leader_embedding) - follower_embedding = self.vertical_communicator.receive("follower") - ... - leader_out = self.leader_top_fl_model.forward_one_step(item, follower_embedding) - grad_scale = self.leader_top_fl_model.backward_one_step(item, follower_embedding) - scale_name = 'loss' - ... - grad_scale_follower = {scale_name: OrderedDict(list(grad_scale[scale_name].items())[2:])} - self.vertical_communicator.send_tensors("follower", grad_scale_follower) - grad_scale_leader = {scale_name: OrderedDict(list(grad_scale[scale_name].items())[:2])} - self.leader_bottom_fl_model.backward_one_step(item, sens=grad_scale_leader) -``` - -Follower participant training process: - -```python -for _ in range(config.epochs): - for _, item in enumerate(train_iter): - follower_embedding = self.follower_bottom_fl_model.forward_one_step(item) - self.vertical_communicator.send_tensors("leader", follower_embedding) - scale = self.vertical_communicator.receive("leader") - self.follower_bottom_fl_model.backward_one_step(item, sens=scale) -``` - diff --git a/docs/federated/docs/source_en/vertical_federated_trainer.rst b/docs/federated/docs/source_en/vertical_federated_trainer.rst deleted file mode 100644 index 388322d850ce57804d645fec7dfc8216d612940f..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/vertical_federated_trainer.rst +++ /dev/null @@ -1,12 +0,0 @@ -Vertical Federated Trainer -========================== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/vertical_federated_trainer.rst - :alt: View Source on Gitee - -.. toctree:: - :maxdepth: 1 - - vertical/vertical_federated_FLModel - vertical/vertical_federated_yaml \ No newline at end of file diff --git a/docs/federated/docs/source_en/vfl_communication_compress.md b/docs/federated/docs/source_en/vfl_communication_compress.md deleted file mode 100644 index ae13e0c55dea254ed994295f1c3e19f83f0a71cf..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_en/vfl_communication_compress.md +++ /dev/null @@ -1,231 +0,0 @@ -# Vertical Federated Learning Communication Compression - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_en/vfl_communication_compress.md) - -Vertical federated learning traffic affects user experience (user traffic, communication latency, federated learning training efficiency) and is limited by performance constraints (memory, bandwidth, CPU usage). Small amount of communication helps a lot to improve user experience and reduce performance bottlenecks, so it needs to be compressed. MindSpore Federated implements bi-directional communication compression between Leader and Follower in a vertical federated application scenario. - -## Overall Process - -![image1](./images/vfl_normal_communication_compress_en.png) - -Figure 1 Framework diagram of general vertical federated learning communication compression process - -First perform the Embedding DP (EDP) encryption operation on Follower. Then enter the bit packing process. The bit packing process automatically determines whether the input data can be packed, and only when the input data can be strongly converted to the specified bit storage format with no loss of precision, the bit packing operation is performed. Follower sends the packaged data to Leader, Leader determines whether the data needs to be unpacked based on the reported data information. Before the Leader passes the data to the Follower, it is quantized and compressed. Follower receives the data and decompresses the quantized data. - -![image2](./images/vfl_pangu_communication_compress_en.png) - -Figure 2 Framework diagram of Pangu vertical federated learning communication compression process - -The overall process is the same as the general vertical federated learning communication compression process. Compared to normal vertical federated, each iteration will have one more round of communication in Pangu vertical federated learning, so it needs to perform one more quantization compression and decompression process. - -## Compression Method - -### Bit Packing Compression Method - -Bit-packing compression is a method of converting a sequence of data structures into a compact binary representation. Bit packing itself is a lossless compression method, but usually the data input to bit packing is compressed with loss compression. - -Taking 3-bit packing as an example: - -Quantization bitness bit_num= 3 - -The data stored in the format of float32 before compression is: - -data = [3, -4, 3, -2, 3, -2, -4, 0, 1, 3] - -First determine if compression is possible using bit-packing compression: - -data_int = int(data) - -If the elements in data - data_int are not 0, exit the bit packing process. - -Convert the source data to binary format based on bit_num: - -data_bin = [011, 100, 011, 110, 011, 110, 100, 000, 001, 011] - -Note: Before conversion, you need to determine whether the current data is within the range that bit_num can accommodate. If it exceeds the range exit the bit packing process. - -Since native C++ does not have a dedicated binary storage format, multiple binary data need to be stitched together and combined into int8 format data storage. If the number of bits is not enough, zero is added to the last data. Combined data is as follows: - -data_int8 = [01110001, 11100111, 10100000, 00101100] - -The binary data is then converted to an integer between -128 and 127, and the data type is strongly converted to int8. - -data_packed = [113, -25, -96, 44] - -Finally, data_packed and bit_num are passed to the receiver. - -When unpacking, the receiver simply reverses the above process. - -### Quantization Compression Method - -The quantization compression method is that communication data fixpoint of floating is approximated to a finite number of discrete values. The currently supported quantization compression method is minimum-maximum compression (min_max). - -Taking the 8-bit quantization as an example: - -Quantization bitness bit_num= 8 - -The float data before compression is: - -data = [0.03356021, -0.01842778, -0.009684053, 0.025363436, -0.027571501, 0.0077043395, 0.016391572, -0.03598478, -0.0009508357] - -Compute the minimum and maximum values: - -min_val = -0.03598478 - -max_val = 0.03356021 - -Compute the scaling coefficient: - -scale = (max_val - min_val) / (2 ^ bit_num- 1) = 0.000272725450980392 - -Convert the pre-compressed data to an integer between -128 and 127 with the formula quant_data = round((data - min_val) / scale) - 2 ^ (bit_num - 1). And to directly convert the data type to int8: - -quant_data = [127, -64, -32, 97, -97, 32, 64, -128, 0] - -After the quantitative encoding, the parameters to be uploaded by the sender are quant_data, bit_num, and the maximum and minimum values min_val and max_val. - -After receiving quant_data, min_val and max_val, the receiver uses the inverse quantization formula (quant_data + 2 ^ (bit_num - 1)) * (max_val - min_val) / (2 ^ bit_num - 1) + min_val to restore the weights. - -## Fast Experience - -To use bit packing or quantized compression methods, one first needs to successfully complete the training aggregation process for any of the vertical federated scenarios, such as [Vertical Federated Learning Model Training - Wide&Deep Recommended Applications](https://www.mindspore.cn/federated/docs/en/master/split_wnd_application.html). The preparation work including datasets and network models and the process of simulating the start of federated learning are described in detail in this document. - -1. For MindSpore and MindSpore Federated installation and data preprocessing, refer to [Vertical Federated Learning Model Training - Wide&Deep Recommended Applications](https://www.mindspore.cn/federated/docs/en/master/split_wnd_application.html). - -2. Set the configuration related to compression setting in [related yaml](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo/yaml_files). - - The configuration of [leader_top.yaml](https://gitee.com/mindspore/federated/blob/master/example/splitnn_criteo/yaml_files/leader_top.yaml) is as follows: - - ```yaml - role: leader - model: # define the net of vFL party - train_net: - name: leader_loss_net - inputs: - - name: leader_wide_embedding - source: local - - name: leader_deep_embedding - source: local - - name: follower_wide_embedding - source: remote - compress_type: min_max - bit_num: 6 - - name: follower_deep_embedding - source: remote - compress_type: min_max - bit_num: 6 - ... - ``` - - The configuration of [follower_bottom.yaml](https://gitee.com/mindspore/federated/blob/master/example/splitnn_criteo/yaml_files/follower_bottom.yaml) is as follows: - - ```yaml - role: follower - model: # define the net of vFL party - train_net: - name: follower_loss_net - inputs: - - name: id_hldr0 - source: local - - name: wt_hldr0 - source: local - outputs: - - name: follower_wide_embedding - destination: remote - compress_type: min_max - bit_num: 6 - - name: follower_deep_embedding - destination: remote - compress_type: min_max - bit_num: 6 - - name: follower_l2_regu - destination: local - ... - ``` - -3. Users can make super-reference modifications according to the actual situation. - - - compress_type: Compression type, string type, including: "min_max", "bit_pack". - - bit_num: The number of bits, int type, the definition domain is within [1, 8]. - -4. Run the sample program to start the script. - - ```shell - # Start leader process: - bash run_vfl_train_leader.sh 127.0.0.1:1984 127.0.0.1:1230 ./mindrecord/ False - # Start follower process: - bash run_vfl_train_follower.sh 127.0.0.1:1230 127.0.0.1:1984 ./mindrecord/ False - ``` - -5. Check the training log `vfl_train_leader.log`. loss converges normally. - - ```text - epoch 0 step 0 loss: 0.693124 - epoch 0 step 100 loss: 0.512151 - epoch 0 step 200 loss: 0.493524 - epoch 0 step 300 loss: 0.473054 - epoch 0 step 400 loss: 0.466222 - epoch 0 step 500 loss: 0.464252 - epoch 0 step 600 loss: 0.469296 - epoch 0 step 700 loss: 0.451647 - epoch 0 step 800 loss: 0.457797 - epoch 0 step 900 loss: 0.457930 - epoch 0 step 1000 loss: 0.461664 - epoch 0 step 1100 loss: 0.460415 - epoch 0 step 1200 loss: 0.466883 - epoch 0 step 1300 loss: 0.455919 - epoch 0 step 1400 loss: 0.466984 - epoch 0 step 1500 loss: 0.454486 - epoch 0 step 1600 loss: 0.458730 - epoch 0 step 1700 loss: 0.451275 - epoch 0 step 1800 loss: 0.445938 - epoch 0 step 1900 loss: 0.458323 - epoch 0 step 2000 loss: 0.446709 - ... - ``` - -6. Close the training process - - ```shell - pid=`ps -ef|grep run_vfl_train_ |grep -v "grep" | grep -v "finish" |awk '{print $2}'` && for id in $pid; do kill -9 $id && echo "killed $id"; done - ``` - -## Deep Experience - -### Obtaining the Compression Configuration - -The user can use the encapsulated interface to get the configuration related to communication compression. The [Yaml Configuration file for model training of vertical federated learning](https://www.mindspore.cn/federated/docs/en/master/vertical/vertical_federated_yaml.html) gives the configuration description of the parameters related to the startup. The [Model Training Interface](https://www.mindspore.cn/federated/docs/en/master/vertical/vertical_federated_FLModel.html) provides the interface to get the compression configuration. The example method is as follows: - -```python -# parse yaml files -leader_top_yaml_data = FLYamlData(config.leader_top_yaml_path) - -# Leader Top Net -leader_top_base_net = LeaderTopNet() -leader_top_train_net = LeaderTopLossNet(leader_top_base_net) -leader_top_fl_model = FLModel( - yaml_data=leader_top_yaml_data, - network=leader_top_train_net -) - -# get compress config -compress_configs = leader_top_fl_model.get_compress_configs() -``` - -### Setting Compression Configuration - -Users can use the already encapsulated [Vertical Federated Learning Communicator](https://www.mindspore.cn/federated/docs/en/master/vertical/vertical_communicator.html) interface to set the configuration related to communication compression to the communicator device by the following method: - -```python -# build vertical communicator -http_server_config = ServerConfig(server_name='leader', server_address=config.http_server_address) -remote_server_config = ServerConfig(server_name='follower', server_address=config.remote_server_address) -vertical_communicator = VerticalFederatedCommunicator( - http_server_config=http_server_config, - remote_server_config=remote_server_config, - compress_configs=compress_configs -) -vertical_communicator.launch() -``` - -After setting the communication compression configuration, the vertical federated framework will automatically compress the communication content in the backend. diff --git a/docs/federated/docs/source_zh_cn/Data_Join.rst b/docs/federated/docs/source_zh_cn/Data_Join.rst deleted file mode 100644 index e1a33a58b0b8602530dac4750cdf1d98f275aee4..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/Data_Join.rst +++ /dev/null @@ -1,12 +0,0 @@ -数据求交 -===================== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/Data_Join.rst - :alt: 查看源文件 - -.. toctree:: - :maxdepth: 1 - - data_join/data_join - data_join/private_set_intersection \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/_ext/__pycache__/my_signature.cpython-37.pyc b/docs/federated/docs/source_zh_cn/_ext/__pycache__/my_signature.cpython-37.pyc deleted file mode 100644 index 516af30c9a197f89a41eb556f5f38d22e72d6668..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/_ext/__pycache__/my_signature.cpython-37.pyc and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/communication_compression.md b/docs/federated/docs/source_zh_cn/communication_compression.md deleted file mode 100644 index c64fd7bd5c6d0725d8fce2ab8d2c39b4d0b3416a..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/communication_compression.md +++ /dev/null @@ -1,137 +0,0 @@ -# 端云联邦学习通信压缩 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/communication_compression.md) - -在横向的端云联邦学习训练过程中,通信量会影响端侧用户体验(用户流量、通信时延、FL-Client 参与数量),并受云侧性能约束(内存、带宽、CPU 占用率)限制。为了提高用户体验和减少性能瓶颈,MindSpore联邦学习框架在端云联邦场景中,提供上传和下载的通信量压缩功能。 - -## 压缩方法 - -### 上传压缩方法 - -上传压缩方法可以分为三个主要部分:权重差编解码、稀疏编解码和量化编解码,下面给出了FL-Client和FL-Server上的流程图。 - -![上传压缩client执行顺序](./images/upload_compression_client.png) - -图1 上传压缩方法在FL-Client上的流程图 - -![上传压缩server执行顺序](./images/upload_compression_server.png) - -图2 上传压缩方法在FL-Server上的流程图 - -### 权重差编解码 - -权重差即为端侧训练前后的权重矩阵的向量差。相较于原始权重而言,权重差的分布更符合高斯分布,因此更适合被压缩。FL-Client对权重差进行编码操作,FL-Server对权重差进行解码操作。值得注意的是,为了在FL-Server聚合权重前就将权重差还原为权重,FL-Client在上传权重时,不将权重乘以数据量。FL-Server解码时,需要将权重乘以数据量。 - -![权重差编码](./images/weight_diff_encode.png) - -图3 权重差编码在FL-Client上的流程图 - -![权重差解码](./images/weight_diff_decode.png) - -图4 权重差解码在FL-Server上的流程图 - -### 稀疏编解码 - -端云都遵循同样的随机算法生成稀疏掩码矩阵,该掩码矩阵和原本需要上传的权重形状相同。掩码矩阵只包含0或1两个值。每个FL-Client只上传和掩码矩阵非零值位置相同的权重的数据到FL-Server上。 - -以稀疏率sparse_rate=0.08的稀疏方法为例。FL-Client原本需要上传的参数: - -| 参数名 | 长度 | -| -------------------- | ----- | -| albert.pooler.weight | 97344 | -| albert.pooler.bias | 312 | -| classifier.weight | 1560 | -| classifier.bias | 5 | - -将所有参数接为一维向量: - -| 参数名 | 长度 | -| ----------- | ---------------------- | -| merged_data | 97344+312+1560+5=99221 | - -生成和拼接后参数长度一样的mask向量,其中,有7937 = int(sparse_rate*拼接后参数长度)个值为1,其余值为0(即mask_vector = (1,1,1,...,0,0,0,...)): - -| 参数名 | 长度 | -| ----------- | ----- | -| mask_vector | 99221 | - -使用伪随机算法,将mask_vector随机打乱。随机种子为当前的iteration数。取出mask_vector中值为1的索引indexes。取出merged_data[indexes]的值,即压缩后的向量: - -| 参数名 | 长度 | -| ----------------- | ---- | -| compressed_vector | 7937 | - -稀疏压缩后,FL-Client需要上传的参数即为compressed_vector。 - -FL-Server在收到compressed_vector后,首先会用和FL-Client同样的伪随机算法和随机种子,构造出掩码向量mask_vector。然后取出mask_vector中值为1的索引indexes。再然后,生成和模型相同shape的全零矩阵weight_vector。依次将compressed_vector中的值放入weight_vector[indexes]中。weight_vector即为稀疏解码后的向量。 - -### 量化编解码 - -量化压缩方法即将浮点型的通信数据定点近似为有限多个离散值。 - -以8-bit量化举例来讲: - -量化位数num_bits = 8 - -压缩前的浮点型数据为: - -data = [0.03356021, -0.01842778, -0.009684053, 0.025363436, -0.027571501, 0.0077043395, 0.016391572, -0.03598478, -0.0009508357] - -计算最大和最小值: - -min_val = -0.03598478 - -max_val = 0.03356021 - -计算缩放系数: - -scale = (max_val - min_val ) / (2 ^ num_bits - 1) = 0.000272725450980392 - -将压缩前数据转换为-128到127之间的整数,转换公式为quant_data = round((data - min_val) / scale) - 2 ^ (num_bits - 1)。并强转数据类型到int8: - -quant_data = [127, -64, -32, 97, -97, 32, 64, -128, 0] - -量化编码后,FL-Client需要上传的参数即为quant_data以及最小和最大值min_val和max_val。 - -FL-Server在收到quant_data、min_val和max_val后,使用反量化公式(quant_data + 2 ^ (num_bits - 1)) * (max_val - min_val) / (2 ^ num_bits - 1) + min_val,还原出权重。 - -## 下载压缩方法 - -下载压缩方法主要为量化编解码操作,下面给出了FL-Server和FL-Client上的流程图。 - -![下载压缩server执行顺序](./images/download_compression_server.png) - -图5 下载压缩方法在FL-Server上的流程图 - -![下载压缩client执行顺序](./images/download_compression_client.png) - -图6 下载压缩方法在FL-Client上的流程图 - -### 量化编解码 - -量化的编解码方法和上传压缩中一样。 - -## 代码实现准备工作 - -若要使用上传和下载压缩方法,首先需要成功完成任一端云联邦场景的训练聚合过程,如[实现一个情感分类应用(Android)](https://www.mindspore.cn/federated/docs/zh-CN/master/sentiment_classification_application.html)。在该文档中详细介绍了包括数据集和网络模型等准备工作和模拟启动多客户端参与联邦学习的流程。 - -## 算法开启脚本 - -上传和下载压缩方法目前只支持端云联邦学习场景。开启方式需要在启动云侧服务时,在server启动脚本中,在对应的yaml中设置`upload_compress_type='DIFF_SPARSE_QUANT'`和`download_compress_type='QUANT'`。上述两个超参数即可分别控制上传和下载压缩方法的开启和关闭。云侧[完整启动脚本](https://gitee.com/mindspore/federated/tree/master/tests/st/cross_device_cloud/),这里给出启动该算法的相关参数配置。确定参数配置后,用户需要在执行训练前配置对应参数,具体如下: - -```yaml -compression: - upload_compress_type: NO_COMPRESS - upload_sparse_rate: 0.4 - download_compress_type: NO_COMPRESS -``` - -| 超参名称&参考值 | 超参描述 | -| ---------------------- | ------------------------------------------------------------ | -| upload_compress_type | 上传压缩类型,string类型,包括:"NO_COMPRESS"、"DIFF_SPARSE_QUANT" | -| upload_sparse_rate | 稀疏率,即权重保留率,float类型,定义域在(0, 1]内 | -| download_compress_type | 下载压缩类型,string类型,包括:"NO_COMPRESS"、"QUANT" | - -## ALBERT实验结果 - -联邦学习总迭代数为100,客户端本地训练epoch数为1,客户端数量为20,batchSize设置为16,学习率为1e-5,同时开启上传和下载压缩方法,上传稀疏率为0.4。最终在验证集上的准确率为72.5%,不压缩的普通联邦场景为72.3%。 \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/conf.py b/docs/federated/docs/source_zh_cn/conf.py deleted file mode 100644 index ab4cf042cbba087fd74ebcd200ceb09689244ded..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/conf.py +++ /dev/null @@ -1,232 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import shutil -import sys -import IPython -import re -from sphinx.ext import autodoc as sphinx_autodoc - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_parser', - 'sphinx.ext.mathjax', - 'IPython.sphinxext.ipython_console_highlighting' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -# -- Options for HTML output ------------------------------------------------- - -# Reconstruction of sphinx auto generated document translation. -language = 'zh_CN' -locale_dirs = ['../../../../resource/locale/'] -gettext_compact = False - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -html_search_language = 'zh' - -html_search_options = {'dict': '../../../resource/jieba.txt'} - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/', '../../../../resource/python_objects.inv'), - 'numpy': ('https://docs.scipy.org/doc/numpy/', '../../../../resource/numpy_objects.inv'), -} - -from sphinx import directives -with open('../_ext/overwriteobjectiondirective.txt', 'r', encoding="utf8") as f: - exec(f.read(), directives.__dict__) - -from sphinx.ext import viewcode -with open('../_ext/overwriteviewcode.txt', 'r', encoding="utf8") as f: - exec(f.read(), viewcode.__dict__) - -# Modify default signatures for autodoc. -autodoc_source_path = os.path.abspath(sphinx_autodoc.__file__) -autodoc_source_re = re.compile(r'stringify_signature\(.*?\)') -get_param_func_str = r"""\ -import re -import inspect as inspect_ - -def get_param_func(func): - try: - source_code = inspect_.getsource(func) - if func.__doc__: - source_code = source_code.replace(func.__doc__, '') - all_params_str = re.findall(r"def [\w_\d\-]+\(([\S\s]*?)(\):|\) ->.*?:)", source_code) - all_params = re.sub("(self|cls)(,|, )?", '', all_params_str[0][0].replace("\n", "").replace("'", "\"")) - return all_params - except: - return '' - -def get_obj(obj): - if isinstance(obj, type): - return obj.__init__ - - return obj -""" - -with open(autodoc_source_path, "r+", encoding="utf8") as f: - code_str = f.read() - code_str = autodoc_source_re.sub('"(" + get_param_func(get_obj(self.object)) + ")"', code_str, count=0) - exec(get_param_func_str, sphinx_autodoc.__dict__) - exec(code_str, sphinx_autodoc.__dict__) - -# Copy source files of chinese python api from federated repository. -from sphinx.util import logging -logger = logging.getLogger(__name__) - -copy_path = 'docs/api/api_python' -src_dir = os.path.join(os.getenv("MF_PATH"), copy_path) - -copy_list = [] - -present_path = os.path.dirname(__file__) - -for i in os.listdir(src_dir): - if os.path.isfile(os.path.join(src_dir,i)): - if os.path.exists('./'+i): - os.remove('./'+i) - shutil.copy(os.path.join(src_dir,i),'./'+i) - copy_list.append(os.path.join(present_path,i)) - else: - if os.path.exists('./'+i): - shutil.rmtree('./'+i) - shutil.copytree(os.path.join(src_dir,i),'./'+i) - copy_list.append(os.path.join(present_path,i)) - -# add view -import json - -if os.path.exists('../../../../tools/generate_html/version.json'): - with open('../../../../tools/generate_html/version.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily_dev.json'): - with open('../../../../tools/generate_html/daily_dev.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily.json'): - with open('../../../../tools/generate_html/daily.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) - -if os.getenv("MF_PATH").split('/')[-1]: - copy_repo = os.getenv("MF_PATH").split('/')[-1] -else: - copy_repo = os.getenv("MF_PATH").split('/')[-2] - -branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == copy_repo][0] -docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == 'tutorials'][0] - -re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{docs_branch}/" + \ - f"resource/_static/logo_source.svg\n :target: https://gitee.com/mindspore/{copy_repo}/blob/{branch}/" - -for cur, _, files in os.walk(present_path): - for i in files: - flag_copy = 0 - if i.endswith('.rst'): - for j in copy_list: - if j in cur: - flag_copy = 1 - break - if os.path.join(cur, i) in copy_list or flag_copy: - try: - with open(os.path.join(cur, i), 'r+', encoding='utf-8') as f: - content = f.read() - new_content = content - if '.. include::' in content and '.. automodule::' in content: - continue - if 'autosummary::' not in content and "\n=====" in content: - re_view_ = re_view + copy_path + cur.split(present_path)[-1] + '/' + i + \ - '\n :alt: 查看源文件\n\n' - new_content = re.sub('([=]{5,})\n', r'\1\n' + re_view_, content, 1) - if new_content != content: - f.seek(0) - f.truncate() - f.write(new_content) - except Exception: - print(f'打开{i}文件失败') - -import mindspore_federated - -sys.path.append(os.path.abspath('../../../../resource/sphinx_ext')) -# import anchor_mod -import nbsphinx_mod - - -sys.path.append(os.path.abspath('../../../../resource/search')) -import search_code - -sys.path.append(os.path.abspath('../../../../resource/custom_directives')) -from custom_directives import IncludeCodeDirective - -def setup(app): - app.add_directive('includecode', IncludeCodeDirective) - -src_release = os.path.join(os.getenv("MF_PATH"), 'RELEASE_CN.md') -des_release = "./RELEASE.md" -with open(src_release, "r", encoding="utf-8") as f: - data = f.read() -if len(re.findall("\n## (.*?)\n",data)) > 1: - content = re.findall("(## [\s\S\n]*?)\n## ", data) -else: - content = re.findall("(## [\s\S\n]*)", data) -#result = content[0].replace('# MindSpore', '#', 1) -with open(des_release, "w", encoding="utf-8") as p: - p.write("# Release Notes"+"\n\n") - p.write(content[0]) diff --git a/docs/federated/docs/source_zh_cn/cross_device.rst b/docs/federated/docs/source_zh_cn/cross_device.rst deleted file mode 100644 index 813aa7f74d4b795049d780a8e018ed247b3f17ea..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/cross_device.rst +++ /dev/null @@ -1,17 +0,0 @@ -端侧客户端 -============ - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/cross_device.rst - :alt: 查看源文件 - -.. toctree:: - :maxdepth: 1 - - java_api_callback - java_api_client - java_api_clientmanager - java_api_dataset - java_api_flparameter - java_api_syncfljob - interface_description_federated_client \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/data_join.md b/docs/federated/docs/source_zh_cn/data_join.md deleted file mode 100644 index 8d9e3ba5e03a60676f61964ead8c949b1132a032..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/data_join.md +++ /dev/null @@ -1,242 +0,0 @@ -# 纵向联邦学习数据接入 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/data_join.md) - -和横向联邦学习不同,纵向联邦学习训练或推理时,两个参与方(leader和follower)拥有相同样本空间。因此,在纵向联邦学习的双方发起训练或推理之前,必须协同完成数据求交。双方必须读取各自的原始数据,并提取出每条数据对应的ID(每条数据的唯一标识符,且都不相同)进行求交(即求取交集)。然后,双方根据求交后的ID从原始数据中获得特征或标签等数据。最后各自导出持久化文件,并在后续训练或推理之前保序地读取数据。 - -## 总体流程 - -数据接入可以分为数据导出和数据读取两个部分。 - -### 数据导出 - -MindSpore Federated纵向联邦学习数据导出流程框架如图1所示: - -![](./images/data_join.png) - -图 1 纵向联邦学习数据接入流程框架图 - -在数据导出流程中,Leader Worker和 Follower Worker为纵向联邦学习的两个参与方。Leader Worker常驻并保持对Follower Worker的监听,Follower Worker可以在任意时刻进入数据接入流程中。 - -在Leader Worker收到 Follower Worker的注册请求后,会对注册内容进行校验。若注册成功,则给Follower Worker发送任务相关的超参(PSI 相关超参、分桶规则、ID字段名称等)。 - -然后Leader Worker 和 Follower Worker 分别读取各自的原始数据,再从各自的原始数据中提取出 ID 列表并实现分桶。 - -Leader Worker 和 Follower Worker 的每个桶都启动隐私求交方法获得两方的ID交集。 - -最后,两方根据ID交集提取原始数据中相应的数据并导出成MindRecord格式的文件。 - -### 数据读取 - -纵向联邦要求两个参与方在训练或推理的每一个批次的数据ID的值和顺序都一样的。MindSpore Federated通过在两方读取各自数据时,使用相同的随机种子和对导出的文件集合使用字典排序的方法,保证数据读取的顺序一致。 - -## 快速体验 - -### 数据准备样例 - -若要使用数据接入方法,首先需要准备好原始数据。用户可以使用[随机数据生成脚本](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/generate_random_data.py)生成出各参与方的伪造数据作为样例。 - -```shell -python generate_random_data.py \ - --seed=0 \ - --total_output_path=vfl/input/total_data.csv \ - --intersection_output_path=vfl/input/intersection_data.csv \ - --leader_output_path=vfl/input/leader_data_*.csv \ - --follower_output_path=vfl/input/follower_data_*.csv \ - --leader_file_num=4 \ - --follower_file_num=2 \ - --leader_data_num=300 \ - --follower_data_num=200 \ - --overlap_num=100 \ - --id_len=20 \ - --feature_num=30 -``` - -用户可根据实际情况进行超参设置: - -| 超参名称 | 超参描述 | -| -------------------- | ------------------------------------------------------------ | -| seed | 随机种子,int类型。 | -| total_output_path | 所有数据的输出路径,str类型。 | -| intersection_output_path | 交集数据的输出路径,str类型。 | -| leader_output_path | leader方数据的输出路径。若配置的内容包括`*`号,则会在导出多个文件时将`*`号依次替换为0、1、2……的序号。str类型。 | -| follower_output_path | follower方数据的输出路径。若配置的内容包括`*`号,则会在导出多个文件时将`*`号依次替换为0、1、2……的序号。str类型。 | -| leader_file_num | leader方数据的输出文件数目,int类型。 | -| follower_file_num | follower方数据的输出文件数目,int类型。 | -| leader_data_num | leader方数据总量,int类型。 | -| follower_data_num | follower方数据总量,int类型。 | -| overlap_num | 两方重叠的数据总量,int类型。 | -| id_len | 数据ID为字符串类型。该超参为字符串的长度,int类型。 | -| feature_num | 导出的数据的列数。 | - -运行数据准备后生成多个csv文件: - -```text -follower_data_0.csv -follower_data_1.csv -intersection_data.csv -leader_data_0.csv -leader_data_1.csv -leader_data_2.csv -leader_data_3.csv -``` - -### 数据导出样例 - -用户可以使用[数据求交脚本](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/run_data_join.py)实现两方数据求交并导出成MindRecord格式文件。用户需要分别启动Leader和Follower两个进程。 - -启动Leader: - -```shell -python run_data_join.py \ - --role="leader" \ - --main_table_files="vfl/input/leader/" \ - --output_dir="vfl/output/leader/" \ - --data_schema_path="vfl/leader_schema.yaml" \ - --server_name=leader_node \ - --http_server_address="127.0.0.1:1086" \ - --remote_server_name=follower_node \ - --remote_server_address="127.0.0.1:1087" \ - --primary_key="oaid" \ - --bucket_num=5 \ - --store_type="csv" \ - --shard_num=1 \ - --join_type="psi" \ - --thread_num=0 -``` - -启动Follower: - -```shell -python run_data_join.py \ - --role="follower" \ - --main_table_files="vfl/input/follower/" \ - --output_dir="vfl/output/follower/" \ - --data_schema_path="vfl/follower_schema.yaml" \ - --server_name=follower_node \ - --http_server_address="127.0.0.1:1087" \ - --remote_server_name=leader_node \ - --remote_server_address="127.0.0.1:1086" \ - --store_type="csv" \ - --thread_num=0 -``` - -用户可根据实际情况进行超参设置: - -| 超参名称 | 超参描述 | -| ------------------- | ------------------------------------------------------- | -| role | worker的角色类型,str类型,包括:"leader"、"follower"。 | -| main_table_files | 原始数据路径,可以配置单个或多个文件路径、数据目录路径,list或str类型。 | -| output_dir | 导出的MindRecord相关文件的目录路径,str类型。 | -| data_schema_path | 导出时所需要配置的超参文件存放的路径,str类型。 | -| server_name |本地用于通信的http服务名字,str类型。 | -| http_server_address | 本机IP和端口地址,str类型。 | -| remote_server_name | 对端用于通信的http服务名字,str类型。 | -| remote_server_address | 对端IP和端口地址,str类型。 | -| primary_key(Follower不需要配置) | 数据ID的名称,str类型。 | -| bucket_num(Follower不需要配置) | 求交和导出时,分桶的数目,int类型。 | -| store_type | 原始数据存储类型,str类型。包括:"csv"。 | -| shard_num(Follower不需要配置) | 单个桶导出的文件数量,int类型。 | -| join_type(Follower不需要配置) | 求交算法,str类型。包括:"psi"。 | -| thread_num | 使用PSI求交算法时,计算所需线程数,int类型。 | - -在上述样例中,data_schema_path对应的文件可以参考[leader_schema.yaml](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/vfl/leader_schema.yaml)和[follower_schema.yaml](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/vfl/follower_schema.yaml)中的相应文件配置。用户需要在该文件中提供要导出的数据的列名和类型。 - -运行数据导出后生成多个MindRecord相关文件: - -```text -mindrecord_0 -mindrecord_0.db -mindrecord_1 -mindrecord_1.db -mindrecord_2 -mindrecord_2.db -mindrecord_3 -mindrecord_3.db -mindrecord_4 -mindrecord_4.db -``` - -### 数据读取样例 - -用户可以使用[读取数据脚本](https://gitee.com/mindspore/federated/blob/master/tests/st/data_join/load_joined_data.py)实现求交后的数据读取: - -```shell -python load_joined_data.py \ - --seed=0 \ - --input_dir=vfl/output/leader/ \ - --shuffle=True -``` - -用户可根据实际情况进行超参设置: - -| 超参名称 | 超参描述 | -| --------- | ----------------------------------------- | -| seed | 随机种子,int类型。 | -| input_dir | 输入的MindRecord相关文件的目录,str类型。 | -| shuffle | 数据是否需要打乱,bool类型。 | - -如果求交结果正确,两方各自读取数据时,两方的每条数据的OAID顺序一致,而每条数据中的其他列的数据可以为不同值。运行数据读取后打印交集数据: - -```text -Leader数据导出运行结果: -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'uMbgxIMMwWhMGrVMVtM7')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'IwoGP08kWVtT4WHL2PLu')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'MSRe6mURtxgyEgWzDn0b')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'y7X0WcMKnTLrhxVcWfGF')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'DicKRIVvbOYSiv63TvcL')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'TCHgtynOhH3z11QYemsH')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'OWmhgIfC3k8UTteGUhni')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'NTV3qEYXBHqKBWyHGc7s')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'wuinSeN1bzYgXy4XmSlR')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'SSsCU0Pb46XGzUIa3Erg')} -…… - -Follower数据导出运行结果: -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'uMbgxIMMwWhMGrVMVtM7')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'IwoGP08kWVtT4WHL2PLu')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'MSRe6mURtxgyEgWzDn0b')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'y7X0WcMKnTLrhxVcWfGF')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'DicKRIVvbOYSiv63TvcL')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'TCHgtynOhH3z11QYemsH')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'OWmhgIfC3k8UTteGUhni')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'NTV3qEYXBHqKBWyHGc7s')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'wuinSeN1bzYgXy4XmSlR')} -{……, 'oaid': Tensor(shape=[], dtype=String, value= 'SSsCU0Pb46XGzUIa3Erg')} -…… -``` - -## 深度体验 - -下列代码的详细的API文档可以参考[数据接入文档](https://www.mindspore.cn/federated/docs/zh-CN/master/data_join/data_join.html)。 - -### 数据导出 - -用户可以使用已经封装好的接口和配置文件实现数据求交以及导出MindRecord相关文件,方法如下: - -```python -from mindspore_federated import FLDataWorker -from mindspore_federated.common.config import get_config - - -if __name__ == '__main__': - current_dir = os.path.dirname(os.path.abspath(__file__)) - args = get_config(os.path.join(current_dir, "vfl/vfl_data_join_config.yaml")) - dict_cfg = args.__dict__ - - worker = FLDataWorker(config=dict_cfg) - worker.do_worker() -``` - -### 数据读取 - -用户可以使用已经封装好的接口实现导出的MindRecord相关文件的数据读取,方法如下: - -```python -from mindspore_federated.data_join import load_mindrecord - - -if __name__ == "__main__": - dataset = load_mindrecord(input_dir="vfl/output/leader/", shuffle=True, seed=0) -``` - diff --git a/docs/federated/docs/source_zh_cn/deploy_federated_client.md b/docs/federated/docs/source_zh_cn/deploy_federated_client.md deleted file mode 100644 index a7b4e8ffb6485ce5a032f6e82dced2b111b9a379..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/deploy_federated_client.md +++ /dev/null @@ -1,200 +0,0 @@ -# 横向联邦端侧部署 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/deploy_federated_client.md) - -本文档介绍如何编译,部署Federated-Client。 - -## Linux 编译指导 - -### 系统环境和第三方依赖 - -本章节介绍如何完成MindSpore联邦学习的端侧编译,当前联邦学习端侧仅提供Linux上的编译指导,其他系统暂不支持。下表列出了编译所需的系统环境和第三方依赖。 - -| 软件名称 | 版本 | 作用 | -|-----------------------| ------------ | ------------ | -| Ubuntu | 18.04.02LTS | 编译和运行MindSpore的操作系统 | -| [GCC](#安装gcc) | 7.3.0到9.4.0之间 | 用于编译MindSpore的C++编译器 | -| [git](#安装git) | - | MindSpore使用的源代码管理工具 | -| [CMake](#安装cmake) | 3.18.3及以上 | 编译构建MindSpore的工具 | -| [Gradle](#安装gradle) | 6.6.1 | 基于JVM的构建工具 | -| [Maven](#安装maven) | 3.3.1及以上 | Java项目的管理和构建工具 | -| [OpenJDK](#安装openjdk) | 1.8 到 1.15之间 | Java项目的管理和构建工具 | - -#### 安装GCC - -可以通过以下命令安装GCC。 - -```bash -sudo apt-get install gcc-7 git -y -``` - -如果要安装更高版本的GCC,使用以下命令安装GCC 8。 - -```bash -sudo apt-get install gcc-8 -y -``` - -或者安装GCC 9。 - -```bash -sudo apt-get install software-properties-common -y -sudo add-apt-repository ppa:ubuntu-toolchain-r/test -sudo apt-get update -sudo apt-get install gcc-9 -y -``` - -#### 安装git - -可以通过以下命令安装git。 - -```bash -sudo apt-get install git -y -``` - -#### 安装CMake - -可以通过以下命令安装[CMake](https://cmake.org/)。 - -```bash -wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | sudo apt-key add - -sudo apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" -sudo apt-get install cmake -y -``` - -#### 安装Gradle - -可以通过以下命令安装[Gradle](https://gradle.org/releases/)。 - -```bash -# 下载对应的压缩包,解压。 -# 配置环境变量: - export GRADLE_HOME=GRADLE路径 - export GRADLE_USER_HOME=GRADLE路径 -# 将bin目录添加到PATH中: - export PATH=${GRADLE_HOME}/bin:$PATH -``` - -#### 安装Maven - -可以通过以下命令安装[Maven](https://archive.apache.org/dist/maven/maven-3/)。 - -```bash -# 下载对应的压缩包,解压。 -# 配置环境变量: - export MAVEN_HOME=MAVEN路径 -# 将bin目录添加到PATH中: - export PATH=${MAVEN_HOME}/bin:$PATH -``` - -#### 安装OpenJDK - -可以通过以下命令安装[OpenJDK](https://jdk.java.net/archive/)。 - -```bash -# 下载对应的压缩包,解压。 -# 配置环境变量: - export JAVA_HOME=JDK路径 -# 将bin目录添加到PATH中: - export PATH=${JAVA_HOME}/bin:$PATH -``` - -### 验证是否成功安装 - -确认[系统环境和第三方依赖](#系统环境和第三方依赖)中安装是否成功。 - -```text -打开命令窗口数输入:gcc --version -输出以下结果标识安装成功: - gcc version 版本号 - -打开命令窗口数输入:git --version -输出以下结果标识安装成功: - git version 版本号 - -打开命令窗口数输入:cmake --version -输出以下结果标识安装成功: - cmake version 版本号 - -打开命令窗口数输入:gradle --version -输出以下结果标识安装成功: - Gradle 版本号 - -打开命令窗口数输入:mvn --version -输出以下结果标识安装成功: - Apache Maven 版本号 - -打开命令窗口数输入:java --version -输出以下结果标识安装成功: - openjdk version 版本号 - -``` - -### 编译选项 - -联邦学习device_client目录下的`cli_build.sh`脚本用于联邦学习端侧的编译。 - -#### cli_build.sh的参数使用说明 - -| 参数 | 参数说明 | 取值范围 | 默认值 | -| ---- | ------------------------ | -------- | ------------ | -| -p | 依赖外部包的下载存放路径 | 字符串 | third | -| -c | 是否复用之前下载的依赖包 | on、off | on | - -### 编译示例 - -1. 首先,在进行编译之前,需从gitee代码仓下载源码。 - - ```bash - git clone https://gitee.com/mindspore/federated.git ./ - ``` - -2. 然后进入目录mindspore_federated/device_client,执行如下命令: - - ```bash - bash cli_build.sh - ``` - -3. 由于端侧框架和模型是解耦的,我们提供的x86架构包mindspore-lite-{version}-linux-x64.tar.gz不包含模型相关脚本,因此需要用户自行生成模型脚本对应的jar包,我们提供的模型脚本对应jar包可采用如下方式获取: - - ```bash - cd federated/example/quick_start_flclient - bash build.sh -r mindspore-lite-java-flclient.jar #-r 后需要给出最新x86架构包绝对路径(步骤2生成,federated/mindspore_federated/device_client/build/libs/jarX86/mindspore-lite-java-flclient.jar) - ``` - -运行以上指令后生成jar包路径为:federated/example/quick_start_flclient/target/quick_start_flclient.jar。 - -### 构建依赖环境 - -1. 将文件`federated/mindspore_federated/device_client/third/mindspore-lite-{version}-linux-x64.tar.gz`解压后,所得到的目录结构如下所示(联邦学习l不使用的文件不展示): - - ```sh - mindspore-lite-{version}-linux-x64 - ├── tools - └── runtime - ├── include # 训练框架头文件 - ├── lib # 训练框架库 - │ ├── libminddata-lite.a # 图像处理静态库文件 - │ ├── libminddata-lite.so # 图像处理动态库文件 - │ ├── libmindspore-lite-jni.so # MindSpore Lite推理框架依赖的jni动态库 - │ ├── libmindspore-lite-train.a # MindSpore Lite训练框架依赖的静态库 - │ ├── libmindspore-lite-train.so # MindSpore Lite训练框架依赖的动态库 - │ ├── libmindspore-lite-train-jni.so # MindSpore Lite训练框架依赖的jni动态库 - │ ├── libmindspore-lite.a # MindSpore Lite推理框架依赖的静态库 - │ ├── libmindspore-lite.so # MindSpore Lite推理框架依赖的动态库 - │ └── mindspore-lite-java.jar # MindSpore Lite训练框架jar包 - └── third_party - ├── glog - │└── libmindspore_glog.so.0 #glog 日志动态库文件 - └── libjpeg-turbo - └── lib - ├── libjpeg.so.62 # 图像处理动态库文件 - └── libturbojpeg.so.0 # 图像处理动态库文件 - ``` - -2. 可将路径`mindspore-lite-{version}-linux-x64/runtime/lib/`、`mindspore-lite-{version}-linux-x64/runtime/third_party/glog/`以及`mindspore-lite-{version}-linux-x64/runtime/third_party/libjpeg-turbo/lib/`中联邦学习所依赖的so文件放入一个文件夹,比如`/resource/x86libs/`。然后在x86中设置环境变量(下面需提供绝对路径): - - ```sh - export LD_LIBRARY_PATH=/resource/x86libs/:$LD_LIBRARY_PATH - ``` - -3. 设置好依赖环境之后,可参考应用实践教程[实现一个端云联邦的图像分类应用(x86)](https://www.mindspore.cn/federated/docs/zh-CN/master/image_classification_application.html)在x86环境中模拟启动多个客户端进行联邦学习。 \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/deploy_federated_server.md b/docs/federated/docs/source_zh_cn/deploy_federated_server.md deleted file mode 100644 index 317e44ece767deabdc76762293fcc6e6d5a30f69..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/deploy_federated_server.md +++ /dev/null @@ -1,317 +0,0 @@ -# 横向联邦云侧部署 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/deploy_federated_server.md) - -本文档以LeNet网络为例,讲解如何使用MindSpore Federated部署横向联邦学习集群。 - -MindSpore Federated Learning (FL) Server集群物理架构如图所示: - -![](./images/mindspore_federated_networking.png) - -如上图所示,在横向联邦学习云侧集群中,有三种角色的MindSpore进程:`Federated Learning Scheduler`、`Federated Learning Server`和`Federated Learning Worker`: - -- Federated Learning Scheduler - - `Scheduler`的功能主要包括: - - 1. 协助集群组网:在集群初始化阶段,由`Scheduler`负责收集`Server`信息,并保障集群一致性。 - 2. 开放管理面:向用户提供`RESTful`接口,实现对集群的管理。 - - 在一个联邦学习任务中,只有一个`Scheduler`,其与`Server`通过TCP协议通信。 - -- Federated Learning Server - - `Server`为执行联邦学习任务的主体,用于接收和解析端侧设备上传的数据,具有执行安全聚合、限时通信、模型存储等能力。在一个联邦学习任务中,`Server`可以有多个(用户可配置),`Server`间通过TCP协议通信,对外开放HTTP端口与端侧设备连接。 - - 在MindSpore联邦学习框架中,`Server`还支持弹性伸缩以及容灾,能够在训练任务不中断的情况下,动态调配硬件资源。 - -- Federated Learning Worker - - `Worker`为执行联邦学习任务的附件模块,用于对Server中的模型进行二次有监督训练,而后将训练所得模型下发给Server,在一个联邦学习任务中,`Worker`可以有多个(用户可配置),`Worker`和`Server`间通过TCP协议通信。 - -`Scheduler`和`Server`需部署在单网卡的服务器或者容器中,且处于相同网段。MindSpore自动获取首个可用IP地址作为`Server`地址。 - -> 服务器会校验客户端携带的时间戳,需要确保服务器定期同步时钟,避免服务器出现较大的时钟偏移。 - -## 准备环节 - -> 建议使用[Anaconda](https://www.anaconda.com/)创建虚拟环境进行如下操作。 - -### 安装MindSpore - -MindSpore横向联邦学习云侧集群支持在x86 CPU和GPU CUDA硬件平台上部署。可参考[MindSpore安装指南](https://www.mindspore.cn/install)安装MindSpore最新版本。 - -### 安装MindSpore Federated - -通过[源码](https://gitee.com/mindspore/federated)编译安装。 - -```shell -git clone https://gitee.com/mindspore/federated.git -b master -cd federated -bash build.sh -``` - -对于`bash build.sh`,可通过例如`-jn`选项,例如`-j16`,加速编译;可通过`-S on`选项,从gitee而不是github下载第三方依赖。 - -编译完成后,在`build/package/`目录下找到Federated的whl安装包进行安装: - -```bash -pip install mindspore_federated-{version}-{python_version}-linux_{arch}.whl -``` - -### 验证是否成功安装 - -执行以下命令,验证安装结果。导入Python模块不报错即安装成功: - -```python -from mindspore_federated import FLServerJob -``` - -### 安装和启动Redis服务器 - -联邦学习默认依赖[Redis服务器](https://gitee.com/link?target=https%3A%2F%2Fredis.io%2F)作为缓存数据中间件,运行联邦学习业务,需要安装和运行Redis服务器。 - -> 用户需自行检查Redis版本的安全性,某些Redis版本可能存在漏洞。 - -安装Redis服务器: - -```bash -sudo apt-get install redis -``` - -运行Redis服务器,配置端口号为:23456: - -```bash -redis-server --port 23456 --save "" -``` - -## 启动集群 - -1. [样例路径](https://gitee.com/mindspore/federated/tree/master/example/cross_device_lenet_femnist/)。 - - ```bash - cd example/cross_device_lenet_femnist - ``` - -2. 据实际运行需要修改yaml配置文件:`default_yaml_config.yaml`,如下为[Lenet的相关配置样例](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/yamls/lenet/default_yaml_config.yaml)。 - - ```yaml - fl_name: Lenet - fl_iteration_num: 25 - server_mode: FEDERATED_LEARNING - enable_ssl: False - - distributed_cache: - type: redis - address: 127.0.0.1:23456 # ip:port of redis actual machine - plugin_lib_path: "" - - round: - start_fl_job_threshold: 2 - start_fl_job_time_window: 30000 - update_model_ratio: 1.0 - update_model_time_window: 30000 - global_iteration_time_window: 60000 - - summary: - metrics_file: "metrics.json" - failure_event_file: "event.txt" - continuous_failure_times: 10 - data_rate_dir: ".." - participation_time_level: "5,15" - - unsupervised: - cluster_client_num: 1000 - eval_type: SILHOUETTE_SCORE - - encrypt: - encrypt_train_type: NOT_ENCRYPT - pw_encrypt: - share_secrets_ratio: 1.0 - cipher_time_window: 3000 - reconstruct_secrets_threshold: 1 - dp_encrypt: - dp_eps: 50.0 - dp_delta: 0.01 - dp_norm_clip: 1.0 - signds: - sign_k: 0.01 - sign_eps: 100 - sign_thr_ratio: 0.6 - sign_global_lr: 0.1 - sign_dim_out: 0 - - compression: - upload_compress_type: NO_COMPRESS - upload_sparse_rate: 0.4 - download_compress_type: NO_COMPRESS - - ssl: - # when ssl_config is set - # for tcp/http server - server_cert_path: "server.p12" - # for tcp client - client_cert_path: "client.p12" - # common - ca_cert_path: "ca.crt" - crl_path: "" - cipher_list: "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-PSK-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-CCM:ECDHE-ECDSA-AES256-CCM:ECDHE-ECDSA-CHACHA20-POLY1305" - cert_expire_warning_time_in_day: 90 - - client_verify: - pki_verify: false - root_first_ca_path: "" - root_second_ca_path: "" - equip_crl_path: "" - replay_attack_time_diff: 600000 - - client: - http_url_prefix: "" - client_epoch_num: 20 - client_batch_size: 32 - client_learning_rate: 0.01 - connection_num: 10000 - - ``` - -3. 准备模型文件,启动方式为:基于权重启动,需要提供相应的模型权重。 - - 获取lenet模型权重: - - ```bash - wget https://ms-release.obs.cn-north-4.myhuaweicloud.com/ms-dependencies/Lenet.ckpt - ``` - -4. 运行Scheduler,管理面地址默认为`127.0.0.1:11202`。 - - ```python - python run_sched.py \ - --yaml_config="yamls/lenet.yaml" \ - --scheduler_manage_address="10.*.*.*:18019" - ``` - -5. 运行Server,默认启动1个Server,HTTP服务器地址默认为`127.0.0.1:6666`。 - - ```python - python run_server.py \ - --yaml_config="yamls/lenet.yaml" \ - --tcp_server_ip="10.*.*.*" \ - --checkpoint_dir="fl_ckpt" \ - --local_server_num=1 \ - --http_server_address="10.*.*.*:8019" - ``` - -6. 停止联邦学习。当前版本联邦学习集群为常驻进程,可执行`finish_cloud.py`脚本,以终止联邦学习服务。执行指令的示例如下,其中`redis_port`传参,需与启动redis时的传参保持一致,代表停止此`Scheduler`对应的集群。 - - ```python - python finish_cloud.py --redis_port=23456 - ``` - - 若console打印如下内容: - - ```text - killed $PID1 - killed $PID2 - killed $PID3 - killed $PID4 - killed $PID5 - killed $PID6 - killed $PID7 - killed $PID8 - ``` - - 则表明停止服务成功。 - -## 弹性伸缩 - -MindSpore联邦学习框架支持`Server`的弹性伸缩,对外通过`Scheduler`管理端口提供`RESTful`服务,使得用户在不中断训练任务的情况下,对硬件资源进行动态调度。 - -以下示例介绍了如何通过对应接口,对控制集群扩容/缩容。 - -### 扩容 - -在集群启动后,进入部署scheduler节点的机器,向`Scheduler`发起请求,查询状态、节点信息。可使用`curl`指令构造`RESTful`请求。 - -```sh -curl -k 'http://10.*.*.*:18015/state' -``` - -`Scheduler`将返回`json`格式的查询结果。 - -```json -{ - "message":"Get cluster state successful.", - "cluster_state":"CLUSTER_READY", - "code":0, - "nodes":[ - {"node_id","{ip}:{port}::{timestamp}::{random}", - "tcp_address":"{ip}:{port}", - "role":"SERVER"} - ] -} -``` - -需要拉起3个新的`Server`进程,并将`local_server_num`参数累加扩容的个数,从而保证全局组网信息的正确性,即扩容后,`local_server_num`的数量应为4,执行指令的示例如下: - -```sh -python run_server.py --yaml_config="yamls/lenet.yaml" --tcp_server_ip="10.*.*.*" --checkpoint_dir="fl_ckpt" --local_server_num=4 --http_server_address="10.*.*.*:18015" -``` - -该指令代表启动四个`Server`节点,总`Server`数量为4。 - -### 缩容 - -直接使用kill -9 pid的方式模拟缩容,使用`curl`指令构造`RESTful`请求,查询状态,发现集群中少了一个node_id,达到缩容目的。 - -```sh -curl -k \ -'http://10.*.*.*:18015/state' -``` - -`Scheduler`将返回`json`格式的查询结果。 - -```json -{ - "message":"Get cluster state successful.", - "cluster_state":"CLUSTER_READY", - "code":0, - "nodes":[ - {"node_id","{ip}:{port}::{timestamp}::{random}", - "tcp_address":"{ip}:{port}", - "role":"SERVER"}, - {"node_id","worker_fl_{timestamp}::{random}", - "tcp_address":"", - "role":"WORKER"}, - {"node_id","worker_fl_{timestamp}::{random}", - "tcp_address":"", - "role":"WORKER"} - ] -} -``` - -> - 在集群扩容/缩容成功后,训练任务会自动恢复,不需要用户进行额外干预。 - -## 安全 - -MindSpore联邦学习框架支持`Server`的SSL安全认证,要开启安全认证,需要在启动命令加上`enable_ssl=True`,config_file_path指定的config.json配置文件需要添加如下字段: - -```json -{ - "server_cert_path": "server.p12", - "crl_path": "", - "client_cert_path": "client.p12", - "ca_cert_path": "ca.crt", - "cert_expire_warning_time_in_day": 90, - "cipher_list": "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-DSS-AES128-GCM-SHA256:kEDH+AESGCM:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA:ECDHE-ECDSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-DSS-AES128-SHA256:DHE-RSA-AES256-SHA256:DHE-DSS-AES256-SHA:DHE-RSA-AES256-SHA:!aNULL:!eNULL:!EXPORT:!DES:!RC4:!3DES:!MD5:!PSK", - "connection_num":10000 -} -``` - -- server_cert_path:服务端包含证书和密钥的密文的p12文件路径。 -- crl_path:吊销列表的文件。 -- client_cert_path:客户端包含证书和密钥的密文的p12文件路径。 -- ca_cert_path:根证书。 -- cipher_list:密码套件。 -- cert_expire_warning_time_in_day:证书过期的告警时间。 - -p12文件中的密钥为密文存储。 diff --git a/docs/federated/docs/source_zh_cn/deploy_vfl.md b/docs/federated/docs/source_zh_cn/deploy_vfl.md deleted file mode 100644 index 640c22f706b26e17cd6ad2e57d724fd423d79e7a..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/deploy_vfl.md +++ /dev/null @@ -1,69 +0,0 @@ -# 纵向联邦部署 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/deploy_vfl.md) - -本文档讲解如何使用和部署纵向联邦学习框架。 - -MindSpore Vertical Federated Learning (VFL) 物理架构如图所示: - -![](./images/deploy_VFL.png) - -如上图所示,在纵向联邦的交互中有两个参与方:Leader node和Follower node,每一个参与方都有两种角色的进程:`FLDataWorker`和`VFLTrainer`: - -- FLDataWorker - - `FLDataWorker`的功能主要包括: - - 1. 数据集合求交:获得纵向联邦参与双方的共有用户交集,支持隐私集合求交协议,可防止联邦学习参与方获得交集外的ID信息。 - 2. 训练数据生成:在获得交集ID之后,扩充数据特征,生成用于训练的mindrecord文件。 - 3. 开放管理面:向用户提供`RESTful`接口,实现对集群的管理。 - - 在一个联邦学习任务中,只有一个`Scheduler`,其与`Server`通过TCP协议通信。 - -- VFLTrainer - - `VFLTrainer`为执行纵向联邦训练任务的主体,执行模型拆分后的正反向计算、Embedding张量传输,梯度张量传输、反向优化器更新等任务。当前版本支持单机单卡和单机多卡的训练模式。 - - 在MindSpore联邦学习框架中,`Server`还支持弹性伸缩以及容灾,能够在训练任务不中断的情况下,动态调配硬件资源。 - -`FLDataWorker`和`VFLTrainer`一般部署在同一台服务器或者容器中。 - -## 准备环节 - -> 建议使用[Anaconda](https://www.anaconda.com/)创建虚拟环境进行如下操作。 - -### 安装MindSpore - -MindSpore纵向联邦支持在x86 CPU、GPU CUDA和Ascend硬件平台上部署。可参考[MindSpore安装指南](https://www.mindspore.cn/install)安装MindSpore最新版本。 - -### 安装MindSpore Federated - -通过[源码](https://gitee.com/mindspore/federated)编译安装。 - -```shell -git clone https://gitee.com/mindspore/federated.git -b master -cd federated -bash build.sh -``` - -对于`bash build.sh`,可通过例如`-jn`选项,例如`-j16`,加速编译;可通过`-S on`选项,从gitee而不是github下载第三方依赖。 - -编译完成后,在`build/package/`目录下找到Federated的whl安装包进行安装: - -```shell -pip install mindspore_federated-{version}-{python_version}-linux_{arch}.whl -``` - -#### 验证是否成功安装 - -执行以下命令,验证安装结果。导入Python模块不报错即安装成功: - -```python -from mindspore_federated import FLServerJob -``` - -## 运行样例 - -FLDataWorker的运行样例可参考[纵向联邦学习数据接入](https://www.mindspore.cn/federated/docs/zh-CN/master/data_join.html)。 - -VFLTrainer的运行样例可参考[纵向联邦学习模型训练 - Wide&Deep推荐应用](https://www.mindspore.cn/federated/docs/zh-CN/master/split_wnd_application.html)。 diff --git a/docs/federated/docs/source_zh_cn/faq.md b/docs/federated/docs/source_zh_cn/faq.md deleted file mode 100644 index 7e9f84e7777a46b01f7a929bf09504f1bc87f93a..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/faq.md +++ /dev/null @@ -1,9 +0,0 @@ -# FAQ - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/faq.md) - -**Q: 请问如果集群组网不成功,怎么定位原因?** - -A: 请查看服务器的网络情况,譬如请查看防火墙是否禁止端口访问,请设置防火墙允许端口访问。 - -
    diff --git a/docs/federated/docs/source_zh_cn/federated_install.md b/docs/federated/docs/source_zh_cn/federated_install.md deleted file mode 100644 index 7de6d374ee4b68d05636407250713eab53cc0472..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/federated_install.md +++ /dev/null @@ -1,25 +0,0 @@ -# 获取MindSpore Federated - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/federated_install.md) - -[MindSpore Federated](https://gitee.com/mindspore/federated)框架代码现已独立建仓,分为端侧和云侧,其云侧能力依赖MindSpore和MindSpore Federated,利用MindSpore进行云侧集群聚合训练以及与端侧进行通信,因此需要分别获取MindSpore whl包和MindSpore Federated whl包。端侧能力依赖MindSpore Lite和MindSpore Federated java包,其中MindSpore Federated java主要负责数据预处理、调用MindSpore Lite进行模型训练和推理以及利用隐私保护机制和云侧进行模型相关的上传和下载。 - -## 获取MindSpore whl包 - -包括源码编译和下载发布版两种方式,支持x86 CPU、GPU CUDA等硬件平台,根据硬件平台类型,选择进行安装即可。安装步骤可参考[MindSpore安装指南](https://www.mindspore.cn/install)。 - -## 获取MindSpore Lite java包 - -包括源码编译和下载发布版两种方式。目前,MindSpore Lite联邦学习功能只支持Linux和Android平台,且只支持CPU。安装步骤可参考[下载MindSpore Lite](https://www.mindspore.cn/lite/docs/zh-CN/master/use/downloads.html)和[编译MindSpore Lite](https://www.mindspore.cn/lite/docs/zh-CN/master/build/build.html)。 - -## 获取MindSpore Federated whl包 - -包括源码编译和下载发布版两种方式,支持x86 CPU、GPU CUDA等硬件平台,根据硬件平台类型,选择进行安装即可。安装步骤可参考[编译MindSpore Federated whl](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_server.html)。 - -## 获取MindSpore Federated java包 - -包括源码编译和下载发布版两种方式。目前,MindSpore Federated联邦学习功能只支持Linux和Android平台。安装步骤可参考和[编译MindSpore Federated java](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)。 - -## Linux编译环境要求 - -目前源码编译只支持Linux,编译环境要求可参考[MindSpore源码编译](https://www.mindspore.cn/install)和[MindSpore Lite源码编译](https://www.mindspore.cn/lite/docs/zh-CN/master/build/build.html)。 diff --git a/docs/federated/docs/source_zh_cn/horizontal_server.rst b/docs/federated/docs/source_zh_cn/horizontal_server.rst deleted file mode 100644 index d5a4319b56f8fa8a8d28a34bb504b31ea40dcb9b..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/horizontal_server.rst +++ /dev/null @@ -1,12 +0,0 @@ -联邦服务器 -============== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/horizontal_server.rst - :alt: 查看源文件 - -.. toctree:: - :maxdepth: 1 - - horizontal/federated_server - horizontal/federated_server_yaml diff --git a/docs/federated/docs/source_zh_cn/image_classfication_dataset_process.md b/docs/federated/docs/source_zh_cn/image_classfication_dataset_process.md deleted file mode 100644 index 6f991cc4e4eb3cd315e67cef25426b2396fe326d..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/image_classfication_dataset_process.md +++ /dev/null @@ -1,451 +0,0 @@ -# 联邦学习图像分类数据集处理 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/image_classfication_dataset_process.md) - -本教程采用`leaf`数据集中的联邦学习数据集`FEMNIST`,该数据集包含62个不同类别的手写数字和字母(数字0~9、26个小写字母、26个大写字母),图像大小为`28 x 28`像素,数据集包含3500个用户的手写数字和字母(最多可模拟3500个客户端参与联邦学习),总数据量为805263,平均每个用户包含数据量为226.83,所有用户数据量的方差为88.94。 - -参考[leaf数据集官方指导](https://github.com/TalwalkarLab/leaf)下载数据集。 - -1. 下载数据集前的环境要求。 - - ```sh - numpy==1.16.4 - scipy # conda install scipy - tensorflow==1.13.1 # pip install tensorflow - Pillow # pip install Pillow - matplotlib # pip install matplotlib - jupyter # conda install jupyter notebook==5.7.8 tornado==4.5.3 - pandas # pip install pandas - ``` - -2. 使用git下载官方数据集生成脚本。 - - ```sh - git clone https://github.com/TalwalkarLab/leaf.git - ``` - - 下载项目后,目录结构如下: - - ```sh - leaf/data/femnist - ├── data # 用来存放指令生成的数据集 - ├── preprocess # 存放数据预处理的相关代码 - ├── preprocess.sh # femnist数据集生成shell脚本 - └── README.md # 官方数据集下载指导文档 - ``` - -3. 以`femnist`数据集为例,运行以下指令进入指定路径。 - - ```sh - cd leaf/data/femnist - ``` - -4. 用指令`./preprocess.sh -s niid --sf 1.0 -k 0 -t sample`生成的数据集包含3500个用户,且按照9:1对每个用户的数据划分训练和测试集。 - - 指令中参数含义可参考`leaf/data/femnist/README.md`文件中的说明。 - - 运行之后目录结构如下: - - ```text - leaf/data/femnist/35_client_sf1_data/ - ├── all_data # 所有数据集混合在一起,不区分训练测试集,共包含35个json文件,每个json文件包含100个用户的数据 - ├── test # 按照9:1对每个用户的数据划分训练和测试集后的测试集,共包含35个json文件,每个json文件包含100个用户的数据 - ├── train # 按照9:1对每个用户的数据划分训练和测试集后的训练集,共包含35个json文件,每个json文件包含100个用户的数据 - └── ... # 其他文件,暂不需要用到,不作介绍 - ``` - - 其中每个json文件包含以下三个部分: - - - `users`: 用户列表。 - - `num_samples`: 每个用户的样本数量列表。 - - `user_data`: 一个以用户名为key,以它们各自的数据为value的字典对象;对于每个用户,数据表示为图像列表,每张图像表示为大小为784的整数列表(将`28 x 28`图像数组展平所得)。 - - 在重新运行`preprocess.sh`之前,请确保删除数据目录中的`rem_user_data`、`sampled_data`、`test`和`train`子文件夹。 - -5. 将35个json文件划分为3500个json文件(每个json文件代表一个用户)。 - - 参考代码如下: - - ```python - import os - import json - - def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - def partition_json(root_path, new_root_path): - """ - partition 35 json files to 3500 json file - - Each raw .json file is an object with 3 keys: - 1. 'users', a list of users - 2. 'num_samples', a list of the number of samples for each user - 3. 'user_data', an object with user names as keys and their respective data as values; for each user, data is represented as a list of images, with each image represented as a size-784 integer list (flattened from 28 by 28) - - Each new .json file is an object with 3 keys: - 1. 'user_name', the name of user - 2. 'num_samples', the number of samples for the user - 3. 'user_data', an dict object with 'x' as keys and their respective data as values; with 'y' as keys and their respective label as values; - - Args: - root_path (str): raw root path of 35 json files - new_root_path (str): new root path of 3500 json files - """ - paths = os.listdir(root_path) - count = 0 - file_num = 0 - for i in paths: - file_num += 1 - file_path = os.path.join(root_path, i) - print('======== process ' + str(file_num) + ' file: ' + str(file_path) + '======================') - with open(file_path, 'r') as load_f: - load_dict = json.load(load_f) - users = load_dict['users'] - num_users = len(users) - num_samples = load_dict['num_samples'] - for j in range(num_users): - count += 1 - print('---processing user: ' + str(count) + '---') - cur_out = {'user_name': None, 'num_samples': None, 'user_data': {}} - cur_user_id = users[j] - cur_data_num = num_samples[j] - cur_user_path = os.path.join(new_root_path, cur_user_id + '.json') - cur_out['user_name'] = cur_user_id - cur_out['num_samples'] = cur_data_num - cur_out['user_data'].update(load_dict['user_data'][cur_user_id]) - with open(cur_user_path, 'w') as f: - json.dump(cur_out, f) - f = os.listdir(new_root_path) - print(len(f), ' users have been processed!') - # partition train json files - partition_json("leaf/data/femnist/35_client_sf1_data/train", "leaf/data/femnist/3500_client_json/train") - # partition test json files - partition_json("leaf/data/femnist/35_client_sf1_data/test", "leaf/data/femnist/3500_client_json/test") - ``` - - 其中`root_path`为`leaf/data/femnist/35_client_sf1_data/{train,test}`,`new_root_path`自行设置,用于存放生成的3500个用户json文件,需分别对训练和测试文件夹进行处理。 - - 新生成的3500个用户json文件,每个文件均包含以下三个部分: - - - `user_name`: 用户名。 - - `num_samples`: 用户的样本数。 - - `user_data`: 一个以'x'为key,以用户数据为value的字典对象;以'y'为key,以用户数据对应的标签为value。 - - 运行该脚本打印如下,代表运行成功: - - ```sh - ======== process 1 file: /leaf/data/femnist/35_client_sf1_data/train/all_data_16_niid_0_keep_0_train_9.json====================== - ---processing user: 1--- - ---processing user: 2--- - ---processing user: 3--- - ...... - ``` - -6. 将json文件转换为图片文件。 - - 可参考如下代码: - - ```python - import os - import json - import numpy as np - from PIL import Image - - name_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', - 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', - 'V', 'W', 'X', 'Y', 'Z', - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', - 'v', 'w', 'x', 'y', 'z' - ] - - def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - def json_2_numpy(img_size, file_path): - """ - read json file to numpy - Args: - img_size (list): contain three elements: the height, width, channel of image - file_path (str): root path of 3500 json files - return: - image_numpy (numpy) - label_numpy (numpy) - """ - # open json file - with open(file_path, 'r') as load_f_train: - load_dict = json.load(load_f_train) - num_samples = load_dict['num_samples'] - x = load_dict['user_data']['x'] - y = load_dict['user_data']['y'] - size = (num_samples, img_size[0], img_size[1], img_size[2]) - image_numpy = np.array(x, dtype=np.float32).reshape(size) # mindspore doesn't support float64 and int64 - label_numpy = np.array(y, dtype=np.int32) - return image_numpy, label_numpy - - def json_2_img(json_path, save_path): - """ - transform single json file to images - - Args: - json_path (str): the path json file - save_path (str): the root path to save images - - """ - data, label = json_2_numpy([28, 28, 1], json_path) - for i in range(data.shape[0]): - img = data[i] * 255 # PIL don't support the 0/1 image ,need convert to 0~255 image - im = Image.fromarray(np.squeeze(img)) - im = im.convert('L') - img_name = str(label[i]) + '_' + name_list[label[i]] + '_' + str(i) + '.png' - path1 = os.path.join(save_path, str(label[i])) - mkdir(path1) - img_path = os.path.join(path1, img_name) - im.save(img_path) - print('-----', i, '-----') - - def all_json_2_img(root_path, save_root_path): - """ - transform json files to images - Args: - json_path (str): the root path of 3500 json files - save_path (str): the root path to save images - """ - usage = ['train', 'test'] - for i in range(2): - x = usage[i] - files_path = os.path.join(root_path, x) - files = os.listdir(files_path) - - for name in files: - user_name = name.split('.')[0] - json_path = os.path.join(files_path, name) - save_path1 = os.path.join(save_root_path, user_name) - mkdir(save_path1) - save_path = os.path.join(save_path1, x) - mkdir(save_path) - print('=============================' + name + '=======================') - json_2_img(json_path, save_path) - - all_json_2_img("leaf/data/femnist/3500_client_json/", "leaf/data/femnist/3500_client_img/") - ``` - - 运行该脚本打印如下,代表运行成功: - - ```sh - =============================f0644_19.json======================= - ----- 0 ----- - ----- 1 ----- - ----- 2 ----- - ...... - ``` - -7. 由于有些用户文件夹下的数据集较小,若数量小于batch size,需要进行随机扩充。 - - 可参考下面代码对整个数据集`"leaf/data/femnist/3500_client_img/"`进行检查并扩充: - - ```python - import os - import shutil - from random import choice - - def count_dir(path): - num = 0 - for root, dirs, files in os.walk(path): - for file in files: - num += 1 - return num - - def get_img_list(path): - img_path_list = [] - label_list = os.listdir(path) - for i in range(len(label_list)): - label = label_list[i] - imgs_path = os.path.join(path, label) - imgs_name = os.listdir(imgs_path) - for j in range(len(imgs_name)): - img_name = imgs_name[j] - img_path = os.path.join(imgs_path, img_name) - img_path_list.append(img_path) - return img_path_list - - def data_aug(data_root_path, batch_size = 32): - users = os.listdir(data_root_path) - tags = ["train", "test"] - aug_users = [] - for i in range(len(users)): - user = users[i] - for tag in tags: - data_path = os.path.join(data_root_path, user, tag) - num_data = count_dir(data_path) - if num_data < batch_size: - aug_users.append(user + "_" + tag) - print("user: ", user, " ", tag, " data number: ", num_data, " < ", batch_size, " should be aug") - aug_num = batch_size - num_data - img_path_list = get_img_list(data_path) - for j in range(aug_num): - img_path = choice(img_path_list) - info = img_path.split(".") - aug_img_path = info[0] + "_aug_" + str(j) + ".png" - shutil.copy(img_path, aug_img_path) - print("[aug", j, "]", "============= copy file:", img_path, "to ->", aug_img_path) - print("the number of all aug users: " + str(len(aug_users))) - print("aug user name: ", end=" ") - for k in range(len(aug_users)): - print(aug_users[k], end = " ") - - if __name__ == "__main__": - data_root_path = "leaf/data/femnist/3500_client_img/" - batch_size = 32 - data_aug(data_root_path, batch_size) - ``` - -8. 将扩充后图片数据集转换为联邦学习框架可用的bin文件格式。 - - 可参考下面代码: - - ```python - import numpy as np - import os - import mindspore.dataset as ds - import mindspore.dataset.vision as vision - import mindspore.dataset.transforms as transforms - import mindspore - - def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - def count_id(path): - files = os.listdir(path) - ids = {} - for i in files: - ids[i] = int(i) - return ids - - def create_dataset_from_folder(data_path, img_size, batch_size=32, repeat_size=1, num_parallel_workers=1, shuffle=False): - """ create dataset for train or test - Args: - data_path: Data path - batch_size: The number of data records in each group - repeat_size: The number of replicated data records - num_parallel_workers: The number of parallel workers - """ - # define dataset - ids = count_id(data_path) - mnist_ds = ds.ImageFolderDataset(dataset_dir=data_path, decode=False, class_indexing=ids) - # define operation parameters - resize_height, resize_width = img_size[0], img_size[1] # 32 - - transform = [ - vision.Decode(True), - vision.Grayscale(1), - vision.Resize(size=(resize_height, resize_width)), - vision.Grayscale(3), - vision.ToTensor(), - ] - compose = transforms.Compose(transform) - - # apply map operations on images - mnist_ds = mnist_ds.map(input_columns="label", operations=transforms.TypeCast(mindspore.int32)) - mnist_ds = mnist_ds.map(input_columns="image", operations=compose) - - # apply DatasetOps - buffer_size = 10000 - if shuffle: - mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script - mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) - mnist_ds = mnist_ds.repeat(repeat_size) - return mnist_ds - - def img2bin(root_path, root_save): - """ - transform images to bin files - - Args: - root_path: the root path of 3500 images files - root_save: the root path to save bin files - - """ - - use_list = [] - train_batch_num = [] - test_batch_num = [] - mkdir(root_save) - users = os.listdir(root_path) - for user in users: - use_list.append(user) - user_path = os.path.join(root_path, user) - train_test = os.listdir(user_path) - for tag in train_test: - data_path = os.path.join(user_path, tag) - dataset = create_dataset_from_folder(data_path, (32, 32, 1), 32) - batch_num = 0 - img_list = [] - label_list = [] - for data in dataset.create_dict_iterator(): - batch_x_tensor = data['image'] - batch_y_tensor = data['label'] - trans_img = np.transpose(batch_x_tensor.asnumpy(), [0, 2, 3, 1]) - img_list.append(trans_img) - label_list.append(batch_y_tensor.asnumpy()) - batch_num += 1 - - if tag == "train": - train_batch_num.append(batch_num) - elif tag == "test": - test_batch_num.append(batch_num) - - imgs = np.array(img_list) # (batch_num, 32,3,32,32) - labels = np.array(label_list) - path1 = os.path.join(root_save, user) - mkdir(path1) - image_path = os.path.join(path1, user + "_" + "bn_" + str(batch_num) + "_" + tag + "_data.bin") - label_path = os.path.join(path1, user + "_" + "bn_" + str(batch_num) + "_" + tag + "_label.bin") - - imgs.tofile(image_path) - labels.tofile(label_path) - print("user: " + user + " " + tag + "_batch_num: " + str(batch_num)) - print("total " + str(len(use_list)) + " users finished!") - - root_path = "leaf/data/femnist/3500_client_img/" - root_save = "leaf/data/femnist/3500_clients_bin" - img2bin(root_path, root_save) - ``` - - 运行该脚本打印如下,代表运行成功: - - ```sh - user: f0141_43 test_batch_num: 1 - user: f0141_43 train_batch_num: 10 - user: f0137_14 test_batch_num: 1 - user: f0137_14 train_batch_num: 11 - ...... - total 3500 users finished! - ``` - -9. 生成`3500_clients_bin`文件夹内共包含3500个用户文件夹,其目录结构如下: - - ```sh - leaf/data/femnist/3500_clients_bin - ├── f0000_14 # 用户编号 - │ ├── f0000_14_bn_10_train_data.bin # 用户f0000_14的训练数据 (bn_后面的数字10代表batch number) - │ ├── f0000_14_bn_10_train_label.bin # 用户f0000_14的训练标签 - │ ├── f0000_14_bn_1_test_data.bin # 用户f0000_14的测试数据 (bn_后面的数字1代表batch number) - │ └── f0000_14_bn_1_test_label.bin # 用户f0000_14的测试标签 - ├── f0001_41 # 用户编号 - │ ├── f0001_41_bn_11_train_data.bin # 用户f0001_41的训练数据 (bn_后面的数字11代表batch number) - │ ├── f0001_41_bn_11_train_label.bin # 用户f0001_41的训练标签 - │ ├── f0001_41_bn_1_test_data.bin # 用户f0001_41的测试数据 (bn_后面的数字1代表batch number) - │ └── f0001_41_bn_1_test_label.bin # 用户f0001_41的测试标签 - │ ... - └── f4099_10 # 用户编号 - ├── f4099_10_bn_4_train_data.bin # 用户f4099_10的训练数据 (bn_后面的数字4代表batch number) - ├── f4099_10_bn_4_train_label.bin # 用户f4099_10的训练标签 - ├── f4099_10_bn_1_test_data.bin # 用户f4099_10的测试数据 (bn_后面的数字1代表batch number) - └── f4099_10_bn_1_test_label.bin # 用户f4099_10的测试标签 - ``` - -根据以上1~9步骤生成的`3500_clients_bin`文件夹可直接作为端云联邦图像分类任务的输入数据。 - diff --git a/docs/federated/docs/source_zh_cn/image_classification_application.md b/docs/federated/docs/source_zh_cn/image_classification_application.md deleted file mode 100644 index ff194ff3b2518306db66d1d5c9d129221f576c91..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/image_classification_application.md +++ /dev/null @@ -1,331 +0,0 @@ -# 实现一个端云联邦的图像分类应用(x86) - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/image_classification_application.md) - -根据参与客户端的类型,联邦学习可分为云云联邦学习(cross-silo)和端云联邦学习(cross-device)。在云云联邦学习场景中,参与联邦学习的客户端是不同的组织(例如,医疗或金融)或地理分布的数据中心,即在多个数据孤岛上训练模型。在端云联邦学习场景中,参与的客户端为大量的移动或物联网设备。本框架将介绍如何在MindSpore端云联邦框架上使用网络LeNet实现一个图片分类应用,并提供在x86环境中模拟启动多客户端参与联邦学习的相关教程。 - -在动手进行实践之前,确保你已经正确安装了MindSpore。如果没有,可以参考[MindSpore安装页面](https://www.mindspore.cn/install)完成安装。 - -## 准备工作 - -我们提供了可供用户直接使用的[联邦学习图像分类数据集FEMNIST](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/federated/3500_clients_bin.zip),以及`.ms`格式的[端侧模型文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/lenet_train.ms)。用户也可以根据实际需求,参考以下教程自行生成数据集和模型。 - -### 生成端侧模型文件 - -1. 定义网络和训练过程。 - - 具体网络和训练过程的定义可参考[快速入门](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/quick_start.html#网络构建)。 - -2. 将模型导出为MindIR格式文件。 - - 代码片段如下: - - ```python - import argparse - import numpy as np - import mindspore as ms - import mindspore.nn as nn - - def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - """weight initial for conv layer""" - weight = weight_variable() - return nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_init=weight, - has_bias=False, - pad_mode="valid", - ) - - def fc_with_initialize(input_channels, out_channels): - """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - - def weight_variable(): - """weight initial""" - return ms.common.initializer.TruncatedNormal(0.02) - - class LeNet5(nn.Cell): - def __init__(self, num_class=10, channel=3): - super(LeNet5, self).__init__() - self.num_class = num_class - self.conv1 = conv(channel, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16 * 5 * 5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, self.num_class) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x - - parser = argparse.ArgumentParser(description="export mindir for lenet") - parser.add_argument("--device_target", type=str, default="CPU") - parser.add_argument("--mindir_path", type=str, - default="lenet_train.mindir") # the mindir file path of the model to be export - - args, _ = parser.parse_known_args() - device_target = args.device_target - mindir_path = args.mindir_path - - ms.set_context(mode=ms.GRAPH_MODE, device_target=device_target) - - if __name__ == "__main__": - np.random.seed(0) - network = LeNet5(62) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction="mean") - net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) - net_with_criterion = nn.WithLossCell(network, criterion) - train_network = nn.TrainOneStepCell(net_with_criterion, net_opt) - train_network.set_train() - - data = ms.Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) - label = ms.Tensor(np.random.randint(0, 1, (32, 62)).astype(np.float32)) - ms.export(train_network, data, label, file_name=mindir_path, - file_format='MINDIR') # Add the export statement to obtain the model file in MindIR format. - ``` - - 参数`--mindir_path`用于设置生成的MindIR格式文件路径。 - -3. 将MindIR文件转化为联邦学习端侧框架可用的ms文件。 - - 模型转换可参考[训练模型转换教程](https://www.mindspore.cn/lite/docs/zh-CN/master/converter/converter_tool.html)。 - - 模型转换示例如下: - - 假设待转换的模型文件为`lenet_train.mindir`,执行如下转换命令: - - ```sh - ./converter_lite --fmk=MINDIR --trainModel=true --modelFile=lenet_train.mindir --outputFile=lenet_train - ``` - - 转换成功输出如下: - - ```sh - CONVERT RESULT SUCCESS:0 - ``` - - 这表明MindSpore模型成功转换为MindSpore端侧模型,并生成了新文件`lenet_train.ms`。如果转换失败输出如下: - - ```sh - CONVERT RESULT FAILED: - ``` - - 生成的`.ms`格式的模型文件为后续客户端所需的模型文件。 - -## 模拟启动多客户端参与联邦学习 - -### 为客户端准备好模型文件 - -本例在端侧使用lenet模拟实际用的网络,其中lenet的`.ms`格式的[端侧模型文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/lenet_train.ms),由于真实场景一个客户端只包含一个.ms格式的模型文件,在模拟场景中,需要拷贝多份.ms文件,并按照`lenet_train{i}.ms`格式进行命名。其中i代表客户端编号,由于`run_client_x86.py`中,已自动为每个客户端拷贝.ms文件。 -具体见[启动脚本](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/simulate_x86/run_client_x86.py)中的copy_ms函数。 - -### 启动云侧服务 - -用户可先参考[横向云侧部署教程](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_server.html)部署云侧环境,并启动云侧服务。 - -### 启动客户端 - -启动客户端之前请先参照[横向端侧部署教程](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)进行端侧环境部署。 - -使用提供的[run_client_x86.py](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/simulate_x86/run_client_x86.py)脚本进行端侧联邦学习的启动,通过相关参数的设置,来启动不同的联邦学习接口。 -待云侧服务启动成功之后,使用提供run_client_x86.py的脚本,调用联邦学习框架jar包`mindspore-lite-java-flclient.jar` 和模型脚本对应的jar包`quick_start_flclient.jar`(可参考[横向端侧部署中编译出包流程](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)获取)来模拟启动多客户端参与联邦学习任务。 - -以LeNet网络为例,`run_client_x86.py`脚本中部分入参含义如下,用户可根据实际情况进行设置: - -- `--fl_jar_path` - - 设置联邦学习jar包路径,x86环境联邦学习jar包获取可参考[横向端侧部署中编译出包流程](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)。 - -- `--case_jar_path` - - 设置模型脚本所生成的jar包`quick_start_flclient.jar`的路径,x86环境联邦学习jar包获取可参考[横向联邦端侧部署中编译出包流程](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)。 - -- `--lite_jar_path` - - 设置mindspore lite的端侧jar包`mindspore-lite-java.jar`的路径,位于端侧包mindspore-lite-{version}-linux-x64.tar.gz中,x86环境联邦学习jar包获取可参考[横向端侧部署中构建环境依赖](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)。 - -- `--train_data_dir` - - 训练数据集root路径,LeNet图片分类任务在该root路径中存放的是每个客户端的训练data.bin文件与label.bin文件,例如`data/femnist/3500_clients_bin/`。 - -- `--fl_name` - - 联邦学习使用的模型脚本包路径。我们提供了两个类型的模型脚本供大家参考([有监督情感分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert)、[LeNet图片分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)),对于有监督情感分类任务,该参数可设置为所提供的脚本文件[AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java) 的包路径`com.mindspore.flclient.demo.albert.AlbertClient`;对于LeNet图片分类任务,该参数可设置为所提供的脚本文件[LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java) 的包路径`com.mindspore.flclient.demo.lenet.LenetClient`。同时,用户可参考这两个类型的模型脚本,自定义模型脚本,然后将该参数设置为自定义的模型文件ModelClient.java(需继承于类[Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java))的包路径即可。 - -- `--train_model_dir` - - 设置联邦学习使用的训练模型路径,为上面教程中拷贝的多份.ms文件所存放的目录,比如`ms/lenet`,必须为绝对路径。 - -- `--domain_name` - - 用于设置端云通信url,目前,可支持https和http通信,对应格式分别为:https://......、http://......,当`if_use_elb`设置为true时,格式必须为: 或者 ,其中`127.0.0.1`对应提供云侧服务的机器ip(即云侧参数`--scheduler_ip`),`6666`对应云侧参数`--fl_server_port`。 - - 注意1,当该参数设置为`http://......`时代表使用HTTP通信,可能会存在通信安全风险,请知悉。 - - 注意2,当该参数设置为`https://......`代表使用HTTPS通信。此时必须进行SSL证书认证,需要通过参数`--cert_path`设置证书路径。 - -- `--task` - - 用于设置本此启动的任务类型,为`train`代表启动训练任务,为`inference`代表启动多条数据推理任务,为`getModel`代表启动获取云侧模型的任务,设置其他字符串代表启动单条数据推理任务。默认为`train`。由于初始的模型文件(.ms文件)是未训练过的,建议先启动训练任务,待训练完成之后,再启动推理任务(注意两次启动的`client_num`保持一致,以保证`inference`使用的模型文件与`train`保持一致)。 - -- `--batch_size` - - 设置联邦学习训练和推理时使用的单步训练样本数,即batch size。需与模型的输入数据的batch size保持一致。 - -- `--client_num` - - 设置client数量,与启动server端时的`start_fl_job_cnt`保持一致,真实场景不需要此参数。 - -若想进一步了解`run_client_x86.py`脚本中其他参数含义,可参考脚本中注释部分。 - -联邦学习接口基本启动指令示例如下: - -```sh - rm -rf client_*\ - && rm -rf ms/* \ - && python3 run_client_x86.py \ - --fl_jar_path="federated/mindspore_federated/device_client/build/libs/jarX86/mindspore-lite-java-flclient.jar" \ - --case_jar_path="federated/example/quick_start_flclient/target/case_jar/quick_start_flclient.jar" \ - --lite_jar_path="federated/mindspore_federated/device_client/third/mindspore-lite-2.0.0-linux-x64/runtime/lib/mindspore-lite-java.jar" \ - --train_data_dir="federated/tests/st/simulate_x86/data/3500_clients_bin/" \ - --eval_data_dir="null" \ - --infer_data_dir="null" \ - --vocab_path="null" \ - --ids_path="null" \ - --path_regex="," \ - --fl_name="com.mindspore.flclient.demo.lenet.LenetClient" \ - --origin_train_model_path="federated/tests/st/simulate_x86/ms_files/lenet/lenet_train.ms" \ - --origin_infer_model_path="null" \ - --train_model_dir="ms" \ - --infer_model_dir="ms" \ - --ssl_protocol="TLSv1.2" \ - --deploy_env="x86" \ - --domain_name="http://10.*.*.*:8010" \ - --cert_path="CARoot.pem" --use_elb="false" \ - --server_num=1 \ - --task="train" \ - --thread_num=1 \ - --cpu_bind_mode="NOT_BINDING_CORE" \ - --train_weight_name="null" \ - --infer_weight_name="null" \ - --name_regex="::" \ - --server_mode="FEDERATED_LEARNING" \ - --batch_size=32 \ - --input_shape="null" \ - --client_num=8 -``` - -注意,启动指令中涉及路径的必须给出绝对路径。 - -以上指令代表启动8个客户端参与联邦学习训练任务,若启动成功,会在当前文件夹生成8个客户端对应的日志文件,查看日志文件内容可了解每个客户端的运行情况: - -```text -./ -├── client_0 -│ └── client.log # 客户端0的日志文件 -│ ...... -└── client_7 - └── client.log # 客户端4的日志文件 -``` - -针对不同的接口和场景,只需根据参数含义,修改特定参数值即可,比如: - -- 启动联邦学习训练任务SyncFLJob.flJobRun() - - 当`基本启动指令`中 `--task`设置为`train`时代表启动该任务。 - - 可通过指令`grep -r "average loss:" client_0/client.log`查看`client_0`在训练过程中每个epoch的平均loss,会有类似如下打印: - - ```sh - INFO: ----------epoch:0,average loss:4.1258564 ---------- - ...... - ``` - - 也可通过指令`grep -r "evaluate acc:" client_0/client.log`查看`client_0`在每个联邦学习迭代中聚合后模型的验证精度,会有类似如下打印: - - ```sh - INFO: [evaluate] evaluate acc: 0.125 - ...... - ``` - - 在云侧,可以通过设置yaml配置文件的`cluster_client_num`参数与`eval_type`参数来指定进行无监督聚类指标统计的客户端group id数量与算法类型,在云侧生成的`metrics.json`统计文件可以查询到无监督指标信息: - - ```text - "unsupervisedEval":0.640 - "unsupervisedEval":0.675 - "unsupervisedEval":0.677 - "unsupervisedEval":0.706 - ...... - ``` - -- 启动推理任务SyncFLJob.modelInference() - - 当`基本启动指令`中 `--task`设置为`inference`时代表启动该任务。 - - 可通过指令`grep -r "the predicted labels:" client_0/client.log`查看`client_0`的推理结果: - - ```sh - INFO: [model inference] the predicted labels: [0, 0, 0, 1, 1, 1, 2, 2, 2] - ...... - ``` - -- 启动获取云侧最新模型任务SyncFLJob.getModel() - - 当`基本启动指令`中 `--task`设置为`getModel`时代表启动该任务。 - - 在日志文件中若有如下内容代表获取云侧最新模型成功: - - ```sh - INFO: [getModel] get response from server ok! - ``` - -### 关闭客户端进程 - -可参考[finish.py](https://gitee.com/mindspore/federated/blob/master/example/cross_device_lenet_femnist/simulate_x86/finish.py)脚本,具体如下: - -关闭客户端指令如下: - -```sh -python finish.py --kill_tag=mindspore-lite-java-flclient -``` - -其中参数`--kill_tag`用于搜索该关键字对客户端进程进行kill,只需要设置`--jarPath`中的特殊关键字即可。默认为`mindspore-lite-java-flclient`,即联邦学习jar包名。 -用户可通过指令`ps -ef |grep "mindspore-lite-java-flclient"`查看进程是否还存在。 - -50个客户端参与联邦学习训练任务实验结果。 - -目前`3500_clients_bin`文件夹中包含3500个客户端的数据,本脚本最多可模拟3500个客户端参与联邦学习。 - -下图给出了50个客户端(设置`server_num`为16)进行联邦学习的测试集精度: - -![lenet_50_clients_acc](images/lenet_50_clients_acc.png) - -其中联邦学习总迭代数为100,客户端本地训练epoch数为20,batchSize设置为32。 - -图中测试精度指对于每个联邦学习迭代,各客户端测试集在云侧聚合后的模型上的精度。 - -AVG:对于每个联邦学习迭代,50个客户端测试集精度的平均值。 - -TOP5:对于每个联邦学习迭代,测试集精度最高的5个客户端的精度平均值。 - -LOW5:对于每个联邦学习迭代,测试集精度最低的5个客户端的精度平均值。 diff --git a/docs/federated/docs/source_zh_cn/image_classification_application_in_cross_silo.md b/docs/federated/docs/source_zh_cn/image_classification_application_in_cross_silo.md deleted file mode 100644 index 7f245d2e503232a8b5564134d9993f1f9455908b..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/image_classification_application_in_cross_silo.md +++ /dev/null @@ -1,313 +0,0 @@ -# 实现一个云云联邦的图像分类应用(x86) - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/image_classification_application_in_cross_silo.md) - -根据参与客户端的类型,联邦学习可分为云云联邦学习(cross-silo)和端云联邦学习(cross-device)。在云云联邦学习场景中,参与联邦学习的客户端是不同的组织(例如,医疗或金融)或地理分布的数据中心,即在多个数据孤岛上训练模型。在端云联邦学习场景中,参与的客户端为大量的移动或物联网设备。本框架将介绍如何在MindSpore Federated云云联邦框架上,使用网络LeNet实现一个图片分类应用。 - -启动云云联邦的图像分类应用的完整脚本可参考[这里](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_femnist)。 - -## 下载数据集 - -本示例采用[leaf数据集](https://github.com/TalwalkarLab/leaf)中的联邦学习数据集`FEMNIST`,该数据集包含62个不同类别的手写数字和字母(数字0~9、26个小写字母、26个大写字母),图像大小为`28 x 28`像素,数据集包含3500个用户的手写数字和字母(最多可模拟3500个客户端参与联邦学习),总数据量为805263,平均每个用户包含数据量为226.83,所有用户数据量的方差为88.94。 - -可参考文档[端云联邦学习图像分类数据集处理](https://www.mindspore.cn/federated/docs/zh-CN/master/image_classfication_dataset_process.html)中步骤1~7获取图片形式的3500个用户数据集`3500_client_img`。 - -由于原始3500个用户数据集中每个用户数据量比较少,在云云联邦任务中会收敛太快,无法明显体现云云联邦框架的收敛效果,下面提供一个参考脚本,将指定数量的用户数据集合并为一个用户,以增加参与云云联邦任务的单个用户数据量,更好地模拟云云联邦框架实验。 - -```python -import os -import shutil - - -def mkdir(path): - if not os.path.exists(path): - os.mkdir(path) - - -def combine_users(root_data_path, new_data_path, raw_user_num, new_user_num): - mkdir(new_data_path) - user_list = os.listdir(root_data_path) - num_per_user = int(raw_user_num / new_user_num) - for i in range(new_user_num): - print( - "========================== combine the raw {}~{} users to the new user: dataset_{} ==========================".format( - i * num_per_user, i * num_per_user + num_per_user - 1, i)) - new_user = "dataset_" + str(i) - new_user_path = os.path.join(new_data_path, new_user) - mkdir(new_user_path) - for j in range(num_per_user): - index = i * new_user_num + j - user = user_list[index] - user_path = os.path.join(root_data_path, user) - tags = os.listdir(user_path) - print("------------- process the raw user: {} -------------".format(user)) - for t in tags: - tag_path = os.path.join(user_path, t) - label_list = os.listdir(tag_path) - new_tag_path = os.path.join(new_user_path, t) - mkdir(new_tag_path) - for label in label_list: - label_path = os.path.join(tag_path, label) - img_list = os.listdir(label_path) - new_label_path = os.path.join(new_tag_path, label) - mkdir(new_label_path) - - for img in img_list: - img_path = os.path.join(label_path, img) - new_img_name = user + "_" + img - new_img_path = os.path.join(new_label_path, new_img_name) - shutil.copy(img_path, new_img_path) - -if __name__ == "__main__": - root_data_path = "cross_silo_femnist/femnist/3500_clients_img" - new_data_path = "cross_silo_femnist/femnist/35_7_client_img" - raw_user_num = 35 - new_user_num = 7 - combine_users(root_data_path, new_data_path, raw_user_num, new_user_num) -``` - -其中`root_data_path`为原始3500个用户数据集路径,`new_data_path`为合并后数据集的路径,`raw_user_num`指定用于合并的用户数据集总数(需<=3500),`new_user_num`用于设置将原始数据集合并为多少个用户。如示例代码中将从`cross_silo_femnist/femnist/3500_clients_img`中选取前35个用户,合并为7个用户数据集后存放在路径`cross_silo_femnist/femnist/35_7_client_img`(合并后的7个用户,每个用户包含原始的5个用户数据集)。 - -如下打印代表合并数据集成功: - -```sh -========================== combine the raw 0~4 users to the new user: dataset_0 ========================== -------------- process the raw user: f1798_42 ------------- -------------- process the raw user: f2149_81 ------------- -------------- process the raw user: f4046_46 ------------- -------------- process the raw user: f1093_13 ------------- -------------- process the raw user: f1124_24 ------------- -========================== combine the raw 5~9 users to the new user: dataset_1 ========================== -------------- process the raw user: f0586_11 ------------- -------------- process the raw user: f0721_31 ------------- -------------- process the raw user: f3527_33 ------------- -------------- process the raw user: f0146_33 ------------- -------------- process the raw user: f1272_09 ------------- -========================== combine the raw 10~14 users to the new user: dataset_2 ========================== -------------- process the raw user: f0245_40 ------------- -------------- process the raw user: f2363_77 ------------- -------------- process the raw user: f3596_19 ------------- -------------- process the raw user: f2418_82 ------------- -------------- process the raw user: f2288_58 ------------- -========================== combine the raw 15~19 users to the new user: dataset_3 ========================== -------------- process the raw user: f2249_75 ------------- -------------- process the raw user: f3681_31 ------------- -------------- process the raw user: f3766_48 ------------- -------------- process the raw user: f0537_35 ------------- -------------- process the raw user: f0614_14 ------------- -========================== combine the raw 20~24 users to the new user: dataset_4 ========================== -------------- process the raw user: f2302_58 ------------- -------------- process the raw user: f3472_19 ------------- -------------- process the raw user: f3327_11 ------------- -------------- process the raw user: f1892_07 ------------- -------------- process the raw user: f3184_11 ------------- -========================== combine the raw 25~29 users to the new user: dataset_5 ========================== -------------- process the raw user: f1692_18 ------------- -------------- process the raw user: f1473_30 ------------- -------------- process the raw user: f0909_04 ------------- -------------- process the raw user: f1956_19 ------------- -------------- process the raw user: f1234_26 ------------- -========================== combine the raw 30~34 users to the new user: dataset_6 ========================== -------------- process the raw user: f0031_02 ------------- -------------- process the raw user: f0300_24 ------------- -------------- process the raw user: f4064_46 ------------- -------------- process the raw user: f2439_77 ------------- -------------- process the raw user: f1717_16 ------------- -``` - -文件夹 `cross_silo_femnist/femnist/35_7_client_img`目录结构如下: - -```text -35_7_client_img # 将FeMnist数据集中35个用户合并为7个客户端数据(各包含5个用户数据) -├── dataset_0 # 客户端0的数据集 -│ ├── train # 训练数据集 -│ │ ├── 0 # 存放类别0对应的图片数据 -│ │ ├── 1 # 存放类别1对应的图片数据 -│ │ │ ...... -│ │ └── 61 # 存放类别61对应的图片数据 -│ └── test # 测试数据集,目录结构同train -│ ...... -│ -└── dataset_6 # 客户端6的数据集 - ├── train # 训练数据集 - │ ├── 0 # 存放类别0对应的图片数据 - │ ├── 1 # 存放类别1对应的图片数据 - │ │ ...... - │ └── 61 # 存放类别61对应的图片数据 - └── test # 测试数据集,目录结构同train -``` - -## 定义网络 - -我们选择相对简单的LeNet网络。LeNet网络不包括输入层的情况下,共有7层:2个卷积层、2个下采样层(池化层)、3个全连接层。每层都包含不同数量的训练参数,如下图所示: - -![LeNet5](images/LeNet_5.jpg) - -> 更多的LeNet网络的介绍不在此赘述,希望详细了解LeNet网络,可以查询。 - -本任务使用的网络可参考脚本[test_cross_silo_femnist.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_femnist/test_cross_silo_femnist.py)。 - -若想具体了解MindSpore中网络定义流程可参考[初学入门](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/quick_start.html#网络构建)。 - -## 启动云云联邦任务 - -### 安装MindSpore和MindSpore Federated - -包括源码和下载发布版两种方式,支持CPU、GPU、Ascend硬件平台,根据硬件平台选择安装即可。安装步骤可参考[MindSpore安装指南](https://www.mindspore.cn/install),[MindSpore Federated安装指南](https://www.mindspore.cn/federated/docs/zh-CN/master/federated_install.html)。 - -目前联邦学习框架只支持Linux环境中部署,cross-silo联邦学习框架需要MindSpore版本号>=1.5.0。 - -### 启动任务 - -参考[示例](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_femnist),启动集群。参考示例目录结构如下: - -```text -cross_silo_femnist/ -├── config.json # 配置文件 -├── finish_cross_silo_femnist.py # 关闭云云联邦任务脚本 -├── run_cross_silo_femnist_sched.py # 启动云云联邦scheduler脚本 -├── run_cross_silo_femnist_server.py # 启动云云联邦server脚本 -├── run_cross_silo_femnist_worker.py # 启动云云联邦worker脚本 -├── run_cross_silo_femnist_worker_distributed.py # 启动云云联邦分布式训练worker脚本 -└── test_cross_silo_femnist.py # 客户端使用的训练脚本 -``` - -1. 启动Scheduler - - `run_cross_silo_femnist_sched.py`是为用户启动`Scheduler`而提供的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的`Scheduler`,其TCP端口为`5554`。 - - ```sh - python run_cross_silo_femnist_sched.py --scheduler_manage_address=127.0.0.1:5554 - ``` - - 打印如下代表启动成功: - - ```sh - [INFO] FEDERATED(35566,7f4275895740,python):2022-10-09-15:23:22.450.205 [mindspore_federated/fl_arch/ccsrc/scheduler/scheduler.cc:35] Run] Scheduler started successfully. - [INFO] FEDERATED(35566,7f41f259d700,python):2022-10-09-15:23:22.450.357 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:90] Run] Start http server! - ``` - -2. 启动Server - - `run_cross_silo_femnist_server.py`是为用户启动若干`Server`而提供的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的`Server`,其http起始端口为`5555`,`server`数量为`4`个。 - - ```sh - python run_cross_silo_femnist_server.py --local_server_num=4 --http_server_address=10.*.*.*:5555 - ``` - - 以上指令等价于启动了4个`Server`进程,每个`Server`的联邦学习服务端口分别为`5555`、`5556`、`5557`和`5558`。 - -3. 启动Worker - - `run_cross_silo_femnist_worker.py`是为用户启动若干`worker`而提供的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的`worker`,其http起始端口为`5555`,`worker`数量为`4`个: - - ```sh - python run_cross_silo_femnist_worker.py --dataset_path=/data_nfs/code/fed_user_doc/federated/tests/st/cross_silo_femnist/35_7_client_img/ --http_server_address=10.*.*.*:5555 - ``` - - 当前云云联邦的`worker`节点支持单机多卡&多机多卡的分布式训练方式,`run_cross_silo_femnist_worker_distributed.py`是为用户启动`worker`节点的分布式训练而提供的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的分布式`worker`,其中`device_num`表示`worker`集群启动的进程数目,`run_distribute`表示启动集群的分布式训练,其http起始端口为`5555`,`worker`进程数量为`4`个: - - ```sh - python run_cross_silo_femnist_worker_distributed.py --device_num=4 --run_distribute=True --dataset_path=/data_nfs/code/fed_user_doc/federated/tests/st/cross_silo_femnist/35_7_client_img/ --http_server_address=10.*.*.*:5555 - ``` - -当执行以上三个指令之后,进入当前目录下`worker_0`文件夹,通过指令`grep -rn "test acc" *`查看`worker_0`日志,可看到如下类似打印: - -```sh -local epoch: 0, loss: 3.787421340711655, trian acc: 0.05342741935483871, test acc: 0.075 -``` - -则说明云云联邦启动成功,`worker_0`正在训练,其他worker可通过类似方式查看。 - -若worker已分布式多卡训练的方式启动,进入当前目录下`worker_distributed/log_output/`文件夹,通过指令`grep -rn "test acc" *`查看`worker`分布式集群的日志,可看到如下类似打印: - -```text -local epoch: 0, loss: 2.3467453340711655, trian acc: 0.06532451988877687, test acc: 0.076 -``` - -以上脚本中参数配置说明请参考[yaml配置说明](https://www.mindspore.cn/federated/docs/zh-CN/master/horizontal/federated_server_yaml.html)。 - -### 日志查看 - -成功启动任务之后,会在当前目录`cross_silo_femnist`下生成相应日志文件,日志文件目录结构如下: - -```text -cross_silo_femnist -├── scheduler -│ └── scheduler.log # 运行scheduler过程中打印日志 -├── server_0 -│ └── server.log # server_0运行过程中打印日志 -├── server_1 -│ └── server.log # server_1运行过程中打印日志 -├── server_2 -│ └── server.log # server_2运行过程中打印日志 -├── server_3 -│ └── server.log # server_3运行过程中打印日志 -├── worker_0 -│ ├── ckpt # 存放worker_0在每个联邦学习迭代结束时获取的聚合后的模型ckpt -│ │ ├── 0-fl-ms-bs32-0epoch.ckpt -│ │ ├── 0-fl-ms-bs32-1epoch.ckpt -│ │ │ -│ │ │ ...... -│ │ │ -│ │ └── 0-fl-ms-bs32-19epoch.ckpt -│ └── worker.log # 记录worker_0参与联邦学习任务过程中输出日志 -└── worker_1 - ├── ckpt # 存放worker_1在每个联邦学习迭代结束时获取的聚合后的模型ckpt - │ ├── 1-fl-ms-bs32-0epoch.ckpt - │ ├── 1-fl-ms-bs32-1epoch.ckpt - │ │ - │ │ ...... - │ │ - │ └── 1-fl-ms-bs32-19epoch.ckpt - └── worker.log # 记录worker_1参与联邦学习任务过程中输出日志 -``` - -### 关闭任务 - -若想中途退出,则可用以下指令: - -```sh -python finish_cross_silo_femnist.py --redis_port=2345 -``` - -或者等待训练任务结束之后集群会自动退出,不需要手动关闭。 - -### 实验结果 - -- 使用数据: - - 上面`下载数据集`部分生成的`35_7_client_img/`数据集 - -- 客户端本地训练epoch数:20 - -- 云云联邦学习总迭代数:20 - -- 实验结果(每个迭代聚合后模型在客户端的测试集上精度) - -`worker_0`测试结果: - -```sh -worker_0/worker.log:7409:local epoch: 0, loss: 3.787421340711655, trian acc: 0.05342741935483871, test acc: 0.075 -worker_0/worker.log:14419:local epoch: 1, loss: 3.725699281115686, trian acc: 0.05342741935483871, test acc: 0.075 -worker_0/worker.log:21429:local epoch: 2, loss: 3.5285709657335795, trian acc: 0.19556451612903225, test acc: 0.16875 -worker_0/worker.log:28439:local epoch: 3, loss: 3.0393165519160608, trian acc: 0.4889112903225806, test acc: 0.4875 -worker_0/worker.log:35449:local epoch: 4, loss: 2.575952764115026, trian acc: 0.6854838709677419, test acc: 0.60625 -worker_0/worker.log:42459:local epoch: 5, loss: 2.2081101375296512, trian acc: 0.7782258064516129, test acc: 0.6875 -worker_0/worker.log:49470:local epoch: 6, loss: 1.9229739431736557, trian acc: 0.8054435483870968, test acc: 0.69375 -worker_0/worker.log:56480:local epoch: 7, loss: 1.7005576549999293, trian acc: 0.8296370967741935, test acc: 0.65625 -worker_0/worker.log:63490:local epoch: 8, loss: 1.5248727620766704, trian acc: 0.8407258064516129, test acc: 0.6375 -worker_0/worker.log:70500:local epoch: 9, loss: 1.3838803705352127, trian acc: 0.8568548387096774, test acc: 0.7 -worker_0/worker.log:77510:local epoch: 10, loss: 1.265225578921041, trian acc: 0.8679435483870968, test acc: 0.7125 -worker_0/worker.log:84520:local epoch: 11, loss: 1.167484122101638, trian acc: 0.8659274193548387, test acc: 0.70625 -worker_0/worker.log:91530:local epoch: 12, loss: 1.082880981700859, trian acc: 0.8770161290322581, test acc: 0.65625 -worker_0/worker.log:98540:local epoch: 13, loss: 1.0097520119572772, trian acc: 0.8840725806451613, test acc: 0.64375 -worker_0/worker.log:105550:local epoch: 14, loss: 0.9469810053708015, trian acc: 0.9022177419354839, test acc: 0.7 -worker_0/worker.log:112560:local epoch: 15, loss: 0.8907848935604703, trian acc: 0.9022177419354839, test acc: 0.6625 -worker_0/worker.log:119570:local epoch: 16, loss: 0.8416629644123349, trian acc: 0.9082661290322581, test acc: 0.70625 -worker_0/worker.log:126580:local epoch: 17, loss: 0.798475691030866, trian acc: 0.9122983870967742, test acc: 0.70625 -worker_0/worker.log:133591:local epoch: 18, loss: 0.7599438544427897, trian acc: 0.9243951612903226, test acc: 0.6875 -worker_0/worker.log:140599:local epoch: 19, loss: 0.7250227383907605, trian acc: 0.9294354838709677, test acc: 0.7125 -``` - -其他客户端的测试结果基本相同,不再一一列出。 \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/images/HFL.png b/docs/federated/docs/source_zh_cn/images/HFL.png deleted file mode 100644 index d66167a07acbb730fc6fac6e5f482b490650d9fe..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/HFL.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/LeNet_5.jpg b/docs/federated/docs/source_zh_cn/images/LeNet_5.jpg deleted file mode 100644 index 7894b0e181d965c5e9cbba91fe240c1890d37bda..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/LeNet_5.jpg and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/SILHOUETTE.png b/docs/federated/docs/source_zh_cn/images/SILHOUETTE.png deleted file mode 100644 index 051fd6fbe72b59d5ec47b3c90663f42a0f8d7d22..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/SILHOUETTE.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/VFL.png b/docs/federated/docs/source_zh_cn/images/VFL.png deleted file mode 100644 index 7e79ee943f5f658e80ea0d17f43a0f4f7eefa22f..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/VFL.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/create_android_project.png b/docs/federated/docs/source_zh_cn/images/create_android_project.png deleted file mode 100644 index a519264c4158fba67eb1ff5f5fbc3eae65b32363..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/create_android_project.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/cross-silo_fastrcnn-2workers-loss.png b/docs/federated/docs/source_zh_cn/images/cross-silo_fastrcnn-2workers-loss.png deleted file mode 100644 index c8be83d387fc0df853616dca972e169dfe8e4b31..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/cross-silo_fastrcnn-2workers-loss.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/data_join.png b/docs/federated/docs/source_zh_cn/images/data_join.png deleted file mode 100644 index 82af0a7bc57276af86086cbea4c4c6866fd9e754..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/data_join.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/deploy_VFL.png b/docs/federated/docs/source_zh_cn/images/deploy_VFL.png deleted file mode 100644 index 824c4cc4a9f3c636e41a8113179d183ec9414d4f..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/deploy_VFL.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/download_compression_client.png b/docs/federated/docs/source_zh_cn/images/download_compression_client.png deleted file mode 100644 index 7953b4579f3b31210b4b0556828a18652d27d6ae..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/download_compression_client.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/download_compression_server.png b/docs/federated/docs/source_zh_cn/images/download_compression_server.png deleted file mode 100644 index 6e77f1daf39e7c444b2744c7fb0e1c16fd7006ef..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/download_compression_server.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/eval_flow.png b/docs/federated/docs/source_zh_cn/images/eval_flow.png deleted file mode 100644 index ff47c7a6e333066addc9095698100838746d99ae..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/eval_flow.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/inverse_ecdh_psi_flow.png b/docs/federated/docs/source_zh_cn/images/inverse_ecdh_psi_flow.png deleted file mode 100644 index 23456f0cec4fbc3d3cca16135352977281846d0b..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/inverse_ecdh_psi_flow.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/label_dp.png b/docs/federated/docs/source_zh_cn/images/label_dp.png deleted file mode 100644 index e0a2667bf9da107ef93e1897218b686e13a60000..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/label_dp.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/laplace_pdf.png b/docs/federated/docs/source_zh_cn/images/laplace_pdf.png deleted file mode 100644 index ffe54d02a3dab253d88c4bf79567ae89f83d229a..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/laplace_pdf.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/lenet_50_clients_acc.png b/docs/federated/docs/source_zh_cn/images/lenet_50_clients_acc.png deleted file mode 100644 index c1282811f7161d77ec2ea563d96983ef293dbf43..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/lenet_50_clients_acc.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/lenet_signds_loss_auc.png b/docs/federated/docs/source_zh_cn/images/lenet_signds_loss_auc.png deleted file mode 100644 index 7304b69c4d0abf039549dce758b906d688213e4f..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/lenet_signds_loss_auc.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/mindspore_federated_networking.png b/docs/federated/docs/source_zh_cn/images/mindspore_federated_networking.png deleted file mode 100644 index 4340cb66b638e072ffdb11167743cc45c36a9536..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/mindspore_federated_networking.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/signds_flow.png b/docs/federated/docs/source_zh_cn/images/signds_flow.png deleted file mode 100644 index 2fc591a202bb54490aa0c6e46cf2fe2a94cc407b..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/signds_flow.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/signds_framework.png b/docs/federated/docs/source_zh_cn/images/signds_framework.png deleted file mode 100644 index 3a90a24dd9dc767a4905e3e3ef3340e8d05f7581..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/signds_framework.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/signds_step_length.png b/docs/federated/docs/source_zh_cn/images/signds_step_length.png deleted file mode 100644 index 1b4e0830cb38b20312fa9fe7e3621fe12c388864..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/signds_step_length.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/splitnn_pangu_alpha.png b/docs/federated/docs/source_zh_cn/images/splitnn_pangu_alpha.png deleted file mode 100644 index c4603cd6f3cbb775e90bcc554cff6de7c77b37a2..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/splitnn_pangu_alpha.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/splitnn_pangu_alpha_result.png b/docs/federated/docs/source_zh_cn/images/splitnn_pangu_alpha_result.png deleted file mode 100644 index 2d0c29c3fdcf0464080a7d61d399f54425716c81..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/splitnn_pangu_alpha_result.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/splitnn_wide_and_deep.png b/docs/federated/docs/source_zh_cn/images/splitnn_wide_and_deep.png deleted file mode 100644 index 3a0075e69492b678f69a6eff759246938504ddd9..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/splitnn_wide_and_deep.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/start_android_project.png b/docs/federated/docs/source_zh_cn/images/start_android_project.png deleted file mode 100644 index 3a9336add10acbbef60dc429b8a3bad1ca198c38..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/start_android_project.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/two_cluster.png b/docs/federated/docs/source_zh_cn/images/two_cluster.png deleted file mode 100644 index 6290d00d5e81f289f370a90ee6ee175fdc5ea638..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/two_cluster.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/upload_compression_client.png b/docs/federated/docs/source_zh_cn/images/upload_compression_client.png deleted file mode 100644 index bf7f72cec5c0e8d4c036c2c2a5e71861da3e0abe..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/upload_compression_client.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/upload_compression_server.png b/docs/federated/docs/source_zh_cn/images/upload_compression_server.png deleted file mode 100644 index 42ba9dc14f6305bebb05e43c531378a02f84f3d8..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/upload_compression_server.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_1.png b/docs/federated/docs/source_zh_cn/images/vfl_1.png deleted file mode 100644 index c35343919ebd93de3181a506e95dcde582039021..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_1.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_backward.png b/docs/federated/docs/source_zh_cn/images/vfl_backward.png deleted file mode 100644 index a2ee272ebba07681968f1884ec548497c6cd7b9d..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_backward.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_feature_reconstruction.png b/docs/federated/docs/source_zh_cn/images/vfl_feature_reconstruction.png deleted file mode 100644 index 6085a486a555711967db3be2c979d145c95b2d67..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_feature_reconstruction.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_feature_reconstruction_defense.png b/docs/federated/docs/source_zh_cn/images/vfl_feature_reconstruction_defense.png deleted file mode 100644 index b52afbce114c59948af812721c50ec279c5f4808..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_feature_reconstruction_defense.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_forward.png b/docs/federated/docs/source_zh_cn/images/vfl_forward.png deleted file mode 100644 index 2e284106b7fcb6dc84c537a5325bb387389443f7..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_forward.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_mnist_detail.png b/docs/federated/docs/source_zh_cn/images/vfl_mnist_detail.png deleted file mode 100644 index b1b5764dfbd7776c5e5b92aa3f2440f71f8045fa..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_mnist_detail.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_normal_communication_compress.png b/docs/federated/docs/source_zh_cn/images/vfl_normal_communication_compress.png deleted file mode 100644 index 927bfb72aca6a8dea90be2fd9b8e1e1c1f216332..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_normal_communication_compress.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_pangu_communication_compress.png b/docs/federated/docs/source_zh_cn/images/vfl_pangu_communication_compress.png deleted file mode 100644 index 776e07a86aa8c6ba10e40aa2d2f2dbb5ec67889f..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_pangu_communication_compress.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/vfl_with_tee.png b/docs/federated/docs/source_zh_cn/images/vfl_with_tee.png deleted file mode 100644 index 6eeb4ac0048f8462866155601009d559d632302b..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/vfl_with_tee.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/weight_diff_decode.png b/docs/federated/docs/source_zh_cn/images/weight_diff_decode.png deleted file mode 100644 index fc7c07cbd48813b259b27dcb8e158976a90025f2..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/weight_diff_decode.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/images/weight_diff_encode.png b/docs/federated/docs/source_zh_cn/images/weight_diff_encode.png deleted file mode 100644 index 1103b9c6f1349699eb2d48cbb00fe8557660dba8..0000000000000000000000000000000000000000 Binary files a/docs/federated/docs/source_zh_cn/images/weight_diff_encode.png and /dev/null differ diff --git a/docs/federated/docs/source_zh_cn/index.rst b/docs/federated/docs/source_zh_cn/index.rst deleted file mode 100644 index a8caa1b9ecab5312034e6aeb56a4d22f84ba650c..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/index.rst +++ /dev/null @@ -1,177 +0,0 @@ -.. MindSpore documentation master file, created by - sphinx-quickstart on Thu Mar 24 11:00:00 2020. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -MindSpore Federated 文档 -========================= - -MindSpore Federated是面向MindSpore开发者的开源联邦学习工具,在用户数据留存在本地的情况下,使能全场景智能应用。 - -联邦学习是一种加密的分布式机器学习技术,用于解决数据孤岛问题,在多方或者多资源计算节点间进行高效率,安全且可靠的机器学习。支持机器学习的各参与方在不直接共享本地数据的前提下,共建AI模型,包括但不限于广告推荐、分类、检测等主流深度学习模型,主要应用在金融,医疗,推荐等领域。 - -MindSpore Federated提供样本联合的横向联邦模式和特征联合的纵向联邦模式。可支持面向亿级无状态终端设备的商用化部署,也可支持跨可信区的数据中心之间的云云联邦。 - -代码仓地址: - -使用MindSpore Federated横向框架的优势 ----------------------------------------- - -横向联邦架构: - -.. raw:: html - - - -1. 隐私安全 - - 支持基于多方安全计算(MPC)的精度无损的安全聚合方案,防止模型窃取。 - - 支持基于本地差分隐私的性能无损的加密方案,防止模型泄漏隐私数据。 - - 支持基于符号维度选择(SignDS)的梯度保护方案,防止模型隐私数据泄露的同时,可将通信开销降低99%。 - -2. 分布式联邦聚合 - - 云侧松耦合集群化处理方式,和分布式梯度二次聚合范式,支持千万级数量的大规模异构终端部署场景,实现高性能、高可用的联邦聚合计算,可应对网络不稳定,负载突变等问题。 - -3. 联邦效率提升 - - 支持自适应调频策略,支持梯度压缩算法,提高联邦学习效率,节省带宽资源。 - - 支持多种联邦聚合策略,提高联邦收敛的平滑度,兼顾全局和局部的精度最优化。 - -4. 灵活易用 - - 仅一行代码即可切换单机训练与联邦学习模式。 - - 网络模型可编程,聚合算法可编程,安全算法可编程,安全等级可定制。 - - 支持联邦训练模型的效果评估,提供联邦任务的监控能力。 - -使用MindSpore Federated纵向框架的优势 ----------------------------------------- - -纵向联邦架构: - -.. raw:: html - - - -1. 隐私安全 - - 支持高性能隐私集合求交协议(PSI),可防止联邦参与方获得交集外的ID信息,可应对数据不均衡场景。 - - 支持结合量化与差分隐私的特征保护软件方案,防止攻击者从中间特征重构出原始隐私数据。 - - 支持基于可信执行环境的特征保护硬件方案,提供高强度且高效的特征保护能力。 - - 支持基于差分隐私的标签保护方案,防止泄漏用户标签数据。 - -2. 联邦训练 - - 支持多类型的拆分学习网络结构。 - - 面向大模型跨域训练,流水线并行优化。 - -使用MindSpore Federated的工作流程 ------------------------------------ - -1. `识别场景、积累数据 `_ - - 识别出可使用联邦学习的业务场景,在客户端为联邦任务积累本地数据。 - -2. `模型选型、框架部署 `_ - - 进行模型原型的选型或开发,并使用工具生成方便部署的联邦学习模型。 - -3. `应用部署 `_ - - 将对应组件部署到业务应用中,并在服务器上设置联邦配置任务和部署脚本。 - -常见应用场景 ------------------ - -1. `图像分类 `_ - - 使用联邦学习实现图像分类应用。 - -2. `文本分类 `_ - - 使用联邦学习实现文本分类应用。 - -.. toctree:: - :maxdepth: 1 - :caption: 安装部署 - - federated_install - deploy_federated_server - deploy_federated_client - deploy_vfl - -.. toctree:: - :maxdepth: 1 - :caption: 横向应用实践 - - image_classfication_dataset_process - image_classification_application - sentiment_classification_application - image_classification_application_in_cross_silo - object_detection_application_in_cross_silo - -.. toctree:: - :maxdepth: 1 - :caption: 纵向应用实践 - - data_join - split_wnd_application - split_pangu_alpha_application - -.. toctree:: - :maxdepth: 1 - :caption: 安全和隐私 - - local_differential_privacy_training_noise - local_differential_privacy_training_signds - local_differential_privacy_eval_laplace - pairwise_encryption_training - private_set_intersection - secure_vertical_federated_learning_with_EmbeddingDP - secure_vertical_federated_learning_with_TEE - secure_vertical_federated_learning_with_DP - -.. toctree:: - :maxdepth: 1 - :caption: 通信压缩 - - communication_compression - vfl_communication_compress - -.. toctree:: - :maxdepth: 1 - :caption: 横向联邦API参考 - - horizontal_server - cross_device - horizontal/cross_silo - -.. toctree:: - :maxdepth: 1 - :caption: 纵向联邦API参考 - - Data_Join - vertical/vertical_communicator - vertical_federated_trainer - -.. toctree:: - :maxdepth: 1 - :caption: 参考文档 - - faq - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE diff --git a/docs/federated/docs/source_zh_cn/interface_description_federated_client.md b/docs/federated/docs/source_zh_cn/interface_description_federated_client.md deleted file mode 100644 index cf716ea2c0aed7cd5d7b7ee13deda38ccb772810..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/interface_description_federated_client.md +++ /dev/null @@ -1,350 +0,0 @@ -# 使用示例 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/interface_description_federated_client.md) - -注意,在使用以下接口前,可先参照文档[横向端侧部署](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)进行相关环境的部署。 - -## 联邦学习启动接口flJobRun() - -调用flJobRun()接口前,需先实例化参数类FLParameter,进行相关参数设置, 相关参数如下: - -| 参数名称 | 参数类型 | 是否必须 | 描述信息 | 备注 | -| -------------------- | -------------------------- | -------- | ------------------------------------------------------------ | ------------------------------------------------------------ | -| dataMap | Map/> | Y | 联邦学习数据集路径 | Map/>类型的数据集,map中key为RunType枚举类型,value为对应的数据集列表,key为RunType.TRAINMODE时代表对应的value为训练相关的数据集列表,key为RunType.EVALMODE时代表对应的value为验证相关的数据集列表,key为RunType.INFERMODE时代表对应的value为推理相关的数据集列表。 | -| flName | String | Y | 联邦学习使用的模型脚本包路径 | 我们提供了两个类型的模型脚本供大家参考([有监督情感分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert))、[LeNet图片分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)),对于有监督情感分类任务,该参数可设置为所提供的脚本文件[AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java) 的包路径`com.mindspore.flclient.demo.albert.AlbertClient`;对于LeNet图片分类任务,该参数可设置为所提供的脚本文件[LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java) 的包路径`com.mindspore.flclient.demo.lenet.LenetClient`。同时,用户可参考这两个类型的模型脚本,自定义模型脚本,然后将该参数设置为自定义的模型文件ModelClient.java(需继承于类[Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java))的包路径即可。 | -| trainModelPath | String | Y | 联邦学习使用的训练模型路径,为.ms文件的绝对路径 | 建议将路径设置到训练App自身目录下,保护模型本身的数据访问安全性。 | -| inferModelPath | String | Y | 联邦学习使用的推理模型路径,为.ms文件的绝对路径 | 对于普通联邦学习模式(训练和推理使用同一个模型),该参数需设置为与trainModelPath相同;对于混合学习模式(训练和推理使用不同的模型,且云侧也包含训练过程),该参数设置为实际的推理模型路径。建议将路径设置到训练App自身目录下,保护模型本身的数据访问安全性。 | -| sslProtocol | String | N | 端云HTTPS通信所使用的TLS协议版本 | 设置了白名单,目前只支持"TLSv1.3"或者"TLSv1.2"。非必须设置,默认值为"TLSv1.2"。只在HTTPS通信场景中使用。 | -| deployEnv | String | Y | 联邦学习的部署环境 | 设置了白名单,目前只支持"x86", "android"。 | -| certPath | String | N | 端云https通信所使用的自签名根证书路径 | 当部署环境为"x86",且端云采用自签名证书进行https通信认证时,需要设置该参数,该证书需与生成云侧自签名证书所使用的CA根证书一致才能验证通过,此参数用于非Android场景。 | -| domainName | String | Y | 端云通信url | 目前,https和http通信均支持,对应格式分别为:https://......、http://......,当`useElb`设置为true时,格式必须为:https://127.0.0.0:6666 或者http://127.0.0.0:6666 ,其中`127.0.0.0`对应提供云侧服务的机器ip(即云侧参数`--scheduler_ip`),`6666`对应云侧参数`--fl_server_port`。 | -| ifUseElb | boolean | N | 用于多server场景设置是否将客户端的请求随机发送给一定范围内的不同server | 设置为true代表客户端会将请求随机发给一定范围内的server地址,false代表客户端的请求会发给固定的server地址,此参数用于非Android场景,默认值为false。 | -| serverNum | int | N | 客户端可选择连接的server数量 | 当ifUseElb设置为true时,可设置为与云侧启动server端时的`server_num`参数保持一致,用于随机选择不同的server发送信息,此参数用于非Android场景,默认值为1。 | -| ifPkiVerify | boolean | N | 端云身份认证开关 | 设置为true代表开启端云安全认证,设置为false代表不开启,默认值为false。身份认证需要HUKS提供证书,该参数只在Android环境中使用(目前只支持华为手机)。 | -| threadNum | int | N | 联邦学习训练和推理时使用的线程数 | 默认值为1 | -| cpuBindMode | BindMode | N | 联邦学习训练和推理时线程所需绑定的cpu内核 | BindMode枚举类型,其中BindMode.NOT_BINDING_CORE代表不绑定内核,由系统自动分配,BindMode.BIND_LARGE_CORE代表绑定大核,BindMode.BIND_MIDDLE_CORE代表绑定中核。默认值为BindMode.NOT_BINDING_CORE。 | -| batchSize | int | Y | 联邦学习训练和推理时使用的单步训练样本数,即batch size | 需与模型的输入数据的batch size保持一致。 | -| iflJobResultCallback | IFLJobResultCallback | N | 联邦学习回调函数对象 | 用户可根据实际场景所需,实现工程中接口类[IFLJobResultCallback.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/IFLJobResultCallback.java)的具体方法后,作为回调函数对象设置到联邦学习任务中。我们提供了一个简单的实现用例[FLJobResultCallback.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/IFLJobResultCallback.java)作为该参数默认值。 | - -注意1,当使用http通信时,可能会存在通信安全风险,请知悉。 - -注意2,在Android环境中,进行https通信时还需对以下参数进行设置,设置示例如下: - -```java -FLParameter flParameter = FLParameter.getInstance(); -SecureSSLSocketFactory sslSocketFactory = SecureSSLSocketFactory.getInstance(applicationContext) -SecureX509TrustManager x509TrustManager = new SecureX509TrustManager(applicationContext); -flParameter.setSslSocketFactory(sslSocketFactory); -flParameter.setX509TrustManager(x509TrustManager); -``` - -其中 `SecureSSLSocketFactory` 、`SecureX509TrustManager` 两个对象需在Android工程中实现,需要用户根据手机中证书种类自行进行设计。 - -注意3,在x86环境中,进行https通信时,目前只支持自签名证书认证,还需对以下参数进行设置,设置示例如下: - -```java -FLParameter flParameter = FLParameter.getInstance(); -String certPath = "CARoot.pem"; // 端云https通信所使用的自签名根证书路径 -flParameter.setCertPath(certPath); -``` - -注意4,在Android环境中, 当pkiVerify设置为true且云侧设置encrypt_train_type=PW_ENCRYPT时,还需要对以下参数进行设置,设置示例如下: - -```java -FLParameter flParameter = FLParameter.getInstance(); -String equipCrlPath = certPath; -long validIterInterval = 3600000; -flParameter.setEquipCrlPath(equipCrlPath); -flParameter.setValidInterval(validIterInterval); -``` - -其中`equipCrlPath`是设备之间证书校验需要的CRL证书,即证书吊销列表,一般可以预置 "Huawei CBG Certificate Revocation Lists" 中的设备证书CRL;`validIterInterval`一般可以设置为每轮端云聚合需要的时间(单位:毫秒,默认值为3600000),在PW_ENCRYPT模式下用来辅助防范重放攻击。 - -注意5,每次联邦学习任务启动前,会实例化类FLParameter进行相关参数设置。而实例化FLParameter时会自动随机生成一个clientID,用于与云侧交互过程中唯一标识该客户端,若用户需要自行设置clientID,可在实例化类FLParameter之后,调用其setClientID方法进行设置,则接着启动联邦学习任务后会使用用户设置的clientID。 - -创建SyncFLJob对象,并通过SyncFLJob类的flJobRun()方法启动同步联邦学习任务。 - -示例代码(基本http通信)如下: - -1. 有监督情感分类任务示例代码 - - ```java - // 构造dataMap - String trainTxtPath = "data/albert/supervise/client/1.txt"; - String evalTxtPath = "data/albert/supervise/eval/eval.txt"; // 非必须,getModel之后不进行验证可不设置 - String vocabFile = "data/albert/supervise/vocab.txt"; // 数据预处理的词典文件路径 - String idsFile = "data/albert/supervise/vocab_map_ids.txt" // 词典的映射id文件路径 - Map> dataMap = new HashMap<>(); - List trainPath = new ArrayList<>(); - trainPath.add(trainTxtPath); - trainPath.add(vocabFile); - trainPath.add(idsFile); - List evalPath = new ArrayList<>(); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(evalTxtPath); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(vocabFile); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(idsFile); // 非必须,getModel之后不进行验证可不设置 - dataMap.put(RunType.TRAINMODE, trainPath); - dataMap.put(RunType.EVALMODE, evalPath); // 非必须,getModel之后不进行验证可不设置 - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // AlBertClient.java 包路径 - String trainModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // 绝对路径 - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // 绝对路径,和trainModelPath保持一致 - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - - // start FLJob - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.flJobRun(); - ``` - -2. LeNet图片分类任务示例代码 - - ```java - // 构造dataMap - String trainImagePath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_9_train_data.bin"; - String trainLabelPath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_9_train_label.bin"; - String evalImagePath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_data.bin"; // 非必须,getModel之后不进行验证可不设置 - String evalLabelPath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_label.bin"; // 非必须,getModel之后不进行验证可不设置 - Map> dataMap = new HashMap<>(); - List trainPath = new ArrayList<>(); - trainPath.add(trainImagePath); - trainPath.add(trainLabelPath); - List evalPath = new ArrayList<>(); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(evalImagePath); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(evalLabelPath); // 非必须,getModel之后不进行验证可不设置 - dataMap.put(RunType.TRAINMODE, trainPath); - dataMap.put(RunType.EVALMODE, evalPath); // 非必须,getModel之后不进行验证可不设置 - - String flName = "com.mindspore.flclient.demo.lenet.LenetClient"; // LenetClient.java 包路径 - String trainModelPath = "SyncFLClient/lenet_train.mindir0.ms"; //绝对路径 - String inferModelPath = "SyncFLClient/lenet_train.mindir0.ms"; //绝对路径 - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - flParameter.setBatchSize(batchSize); - - // start FLJob - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.flJobRun(); - ``` - -## 多条数据输入推理接口modelInference() - -调用modelInference()接口前,需先实例化参数类FLParameter,进行相关参数设置,相关参数如下: - -| 参数名称 | 参数类型 | 是否必须 | 描述信息 | 适应API版本 | -| -------------- | -------------------------- | -------- | ------------------------------------------------------ | ------------------------------------------------------------ | -| flName | String | Y | 联邦学习使用的模型脚本包路径 | 我们提供了两个类型的模型脚本供大家参考([有监督情感分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert)、[LeNet图片分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)),对于有监督情感分类任务,该参数可设置为所提供的脚本文件[AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java) 的包路径`com.mindspore.flclient.demo.albert.AlbertClient`;对于LeNet图片分类任务,该参数可设置为所提供的脚本文件[LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java) 的包路径`com.mindspore.flclient.demo.lenet.LenetClient`。同时,用户可参考这两个类型的模型脚本,自定义模型脚本,然后将该参数设置为自定义的模型文件ModelClient.java(需继承于类[Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java))的包路径即可。 | -| dataMap | Map/> | Y | 联邦学习数据集路径 | Map/>类型的数据集,map中key为RunType枚举类型,value为对应的数据集列表,key为RunType.TRAINMODE时代表对应的value为训练相关的数据集列表,key为RunType.EVALMODE时代表对应的value为验证相关的数据集列表,key为RunType.INFERMODE时代表对应的value为推理相关的数据集列表。 | -| inferModelPath | String | Y | 联邦学习推理模型路径,为.ms文件的绝对路径 | 对于普通联邦学习模式(训练和推理使用同一个模型),该参数需设置为与trainModelPath相同;对于混合学习模式(训练和推理使用不同的模型,且云侧也包含训练过程),该参数设置为实际的推理模型路径。建议将路径设置到训练App自身目录下,保护模型本身的数据访问安全性。 | -| threadNum | int | N | 联邦学习训练和推理时使用的线程数 | 默认值为1 | -| cpuBindMode | BindMode | N | 联邦学习训练和推理时线程所需绑定的cpu内核 | BindMode枚举类型,其中BindMode.NOT_BINDING_CORE代表不绑定内核,由系统自动分配,BindMode.BIND_LARGE_CORE代表绑定大核,BindMode.BIND_MIDDLE_CORE代表绑定中核。默认值为BindMode.NOT_BINDING_CORE。 | -| batchSize | int | Y | 联邦学习训练和推理时使用的单步训练样本数,即batch size | 需与模型的输入数据的batch size保持一致。 | - -创建SyncFLJob对象,并通过SyncFLJob类的modelInference()方法启动端侧推理任务,返回推理的标签数组。 - -示例代码如下: - -1. 有监督情感分类任务示例代码 - - ```java - // 构造dataMap - String inferTxtPath = "data/albert/supervise/eval/eval.txt"; - String vocabFile = "data/albert/supervise/vocab.txt"; - String idsFile = "data/albert/supervise/vocab_map_ids.txt" - Map> dataMap = new HashMap<>(); - List inferPath = new ArrayList<>(); - inferPath.add(inferTxtPath); - inferPath.add(vocabFile); - inferPath.add(idsFile); - dataMap.put(RunType.INFERMODE, inferPath); - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // AlBertClient.java 包路径 - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // 绝对路径,和trainModelPath保持一致; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setInferModelPath(inferModelPath); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - flParameter.setBatchSize(batchSize); - - // inference - SyncFLJob syncFLJob = new SyncFLJob(); - int[] labels = syncFLJob.modelInference(); - ``` - -2. LeNet图片分类示例代码 - - ```java - // 构造dataMap - String inferImagePath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_data.bin"; - String inferLabelPath = "SyncFLClient/data/3500_clients_bin/f0178_39/f0178_39_bn_1_test_label.bin"; - Map> dataMap = new HashMap<>(); - List inferPath = new ArrayList<>(); - inferPath.add(inferImagePath); - inferPath.add(inferLabelPath); - dataMap.put(RunType.INFERMODE, inferPath); - - String flName = "com.mindspore.flclient.demo.lenet.LenetClient"; // LenetClient.java 包路径 - String inferModelPath = "SyncFLClient/lenet_train.mindir0.ms"; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setInferModelPath(inferModelPath); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - flParameter.setBatchSize(batchSize); - - // inference - SyncFLJob syncFLJob = new SyncFLJob(); - int[] labels = syncFLJob.modelInference(); - ``` - -## 获取云侧最新模型接口getModel () - -调用getModel()接口前,需先实例化参数类FLParameter,进行相关参数设置,相关参数如下: - -| 参数名称 | 参数类型 | 是否必须 | 描述信息 | 备注 | -| -------------- | --------- | -------- | ------------------------------------------------------------ | ------------------------------------------------------------ | -| flName | String | Y | 联邦学习使用的模型脚本包路径 | 我们提供了两个类型的模型脚本供大家参考([有监督情感分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert)、[LeNet图片分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)),对于有监督情感分类任务,该参数可设置为所提供的脚本文件[AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java) 的包路径`com.mindspore.flclient.demo.albert.AlbertClient`;对于LeNet图片分类任务,该参数可设置为所提供的脚本文件[LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java) 的包路径`com.mindspore.flclient.demo.lenet.LenetClient`。同时,用户可参考这两个类型的模型脚本,自定义模型脚本,然后将该参数设置为自定义的模型文件ModelClient.java(需继承于类[Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java))的包路径即可。 | -| trainModelPath | String | Y | 联邦学习使用的训练模型路径,为.ms文件的绝对路径 | 建议将路径设置到训练App自身目录下,保护模型本身的数据访问安全性。 | -| inferModelPath | String | Y | 联邦学习推理模型路径,为.ms文件的绝对路径 | 对于普通联邦学习模式(训练和推理使用同一个模型),该参数需设置为与trainModelPath相同;对于混合学习模式(训练和推理使用不同的模型,且云侧也包含训练过程),该参数设置为实际的推理模型路径。建议将路径设置到训练App自身目录下,保护模型本身的数据访问安全性。 | -| sslProtocol | String | N | 端云HTTPS通信所使用的TLS协议版本 | 设置了白名单,目前只支持"TLSv1.3"或者"TLSv1.2"。非必须设置,默认值为"TLSv1.2"。只在HTTPS通信场景中使用。 | -| deployEnv | String | Y | 联邦学习的部署环境 | 设置了白名单,目前只支持"x86", "android"。 | -| certPath | String | N | 端云https通信所使用的自签名根证书路径 | 当部署环境为"x86",且端云采用自签名证书进行https通信认证时,需要设置该参数,该证书需与生成云侧自签名证书所使用的CA根证书一致才能验证通过,此参数用于非Android场景。 | -| domainName | String | Y | 端云通信url | 目前,https和http通信均支持,对应格式分别为:https://......、http://......,当`useElb`设置为true时,格式必须为:https://127.0.0.0:6666 或者http://127.0.0.0:6666 ,其中`127.0.0.0`对应提供云侧服务的机器ip(即云侧参数`--scheduler_ip`),`6666`对应云侧参数`--fl_server_port`。 | -| ifUseElb | boolean | N | 用于多server场景设置是否将客户端的请求随机发送给一定范围内的不同server | 设置为true代表客户端会将请求随机发给一定范围内的server地址,false代表客户端的请求会发给固定的server地址,此参数用于非Android场景,默认值为false。 | -| serverNum | int | N | 客户端可选择连接的server数量 | 当ifUseElb设置为true时,可设置为与云侧启动server端时的`server_num`参数保持一致,用于随机选择不同的server发送信息,此参数用于非Android场景,默认值为1。 | -| serverMod | ServerMod | Y | 联邦学习训练模式。 | ServerMod枚举类型的联邦学习训练模式,其中ServerMod.FEDERATED_LEARNING代表普通联邦学习模式(训练和推理使用同一个模型)ServerMod.HYBRID_TRAINING代表混合学习模式(训练和推理使用不同的模型,且云侧也包含训练过程)。 | - -注意1,当使用http通信时,可能会存在通信安全风险,请知悉。 - -注意2,在Android环境中,进行https通信时还需对以下参数进行设置,设置示例如下: - -```java -FLParameter flParameter = FLParameter.getInstance(); -SecureSSLSocketFactory sslSocketFactory = SecureSSLSocketFactory.getInstance(applicationContext) -SecureX509TrustManager x509TrustManager = new SecureX509TrustManager(applicationContext); -flParameter.setSslSocketFactory(sslSocketFactory); -flParameter.setX509TrustManager(x509TrustManager); -``` - -其中 `SecureSSLSocketFactory` 、`SecureX509TrustManager` 两个对象需在Android工程中实现,需要用户根据手机中证书种类自行进行设计。 - -注意3,在X86环境中,进行https通信时,目前只支持自签名证书认证,还需对以下参数进行设置,设置示例如下: - -```java -FLParameter flParameter = FLParameter.getInstance(); -String certPath = "CARoot.pem"; // 端云https通信所使用的自签名根证书路径 -flParameter.setCertPath(certPath); -``` - -注意4,在调用getModel方法前,会实例化类FLParameter进行相关参数设置。而实例化FLParameter时会自动随机生成一个clientID,用于与云侧交互过程中唯一标识该客户端,若用户需要自行设置clientID,可在实例化类FLParameter之后,调用其setCertPath方法进行设置,则接着启动getModel任务后会使用用户设置的clientID。 - -创建SyncFLJob对象,并通过SyncFLJob类的getModel()方法启动异步推理任务,返回getModel请求状态码。 - -示例代码如下: - -1. 有监督情感分类任务版本 - - ```java - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // AlBertClient.java 包路径 - String trainModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; //绝对路径 - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; //绝对路径,和trainModelPath保持一致 - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - ServerMod serverMod = ServerMod.FEDERATED_LEARNING; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setServerMod(ServerMod.valueOf(serverMod)); - - // getModel - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.getModel(); - ``` - -2. LeNet图片分类任务版本 - - ```java - String flName = "com.mindspore.flclient.demo.lenet.LenetClient"; // LenetClient.java 包路径 - String trainModelPath = "SyncFLClient/lenet_train.mindir0.ms"; //绝对路径 - String inferModelPath = "SyncFLClient/lenet_train.mindir0.ms"; //绝对路径,和trainModelPath保持一致 - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - ServerMod serverMod = ServerMod.FEDERATED_LEARNING; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - flParameter.setServerMod(ServerMod.valueOf(serverMod)); - - // getModel - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.getModel(); - ``` diff --git a/docs/federated/docs/source_zh_cn/java_api_callback.md b/docs/federated/docs/source_zh_cn/java_api_callback.md deleted file mode 100644 index eb4ffcb3928f80a85bef13fb96e38646a04605cf..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/java_api_callback.md +++ /dev/null @@ -1,66 +0,0 @@ -# Callback - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/java_api_callback.md) - -```java -import com.mindspore.flclient.model.Callback -``` - -Callback定义了端侧联邦学习中用于记录训练、评估和预测不同阶段结果的钩子函数。 - -## 公有成员函数 - -| function | -| -------------------------------- | -| [abstract Status stepBegin()](#stepbegin) | -| [abstract Status stepEnd()](#stepend) | -| [abstract Status epochBegin()](#epochbegin) | -| [abstract Status epochEnd()](#epochend) | - -## stepBegin - -```java - public abstract Status stepBegin() -``` - -单步执行前处理函数。 - -- 返回值 - - 前处理执行结果状态。 - -## stepEnd - -```java -public abstract Status stepEnd() -``` - -单步执行后处理函数。 - -- 返回值 - - 后处理执行结果状态。 - -## epochBegin - -```java -public abstract Status epochBegin() -``` - -epoch执行前处理函数。 - -- 返回值 - - 前处理执行结果状态。 - -## epochEnd - -```java -public abstract Status epochEnd() -``` - -epoch执行后处理函数。 - -- 返回值 - - 前处理执行结果状态。 diff --git a/docs/federated/docs/source_zh_cn/java_api_client.md b/docs/federated/docs/source_zh_cn/java_api_client.md deleted file mode 100644 index 730960fa29c7a246cba71e5910f1a0c7b9b9907a..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/java_api_client.md +++ /dev/null @@ -1,173 +0,0 @@ -# Client - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/java_api_client.md) - -```java -import com.mindspore.flclient.model.Client -``` - -Client定义了端侧联邦学习算法执行流程对象。 - -## 公有成员函数 - -| function | -| -------------------------------- | -| [abstract List initCallbacks(RunType runType, DataSet dataSet)](#initcallbacks) | -| [abstract Map initDataSets(Map\> files)](#initdatasets) | -| [abstract float getEvalAccuracy(List evalCallbacks)](#getevalaccuracy) | -| [abstract List getInferResult(List inferCallbacks)](#getinferresult) | -| [Status trainModel(int epochs)](#trainmodel) | -| [float evalModel()](#evalmodel) | -| [Map genUnsupervisedEvalData(List evalCallbacks)](#genunsupervisedevaldata) | -| [List inferModel()](#infermodel) | -| [Status setLearningRate(float lr)](#setlearningrate) | -| [void setBatchSize(int batchSize)](#setbatchsize) | - -## initCallbacks - -```java -public abstract List initCallbacks(RunType runType, DataSet dataSet) -``` - -初始化callback列表。 - -- 参数 - - - `runType`: RunType类,标识训练、评估还是预测阶段。 - - `dataSet`: DataSet类,训练、评估还是预测阶段数据集。 - -- 返回值 - - 初始化的callback列表。 - -## initDataSets - -```java -public abstract Map initDataSets(Map> files) -``` - -初始化dataset列表。 - -- 参数 - - - `files`: 训练、评估和预测阶段使用的数据文件。 - -- 返回值 - - 训练、评估和预测阶段数据集样本量。 - -## getEvalAccuracy - -```java -public abstract float getEvalAccuracy(List evalCallbacks) -``` - -获取评估阶段的精度。 - -- 参数 - - - `evalCallbacks`: 评估阶段使用的callback列表。 - -- 返回值 - - 评估阶段精度。 - -## getInferResult - -```java -public abstract List getInferResult(List inferCallbacks) -``` - -获取预测结果。 - -- 参数 - - - `inferCallbacks`: 预测阶段使用的callback列表。 - -- 返回值 - - 预测结果。 - -## trainModel - -```java -public Status trainModel(int epochs) -``` - -开启模型训练。 - -- 参数 - - - `epochs`: 训练的epoch数。 - -- 返回值 - - 模型训练结果。 - -## evalModel - -```java -public float evalModel() -``` - -执行模型评估过程。 - -- 返回值 - - 模型评估精度。 - -## genUnsupervisedEvalData - -```java -public Map genUnsupervisedEvalData(List evalCallbacks) -``` - -生成无监督训练评估数据,子类需要覆写该函数。 - -- 参数 - - - `evalCallbacks`: 推理回调类,该类生成数据。 - -- 返回值 - - 无监督训练评估数据。 - -## inferModel - -```java -public List inferModel() -``` - -执行模型预测过程。 - -- 返回值 - - 模型预测结果。 - -## setLearningRate - -```java -public Status setLearningRate(float lr) -``` - -设置学习率。 - -- 参数 - - - `lr`: 学习率。 - -- 返回值 - - 设置结果。 - -## setBatchSize - -```java -public void setBatchSize(int batchSize) -``` - -设置执行批次数。 - -- 参数 - - - `batchSize`: 批次数。 \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/java_api_clientmanager.md b/docs/federated/docs/source_zh_cn/java_api_clientmanager.md deleted file mode 100644 index 5d1af142706522e091ba9bcd6a8bf34d71938d73..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/java_api_clientmanager.md +++ /dev/null @@ -1,44 +0,0 @@ -# ClientManager - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/java_api_clientmanager.md) - -```java -import com.mindspore.flclient.model.ClientManager -``` - -ClientManager定义了端侧联邦学习自定义算法模型注册接口。 - -## 公有成员函数 - -| function | -| -------------------------------- | -| [static void registerClient(Client client)](#registerclient) | -| [static Client getClient(String name)](#getclient) | - -## registerClient - -```java -public static void registerClient(Client client) -``` - -注册Client对象。 - -- 参数 - - - `client`: 需要注册的Client对象。 - -## getClient - -```java -public static Client getClient(String name) -``` - -获取Client对象。 - -- 参数 - - - `name`: Client对象名称。 - -- 返回值 - - Client对象。 diff --git a/docs/federated/docs/source_zh_cn/java_api_dataset.md b/docs/federated/docs/source_zh_cn/java_api_dataset.md deleted file mode 100644 index d397da7f43a56cfddef096e47fdc976a068a1ec6..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/java_api_dataset.md +++ /dev/null @@ -1,63 +0,0 @@ -# DataSet - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/java_api_dataset.md) - -```java -import com.mindspore.flclient.model.DataSet -``` - -DataSet定义了端侧联邦学习数据集对象。 - -## 公有成员函数 - -| function | -| -------------------------------- | -| [abstract void fillInputBuffer(List var1, int var2)](#fillinputbuffer) | -| [abstract void shuffle()](#shuffle) | -| [abstract void padding()](#padding) | -| [abstract Status dataPreprocess(List var1)](#datapreprocess) | - -## fillInputBuffer - -```java -public abstract void fillInputBuffer(List var1, int var2) -``` - -填充输入buffer数据。 - -- 参数 - - - `var1`: 需要填充的buffer内存。 - - `var2`: 需要填充的batch索引。 - -## shuffle - -```java - public abstract void shuffle() -``` - -打乱数据。 - -## padding - -```java - public abstract void padding() -``` - -补齐数据。 - -## dataPreprocess - -```java -public abstract Status dataPreprocess(List var1) -``` - -数据前处理。 - -- 参数 - - - `var1`: 使用的训练、评估或推理数据集。 - -- 返回值 - - 数据处理结果。 diff --git a/docs/federated/docs/source_zh_cn/java_api_flparameter.md b/docs/federated/docs/source_zh_cn/java_api_flparameter.md deleted file mode 100644 index ad6fbb09e68732d355f67fbc4a5428b9d1e38487..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/java_api_flparameter.md +++ /dev/null @@ -1,636 +0,0 @@ -# FLParameter - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/java_api_flparameter.md) - -```java -import com.mindspore.flclient.FLParameter -``` - -FLParameter定义联邦学习相关参数,供用户进行设置。 - -## 公有成员函数 - -| **function** | -| ------------------------------------------------------------ | -| public static synchronized FLParameter getInstance() | -| public String getDeployEnv() | -| public void setDeployEnv(String env) | -| public String getDomainName() | -| public void setDomainName(String domainName) | -| public String getClientID() | -| public void setClientID(String clientID) | -| public String getCertPath() | -| public void setCertPath(String certPath) | -| public SSLSocketFactory getSslSocketFactory() | -| public void setSslSocketFactory(SSLSocketFactory sslSocketFactory) | -| public X509TrustManager getX509TrustManager() | -| public void setX509TrustManager(X509TrustManager x509TrustManager) | -| public IFLJobResultCallback getIflJobResultCallback() | -| public void setIflJobResultCallback(IFLJobResultCallback iflJobResultCallback) | -| public String getFlName() | -| public void setFlName(String flName) | -| public String getTrainModelPath() | -| public void setTrainModelPath(String trainModelPath) | -| public String getInferModelPath() | -| public void setInferModelPath(String inferModelPath) | -| public String getSslProtocol() | -| public void setSslProtocol(String sslProtocol) | -| public int getTimeOut() | -| public void setTimeOut(int timeOut) | -| public int getSleepTime() | -| public void setSleepTime(int sleepTime) | -| public boolean isUseElb() | -| public void setUseElb(boolean useElb) | -| public int getServerNum() | -| public void setServerNum(int serverNum) | -| public boolean isPkiVerify() | -| public void setPkiVerify(boolean ifPkiVerify) | -| public String getEquipCrlPath() | -| public void setEquipCrlPath(String certPath) | -| public long getValidInterval() | -| public void setValidInterval(long validInterval) | -| public int getThreadNum() | -| public void setThreadNum(int threadNum) | -| public int getCpuBindMode() | -| public void setCpuBindMode(BindMode cpuBindMode) | -| public List getHybridWeightName(RunType runType) | -| public void setHybridWeightName(List hybridWeightName, RunType runType) | -| public Map/> getDataMap() | -| public void setDataMap(Map/> dataMap) | -| public ServerMod getServerMod() | -| public void setServerMod(ServerMod serverMod) | -| public int getBatchSize() | -| public void setBatchSize(int batchSize) | - -## getInstance - -```java -public static synchronized FLParameter getInstance() -``` - -获取FLParameter单例。 - -- 返回值 - - FLParameter类型的单例对象。 - -## getDeployEnv - -```java -public String getDeployEnv() -``` - -获取用户设置联邦学习的部署环境。 - -- 返回值 - - String类型的联邦学习的部署环境。 - -## setDeployEnv - -```java -public void setDeployEnv(String env) -``` - -用于设置联邦学习的部署环境,设置了白名单,目前只支持"x86"、"android"。 - -- 参数 - - - `env`: 联邦学习的部署环境。 - -## getDomainName - -```java -public String getDomainName() -``` - -获取用户设置的域名domainName。 - -- 返回值 - - String类型的域名。 - -## setDomainName - -```java -public void setDomainName(String domainName) -``` - -用于设置端云通信url,目前,可支持https和http通信,对应格式分别为:https://......、http://......,当`useElb`设置为true时,格式必须为:https://127.0.0.0:6666 或者http://127.0.0.0:6666 ,其中`127.0.0.0`对应提供云侧服务的机器ip(即云侧参数`--scheduler_ip`),`6666`对应云侧参数`--fl_server_port`。 - -- 参数 - - - `domainName`: 域名。 - -## getClientID - -```java -public String getClientID() -``` - -每次联邦学习任务启动前会自动生成一个唯一标识客户端的clientID(若用户需要自行设置clientID,可在启动联邦学习训练任务前使用setClientID进行设置),该方法用于获取该ID,可用于端云安全认证场景中生成相关证书。 - -- 返回值 - - String类型的唯一标识客户端的clientID。 - -## setClientID - -```java -public void setClientID(String clientID) -``` - -用于用户设置唯一标识客户端的clientID。 - -- 参数 - - - `clientID`: 唯一标识客户端的clientID。 - -## getCertPath - -```java -public String getCertPath() -``` - -获取用户设置的端云https通信所使用的自签名根证书路径certPath。 - -- 返回值 - - String类型的自签名根证书路径certPath。 - -## setCertPath - -```java -public void setCertPath(String certPath) -``` - -用于设置端云HTTPS通信所使用的自签名根证书路径certPath。当部署环境为"x86",且端云采用自签名证书进行https通信认证时,需要设置该参数,该证书需与生成云侧自签名证书所使用的CA根证书一致才能验证通过,此参数用于非Android场景。 - -- 参数 - - `certPath`: 端云https通信所使用的自签名根证书路径。 - -## getSslSocketFactory - -```java -public SSLSocketFactory getSslSocketFactory() -``` - -获取用户设置的ssl证书认证库sslSocketFactory。 - -- 返回值 - - SSLSocketFactory类型的ssl证书认证库sslSocketFactory。 - -## setSslSocketFactory - -```java -public void setSslSocketFactory(SSLSocketFactory sslSocketFactory) -``` - -用于设置ssl证书认证库sslSocketFactory。 - -- 参数 - - `sslSocketFactory`: ssl证书认证库。 - -## getX509TrustManager - -```java -public X509TrustManager getX509TrustManager() -``` - - 获取用户设置的ssl证书认证管理器x509TrustManager。 - -- 返回值 - - X509TrustManager类型的ssl证书认证管理器x509TrustManager。 - -## setX509TrustManager - -```java -public void setX509TrustManager(X509TrustManager x509TrustManager) -``` - -用于设置ssl证书认证管理器x509TrustManager。 - -- 参数 - - `x509TrustManager`:ssl证书认证管理器。 - -## getIflJobResultCallback - -```java -public IFLJobResultCallback getIflJobResultCallback() -``` - - 获取用户设置的联邦学习回调函数对象iflJobResultCallback。 - -- 返回值 - - IFLJobResultCallback类型的联邦学习回调函数对象iflJobResultCallback。 - -## setIflJobResultCallback - -```java -public void setIflJobResultCallback(IFLJobResultCallback iflJobResultCallback) -``` - -用于设置联邦学习回调函数对象iflJobResultCallback,用户可根据实际场景所需,实现工程中接口类[IFLJobResultCallback.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/IFLJobResultCallback.java)的具体方法后,作为回调函数对象设置到联邦学习任务中。 - -- 参数 - - `iflJobResultCallback`:联邦学习回调函数。 - -## getFlName - -```java -public String getFlName() -``` - -用于获取用户设置的模型脚本包路径。 - -- 返回值 - - String类型的模型脚本包路径。 - -## setFlName - -```java -public void setFlName(String flName) -``` - -设置模型脚本包路径。我们提供了两个类型的模型脚本供大家参考([有监督情感分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert)、[LeNet图片分类任务](https://gitee.com/mindspore/federated/tree/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet)),对于有监督情感分类任务,该参数可设置为所提供的脚本文件[AlBertClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/albert/AlbertClient.java) 的包路径`com.mindspore.flclient.demo.albert.AlbertClient`;对于LeNet图片分类任务,该参数可设置为所提供的脚本文件[LenetClient.java](https://gitee.com/mindspore/federated/blob/master/example/quick_start_flclient/src/main/java/com/mindspore/flclient/demo/lenet/LenetClient.java) 的包路径`com.mindspore.flclient.demo.lenet.LenetClient`。同时,用户可参考这两个类型的模型脚本,自定义模型脚本,然后将该参数设置为自定义的模型文件ModelClient.java(需继承于类[Client.java](https://gitee.com/mindspore/federated/blob/master/mindspore_federated/device_client/src/main/java/com/mindspore/flclient/model/Client.java))的包路径即可。 - -- 参数 - - `flName`: 模型脚本包路径。 - -## getTrainModelPath - -```java -public String getTrainModelPath() -``` - -用于获取用户设置的训练模型路径trainModelPath。 - -- 返回值 - - String类型的训练模型路径trainModelPath。 - -## setTrainModelPath - -```java -public void setTrainModelPath(String trainModelPath) -``` - -设置训练模型路径trainModelPath。 - -- 参数 - - `trainModelPath`: 训练模型路径。 - -## getInferModelPath - -```java -public String getInferModelPath() -``` - -用于获取用户设置的推理模型路径inferModelPath。 - -- 返回值 - - String类型的推理模型路径inferModelPath。 - -## setInferModelPath - -```java -public void setInferModelPath(String inferModelPath) -``` - -设置推理模型路径inferModelPath。 - -- 参数 - - `inferModelPath`: 推理模型路径。 - -## getSslProtocol - -```java -public String getSslProtocol() -``` - -用于获取用户设置的端云HTTPS通信所使用的TLS协议版本。 - -- 返回值 - - String类型的端云HTTPS通信所使用的TLS协议版本。 - -## setSslProtocol - -```java -public void setSslProtocol(String sslProtocol) -``` - -用于设置端云HTTPS通信所使用的TLS协议版本,设置了白名单,目前只支持"TLSv1.3"或者"TLSv1.2"。只在HTTPS通信场景中使用。 - -- 参数 - - `sslProtocol`: 端云HTTPS通信所使用的TLS协议版本。 - -## getTimeOut - -```java -public int getTimeOut() -``` - -用于获取用户设置的端侧通信的超时时间timeOut。 - -- 返回值 - - int类型的端侧通信的超时时间timeOut。 - -## setTimeOut - -```java -public void setTimeOut(int timeOut) -``` - -用于设置端侧通信的超时时间timeOut。 - -- 参数 - - `timeOut`: 端侧通信的超时时间。 - -## getSleepTime - -```java -public int getSleepTime() -``` - -用于获取用户设置的重复请求的等待时间sleepTime。 - -- 返回值 - - int类型的重复请求的等待时间sleepTime。 - -## setSleepTime - -```java -public void setSleepTime(int sleepTime) -``` - -用于设置重复请求的等待时间sleepTime。 - -- 参数 - - `sleepTime`: 重复请求的等待时间。 - -## isUseElb - -```java -public boolean isUseElb() -``` - -是否模拟弹性负载均衡,即客户端将请求随机发给一定范围内的server地址。 - -- 返回值 - - boolean类型,true代表客户端会将请求随机发给一定范围内的server地址, false客户端的请求会发给固定的server地址。 - -## setUseElb - -```java -public void setUseElb(boolean useElb) -``` - -用于设置是否模拟弹性负载均衡,即客户端将请求随机发给一定范围内的server地址。 - -- 参数 - - `useElb`: 是否模拟弹性负载均衡,默认为false。 - -## getServerNum - -```java -public int getServerNum() -``` - -用于获取用户设置的模拟弹性负载均衡时可发送请求的server数量。 - -- 返回值 - - int类型的模拟弹性负载均衡时可发送请求的server数量。 - -## setServerNum - -```java -public void setServerNum(int serverNum) -``` - -用于设置模拟弹性负载均衡时可发送请求的server数量。 - -- 参数 - - `serverNum`: 模拟弹性负载均衡时可发送请求的server数量,默认为1。 - -## isPkiVerify - -```java -public boolean isPkiVerify() -``` - -是否进行端云认证。 - -- 返回值 - - boolean类型,true代表进行端云认证,false代表不进行端云认证。 - -## setPkiVerify - -```java -public void setPkiVerify(boolean pkiVerify) -``` - -用于设置是否进行端云认证。 - -- 参数 - - - `pkiVerify`: 是否进行端云认证。 - -## getEquipCrlPath - -```java -public String getEquipCrlPath() -``` - -获取用户设置的设备证书的CRL证书路径equipCrlPath,此参数用于Android环境。 - -- 返回值 - - String类型的证书路径equipCrlPath。 - -## setEquipCrlPath - -```java -public void setEquipCrlPath(String certPath) -``` - -用于设置设备证书的CRL证书路径,用于验证数字证书是否被吊销,此参数用于Android环境。 - -- 参数 - - `certPath`: 证书路径。 - -## getValidInterval - -```java -public long getValidInterval() -``` - -获取用户设置的有效迭代时间间隔validIterInterval,此参数用于Android环境。 - -- 返回值 - - long类型的有效迭代时间间隔validIterInterval。 - -## setValidInterval - -```java -public void setValidInterval(long validInterval) -``` - -用于设置有效迭代时间间隔validIterInterval,建议时长为端云间一个训练epoch的时长(单位:毫秒),用于防范重放攻击,此参数用于Android环境。 - -- 参数 - - `validInterval`: 有效迭代时间间隔。 - -## getThreadNum - -```java -public int getThreadNum() -``` - -获取联邦学习训练和推理时使用的线程数,默认值为1。 - -- 返回值 - - int类型的线程数threadNum。 - -## setThreadNum - -```java -public void setThreadNum(int threadNum) -``` - -设置联邦学习训练和推理时使用的线程数。 - -- 参数 - - `threadNum`: 线程数。 - -## getCpuBindMode - -```java -public int getCpuBindMode() -``` - -获取联邦学习训练和推理时线程所需绑定的cpu内核。 - -- 返回值 - - 将枚举类型的cpu内核cpuBindMode转换为int型返回。 - -## setCpuBindMode - -```java -public void setCpuBindMode(BindMode cpuBindMode) -``` - -设置联邦学习训练和推理时线程所需绑定的cpu内核。 - -- 参数 - - `cpuBindMode`: BindMode枚举类型,其中BindMode.NOT_BINDING_CORE代表不绑定内核,由系统自动分配,BindMode.BIND_LARGE_CORE代表绑定大核,BindMode.BIND_MIDDLE_CORE代表绑定中核。 - -## getHybridWeightName - -```java -public List getHybridWeightName(RunType runType) -``` - -混合学习模式时使用。获取用户设置的训练权重名和推理权重名。 - -- 参数 - - - `runType`: RunType枚举类型,只支持设置为RunType.TRAINMODE(代表获取训练权重名)、RunType.INFERMODE(代表获取推理权重名)。 - -- 返回值 - - List 类型,根据参数runType返回相应的权重名列表。 - -## setHybridWeightName - -```java -public void setHybridWeightName(List hybridWeightName, RunType runType) -``` - -由于混合学习模式时,云侧下发的权重,一部分导入到训练模型,一部分导入到推理模型,但框架本身无法判断,需要用户自行设置相关训练权重名和推理权重名。该方法提供给用户进行设置。 - -- 参数 - - `hybridWeightName`: List 类型的权重名列表。 - - `runType`: RunType枚举类型,只支持设置为RunType.TRAINMODE(代表设置训练权重名)、RunType.INFERMODE(代表设置推理权重名)。 - -## getDataMap - -```java -public Map> getDataMap() -``` - -获取用户设置的联邦学习数据集。 - -- 返回值 - - Map>类型的数据集。 - -## setDataMap - -```java -public void setDataMap(Map> dataMap) -``` - -设置联邦学习数据集。 - -- 参数 - - `dataMap`: Map>类型的数据集,map中key为RunType枚举类型,value为对应的数据集列表,key为RunType.TRAINMODE时代表对应的value为训练相关的数据集列表,key为RunType.EVALMODE时代表对应的value为验证相关的数据集列表, key为RunType.INFERMODE时代表对应的value为推理相关的数据集列表。 - -## getServerMod - -```java -public ServerMod getServerMod() -``` - -获取联邦学习训练模式。 - -- 返回值 - - ServerMod枚举类型的联邦学习训练模式。 - -## setServerMod - -```java -public void setServerMod(ServerMod serverMod) -``` - -设置联邦学习训练模式。 - -- 参数 - - `serverMod`: ServerMod枚举类型的联邦学习训练模式,其中ServerMod.FEDERATED_LEARNING代表普通联邦学习模式(训练和推理使用同一个模型)ServerMod.HYBRID_TRAINING代表混合学习模式(训练和推理使用不同的模型,且云侧也包含训练过程)。 - -## getBatchSize - -```java -public int getBatchSize() -``` - -获取联邦学习训练和推理时使用的单步训练样本数,即batch size。 - -- 返回值 - - int类型的单步训练样本数batchSize。 - -## setBatchSize - -```java -public void setBatchSize(int batchSize) -``` - -设置联邦学习训练和推理时使用的单步训练样本数,即batch size。需与模型的输入数据的batch size保持一致。 - -- 参数 - - `batchSize`: 单步训练样本数,即batch size。 \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/java_api_syncfljob.md b/docs/federated/docs/source_zh_cn/java_api_syncfljob.md deleted file mode 100644 index 1ed1c27bbf1649ff6642c89b668dab69e1c2b838..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/java_api_syncfljob.md +++ /dev/null @@ -1,64 +0,0 @@ -# SyncFLJob - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/java_api_syncfljob.md) - -```java -import com.mindspore.flclient.SyncFLJob -``` - -SyncFLJob定义了端侧联邦学习启动接口flJobRun()、端侧推理接口modelInference()、获取云侧最新模型的接口getModel()、停止联邦学习训练任务的接口stopFLJob()。 - -## 公有成员函数 - -| **function** | -| -------------------------------- | -| public FLClientStatus flJobRun() | -| public int[] modelInference() | -| public FLClientStatus getModel() | -| public void stopFLJob() | - -## flJobRun - -```java -public FLClientStatus flJobRun() -``` - -启动端侧联邦学习任务,具体使用方法可参考[接口介绍文档](https://www.mindspore.cn/federated/docs/zh-CN/master/interface_description_federated_client.html)。 - -- 返回值 - - 返回flJobRun请求状态码。 - -## modelInference - -```java -public int[] modelInference() -``` - -启动端侧推理任务,具体使用方法可参考[接口介绍文档](https://www.mindspore.cn/federated/docs/zh-CN/master/interface_description_federated_client.html)。 - -- 返回值 - - 根据输入推理出的标签组成的int[]。 - -## getModel - -```java -public FLClientStatus getModel() -``` - -获取云侧最新模型,具体使用方法可参考[接口介绍文档](https://www.mindspore.cn/federated/docs/zh-CN/master/interface_description_federated_client.html)。 - -- 返回值 - - 返回getModel请求状态码。 - -## stopFLJob - -```java -public void stopFLJob() -``` - -在联邦学习训练任务中,可通过调用该接口停止训练任务。 - -当一个线程调用SyncFLJob.flJobRun()时,可在联邦学习训练过程中,使用另外一个线程调用SyncFLJob.stopFLJob()停止联邦学习训练任务。 diff --git a/docs/federated/docs/source_zh_cn/local_differential_privacy_eval_laplace.md b/docs/federated/docs/source_zh_cn/local_differential_privacy_eval_laplace.md deleted file mode 100644 index 311cd71fda89b130c57671f48c9d17c5b2a15495..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/local_differential_privacy_eval_laplace.md +++ /dev/null @@ -1,236 +0,0 @@ -# 横向联邦-局部差分隐私推理结果保护 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/local_differential_privacy_eval_laplace.md) - -## 隐私保护背景 - -评价联邦无监督模型训练的好坏,可通过端侧反馈的$loss$判断,也可利用端侧推理结果结合云侧聚类及聚类评估指标,来进一步监测联邦无监督模型训练进度。后者涉及到端侧推理数据上云,为满足隐私保护要求,需要对端侧推理数据进行隐私保护处理,同时云侧仍可进行聚类评估。该任务相较训练任务为辅助任务,则尽量使用轻量级算法,不能引入较训练阶段计算或通讯开销更大的隐私保护算法,本文介绍了一种利用局部差分隐私Laplace噪声机制保护推理结果的轻量级方案。 - -将隐私保护技术有效地融入到产品服务中,一方面有利于提升用户以及业界对产品及技术的信任度,另一方面有助于在满足当前隐私合规要求之下更好地开展联邦任务,打造全生命周期(训练-推理-评估)的隐私保护。 - -## 算法分析 - -### $L1$与$L2$范式 - -长度为$k$的向量$V$的$L1$范数为 $||V||_1=\sum^{k}_{i=1}{|V_i|}$,则在二维空间中,两个向量差的$L1$范数就是曼哈顿距离。 - -$L2$范数为 $||V||_2=\sqrt{\sum^{k}_{i=1}{V^2_i}}$。 - -推理结果一般为$softmax$结果,和为$1$,向量的各个维度值表示所属该维度对应类别的概率。 - -### $L1$与$L2$敏感度 - -本地差分隐私对要上传的数据引入不确定性,敏感度描述了不确定性的上界。在优化器和联邦训练中,可以给梯度添加$L2$敏感度的高斯噪声,因为添加前会对梯度向量进行裁剪操作。此处$softmax$推理结果满足和为$1$,因此添加$L1$的拉普拉斯噪声。对于$L2$灵敏度远低于$L1$灵敏度的应用程序,高斯机制允许增加更少的噪声,但该场景没有$L2$相关的约束限制,仅使用$L1$敏感度。 - -$L1$敏感度在本地差分隐私中表现为定义域内任意输入的最大距离: - -$\Delta f=max||X-Y||_1$ - -在本场景中,$X=, Y=, \sum X = 1, \sum Y = 1, |x_1-y_1|+|x_2-y_2|+...+|x_k-y_k|\leq1=\Delta f$。 - -### Laplace分布 - -拉普拉斯分布是连续的,均值为0的拉普拉斯的概率密度函数为: - -$Lap(x|b)=\frac{1}{2b}exp(-\frac{|x|}{b})$ - -### Laplace机制 - -$M(x,\epsilon)=X+Lap(\Delta f/\epsilon)$ - -其中,$Lap(\Delta f/\epsilon)$是和$X$同shape,独立同分布的随机变量向量。 - -在此场景中,$b$(又叫$scale$、$lambda$,$beta$)为$1/\epsilon$。 - -### 证明拉普拉斯机制是满足$\epsilon-LDP$的 - -任意两个不同的客户端,经过拉普拉斯机制处理之后,都输出相同结果来达到混淆不可区分的目的概率比有上确界。将$b=\Delta f/\epsilon$代入得到: - -$Lap(\Delta f/\epsilon)=\frac{\epsilon}{2\Delta f}exp(-\frac{\epsilon|x|}{\Delta f})$ - -$\frac{P(Z|X)}{P(Z|Y)}$ - -$=\prod^k_{i=1}(\frac{exp(-\frac{\epsilon|x_i-z_i|}{\Delta f})}{exp(-\frac{\epsilon |y_i-z_i|}{\Delta f})})$ - -$=\prod^k_{i=1}exp(\epsilon\frac{|x_i-z_i|-|y_i-z_i|}{\Delta f})$ - -$\leq\prod^k_{i=1}(\epsilon\frac{|x_i-y_i|}{\Delta f})$ - -$=exp(\epsilon\frac{X-Y}{\Delta f})$ - -$\leq exp(\epsilon)$ - -#### $\epsilon$ 的确定与对应的概率密度图 - -结合数据特点计算出可用性较高的隐私预算,比如要求大概率输出$1e-5$量级的噪声,否则会直接严重影响聚类结果。下面给出产生指定量级噪声对应的隐私预算计算方法。 - -$90\%$概率输出$1e-5$量级的大小,对概率密度曲线积分得到$\epsilon$的取值。 - -$x>=0, Lap(x|b)=\frac{1}{2b}exp(-\frac{x}{b})$ - -$\int^ {E^{-5}}_0 {Lap(x|b)dx}$ - -$=1-\frac{1}{2}exp(-\frac{x}{b})|^{E^{-5}}_{0}$ - -$=\frac{1}{2}(exp(0)-exp(-\frac{E^{-5}}{b}))$ - -$=0.5(1-exp(-\frac{E^{-5}}{b})) = 0.45$ - -即: - -$exp(-\frac{E^{-5}}{b})=0.1$ - -$b=-E^{-5}/ln(0.1)=E^{-5}/2.3026=1/\epsilon$ - -$\epsilon=2.3026E^5$ - -当隐私预算取该值时,拉普拉斯概率密度函数如下: - -![laplace](./images/laplace_pdf.png) - -### 聚类评价指标的影响性分析 - -以**Calinski-Harabasz Index**评估方法举例,该评价指标计算过程分为两步: - -1. 每个类计算该类中所有`点`到 `该类中心`距离的平方和; - -2. 计算每个`类`与`类中心`距离平方和; - -源码实现与加噪之后的影响性分析: - -```python -# 1.云侧聚类算法得到所属类序号,有影响 -n_labels = argmax(X) - -extra_disp, intra_disp = 0.0, 0.0 -# 2.计算所有点的类中心,不影响 -mean = np.mean(X, axis=0) -for k in range(n_labels): - # 3.得到第k类中的所有点,基于1的影响 - cluster_k = X[labels == k] - # 4.得到该类中心,基于1的影响 - mean_k = np.mean(cluster_k, axis=0) - # 5.该类与所有类中心距离,基于1的影响 - extra_disp += len(cluster_k) * np.sum((mean_k - mean) ** 2) - # 6.点到该类中心距离,有影响 - intra_disp += np.sum((cluster_k - mean_k) ** 2) - -return ( - 1.0 - if intra_disp == 0.0 - else extra_disp * (n_samples - n_labels) / (intra_disp * (n_labels - 1.0)) -) -``` - -综合分析,主要影响在加噪之后对聚类算法的影响,还有距离计算上的误差。在计算类中心时,由于噪声和期望为$0$,所以引入的误差较小。 - -以**SILHOUETTE SCORE**举例,该评价指标计算过程分为两步: - -1. 计算一个样本点$i$与同簇的其他所有样本点的平均距离,记为$a_i$;该值越小,表示样本$i$越应该分到这个簇。 - -2. 计算样本$i$到其他某簇$C_j$的所有样本的平均距离$b_{ij}$,称为样本$i$与簇$C_j$的不相似度。定义为样本$i$的簇间不相似度:$b_i = min(b_{i1}, b_{i2}, …, b_{ik})$。该值越大,说明样本$i$越不应该属于这个簇。 - -![flow](./images/two_cluster.png) - -$s_i=(b_i-a_i) / max(a_i, b_i)$. - -$a_i$越小,$b_i$越大,结果为$1-a_i / b_i$就越接近$1$,聚类效果越好。 - -伪代码实现与加噪之后的影响性分析: - -```c++ -// 计算距离矩阵,空间换时间,上三角存储,加噪有影响 -euclidean_distance_matrix(&distance_matrix, group_ids); - -// 对每个点都进行相同的计算,最后计算均值 -for (size_t i = 0; i < n_samples; ++i) { - std::unordered_map> b_i_map; - for (size_t j = 0; j < n_samples; ++j) { - size_t label_j = labels[j]; - float distance = distance_matrix[i][j]; - // 同簇计算ai - if (label_j == label_i) { - a_distances.push_back(distance); - } else { - // 非同簇计算bi - b_i_map[label_j].push_back(distance); - } - } - if (a_distances.size() > 0) { - // 计算该点距离同簇其他点平均距离 - a_i = std::accumulate(a_distances.begin(), a_distances.end(), 0.0) / a_distances.size(); - } - for (auto &item : b_i_map) { - auto &b_i_distances = item.second; - float b_i_distance = std::accumulate(b_i_distances.begin(), b_i_distances.end(), 0.0) / b_i_distances.size(); - b_i = std::min(b_i, b_i_distance); - } - if (a_i == 0) { - s_i[i] = 0; - } else { - s_i[i] = (b_i - a_i) / std::max(a_i, b_i); - } -} -return std::accumulate(s_i.begin(), s_i.end(), 0.0) / n_samples; -``` - -同上,主要影响在加噪之后对聚类算法的影响,还有距离计算上的误差。 - -### 端侧Java实现 - -Java基本库中没有生成Laplace分布随机数的函数,采用如下随机数的组合策略产生。 - -源码如下: - -```java -float genLaplaceNoise(SecureRandom secureRandom, float beta) { - float u1 = secureRandom.nextFloat(); - float u2 = secureRandom.nextFloat(); - if (u1 <= 0.5f) { - return (float) (-beta * log(1. - u2)); - } else { - return (float) (beta * log(u2)); - } -} -``` - -在端侧获得新一轮模型后,立即执行推理计算,等待训练结束之后,连同新模型和隐私保护之后的推理结果一同上传至云侧,云侧最终执行聚类和分数计算等操作。流程见下图,其中红色部分为隐私保护处理的输出结果: - -![flow](./images/eval_flow.png) - -## 快速上手 - -### 准备工作 - -若要使用该功能,首先需要成功完成任一端云联邦场景的训练聚合过程,[实现一个端云联邦的图像分类应用(x86)](https://www.mindspore.cn/federated/docs/zh-CN/master/image_classification_application.html)详细介绍了数据集、网络模型等准备工作,以及模拟启动多客户端参与联邦学习的流程。 - -### 配置项 - -[云侧yaml配置文件](https://gitee.com/mindspore/federated/blob/master/tests/st/cross_device_cloud/default_yaml_config.yaml)给出了开启端云联邦的完整配置项,该方案涉及到的新增配置文件项如下: - -```c -encrypt: - privacy_eval_type: LAPLACE - laplace_eval: - laplace_eval_eps: 230260 -``` - -其中`privacy_eval_type`目前仅支持`NOT_ENCRYPT`和`LAPLACE`,分别表示不使用隐私保护方法处理推理结果和使用`LAPLACE`机制处理。 - -`laplace_eval_eps`表示如果使用`LAPLACE`处理,所使用的的隐私预算为多少。 - -## 实验结果 - -推理结果评估函数相关的基本配置使用如下: - -```c -unsupervised: - cluster_client_num: 1000 - eval_type: SILHOUETTE_SCORE -``` - -观察在使用`NOT_ENCRYPT`和使用`laplace_eval_eps=230260`的`LAPLACE`机制下,$loss$与分数之间的关系如图所示: - -![flow](./images/SILHOUETTE.png) - -红色虚线为使用Laplace机制保护推理结果后的SILHOUETTE分数,由于模型中含有$dropout$和高斯输入,两次训练的$loss$略微不同,基于不同的模型得到的分数也有略微不同。但整体趋势保持一致,可以辅助$loss$一起检测模型训练进展。 diff --git a/docs/federated/docs/source_zh_cn/local_differential_privacy_training_noise.md b/docs/federated/docs/source_zh_cn/local_differential_privacy_training_noise.md deleted file mode 100644 index 63d870b228ae757289f18fba3b396724b383a25e..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/local_differential_privacy_training_noise.md +++ /dev/null @@ -1,43 +0,0 @@ -# 横向联邦-局部差分隐私加噪训练 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/local_differential_privacy_training_noise.md) - -联邦学习过程中,用户数据仅用于客户端设备的本地训练,不需要上传至中心服务器,可以避免泄露用户个人数据。然而,传统联邦学习框架中,模型以明文形式上云,仍然存在间接泄露用户隐私的风险。攻击者获取到客户端上传的明文模型后,可以通过重构、模型逆向等攻击方式,恢复参与学习的用户个人数据,导致用户隐私泄露。 - -MindSpore Federated联邦学习框架,提供了基于本地差分隐私(LDP)算法,在客户端上传本地模型前对其进行加噪。在保证模型可用性的前提下,解决横向联邦学习中的隐私泄露问题。 - -## 原理概述 - -差分隐私(differential privacy)是一种保护用户数据隐私的机制。差分隐私定义为: - -$$ -Pr[\mathcal{K}(D)\in S] \le e^{\epsilon} Pr[\mathcal{K}(D’) \in S]+\delta​ -$$ - -对于两个差别只有一条记录的数据集$D, D’$,通过随机算法$\mathcal{K}$,输出结果为集合$S$子集的概率满足上述公式。$\epsilon$为差分隐私预算,$\delta$扰动,$\epsilon$和$\delta$越小,说明$\mathcal{K}$在$D$和$D’$上输出的数据分布越接近。 - -在横向联邦学习中,假设客户端本地训练之后的模型权重矩阵是$W$,由于模型在训练过程中会“记住”训练集的特征,所以攻击者可以借助$W$还原出用户的训练数据集[1]。 - -MindSpore Federated提供基于本地差分隐私的安全聚合算法,防止客户端上传本地模型时泄露用户隐私数据。 - -MindSpore Federated客户端会生成一个与本地模型$W$相同维度的差分噪声矩阵$G$,然后将二者相加,得到一个满足差分隐私定义的权重$W_p$: - -$$ -W_p=W+G -$$ - -MindSpore Federated客户端将加噪后的模型$W_p$上传至云侧服务器进行联邦聚合。噪声矩阵$G$相当于给原模型加上了一层掩码,在降低模型泄露敏感数据风险的同时,也会影响模型训练的收敛性。如何在模型隐私性和可用性之间取得更好的平衡,仍然是一个值得研究的问题。实验表明,当参与方的数量$n$足够大时(一般指1000以上),大部分噪声能够相互抵消,本地差分机制对聚合模型的精度和收敛性没有明显影响。 - -## 使用方式 - -本地差分隐私训练目前只支持端云联邦学习场景。开启差分隐私训练的方式很简单,只需要在启动云侧服务时,通过[yaml](https://www.mindspore.cn/federated/docs/zh-CN/master/horizontal/federated_server_yaml.html#)设置`encrypt_train_type`字段为`DP_ENCRYPT`即可。 - -此外,为了控制隐私保护的效果,我们还提供了3个参数:`dp_eps`,`dp_delta`以及`dp_norm_clip`,它们也是通过yaml文件进行设置。 - -`dp_eps`和`dp_norm_clip`的合法取值范围是大于0,`dp_delta`的合法取值范围是0<`dp_delta`<1。一般来说,`dp_eps`和`dp_delta`越小,隐私保护效果也越好,但是对模型收敛性的影响越大。建议`dp_delta`取成客户端数量的倒数,`dp_eps`大于50。 - -`dp_norm_clip`是差分隐私机制对模型权重加噪前对权重大小的调整系数,会影响模型的收敛性,一般建议取0.5~2。 - -## 参考文献 - -[1] Ligeng Zhu, Zhijian Liu, and Song Han. [Deep Leakage from Gradients](http://arxiv.org/pdf/1906.08935.pdf). NeurIPS, 2019. diff --git a/docs/federated/docs/source_zh_cn/local_differential_privacy_training_signds.md b/docs/federated/docs/source_zh_cn/local_differential_privacy_training_signds.md deleted file mode 100644 index 1cb2950dcf0180aac49b5c1fe2e753510cf0b570..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/local_differential_privacy_training_signds.md +++ /dev/null @@ -1,179 +0,0 @@ -# 横向联邦-局部差分隐私SignDS训练 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/local_differential_privacy_training_signds.md) - -## 隐私保护背景 - -联邦学习通过让参与方只上传本地训练后的新模型或更新模型的update信息,实现了client用户不上传原始数据集就能参与全局模型训练的目的,打通了数据孤岛。这种普通场景的联邦学习对应MindSpore联邦学习框架中的默认方案,启动`server`时,`encrypt_train_type`开关默认为`not_encrypt`,联邦学习教程中的`安装部署`与`应用实践`都默认使用这种方式,是没有任何加密扰动等保护隐私处理的普通联邦求均方案,为方便描述,下文以`not_encrypt`来特指这种默认方案。 - -这种联邦学习方案并不是毫无隐私泄漏的,使用上述`not_encrypt`方案进行训练,服务端Server收到客户端Client上传的本地训练模型,仍可通过一些攻击方法[1]重构用户训练数据,从而泄露用户隐私,所以`not_encrypt`方案需要进一步增加用户隐私保护机制。 - -联邦学习中客户端Client每轮接收的全局模型`oldModel`都是由服务端Server下发的,不涉及用户隐私问题。但各客户端Client本地训练若干epoch后得到的本地模型`newModel`拟合了其本地隐私数据,所以隐私保护重点是二者的权重差值`newModel`-`oldModel`=`update`。 - -MindSpore Federated框架中已实现的`DP_ENCRYPT`差分噪声方案通过向`update`迭加高斯随机噪声进行扰动,实现隐私保护。但随着模型维度增大,`update`范数增大会使噪声增大,从而需要较多的客户端Client参与同一轮聚合,以中和噪声影响,否则模型收敛性和精度会降低。如果设置的噪声过小,虽然收敛性和精度与`not_encrypt`方案性能接近,但隐私保护力度不够。同时每个客户端Client都需要发送扰动后的模型,随着模型增大,通信开销也会随之增大。我们期望手机为代表的客户端Client,以尽可能少的通信开销,即可实现全局模型的收敛。 - -## 算法流程介绍 - -SignDS[2]是Sign Dimension Select的缩写,处理对象是客户端Client的`update`。准备工作:把`update`的每一层Tensor拉平展开成一维向量,连接在一起,拼接向量维度数量记为$d$。 - -一句话概括算法:每个参与方仅上传重要维度的信息,信息包括它们的梯度方向和隐私保护的步长。分别对应下图中的SignDS和MagRR(Magnitude Random Response)模块。 - -![signds](./images/signds_framework.png) - -下面举例来说明:现有3个客户端Client1,2,3,其`update`拉平展开后为$d=8$维向量,服务端Server计算这3个客户端Client的`avg`,并用该值更新全局模型,即完成一轮联邦学习。 - -| Client | d_1 | d_2 | d_3 | d_4 | d_5 | d_6 | d_7 | d_8 | -| :----: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :---: | -| 1 | 0.4 | 0.1 | -0.2 | 0.3 | 0.5 | 0.1 | -0.2 | -0.3 | -| 2 | 0.5 | 0.2 | 0 | 0.1 | 0.3 | 0.2 | -0.1 | -0.2 | -| 3 | 0.3 | 0.1 | -0.1 | 0.5 | 0.2 | 0.3 | 0 | 0.1 | -| avg | 0.4 | 0.13 | -0.1 | 0.3 | 0.33 | 0.2 | -0.1 | -0.13 | - -### SignDS - -应选择重要性较高的维度,重要性衡量标准是**取值的大小**,需要对update进行排序。update取值正负代表不同的更新方向,故每轮联邦学习中,客户端Client的sign值各有**0.5的概率**取`1`或`-1`。如果sign=1,则将最大的$k$个`update`维度记为`topk`集合,剩余的记为`non-topk`集合;如果sign=-1,则将最小的$k$个记为`topk`集合。 - -如果服务端Server指定总共选择的维度数量`h`,客户端Client会直接使用该值,否则各客户端Client会本地计算出最优的输出维度`h`。 - -随后SignDS算法会输出应从`topk`集合和`non-topk`集合中选择的维度数量(记为$v$),如下表中示例,两个集合总共挑选维度h=3。 - -客户端Client按照SignDS算法输出的维度数量,均匀随机挑选维度,将维度序号和sign值发送至服务端Server,维度序号如果按照先从`topk`挑选,再从`non-topk`挑选的顺序输出,则需要对维度序号列表`index`进行洗牌打乱操作,下表为该算法各客户端Client最终传输至服务端Server的部分信息: - -| Client | index | sign | -| :----: | :---: | :--: | -| 1 | 1,5,8 | 1 | -| 2 | 2,3,4 | -1 | -| 3 | 3,6,7 | 1 | - -### MagRR - -服务端Server收到客户端发来的维度方向,但不清楚在该方向要更新的步长是多少。通常来讲,在训练初期,步长往往很大,随着训练逐渐收敛,步长缩小。步长变化的大致趋势如下图所示: - -![step_length](./images/signds_step_length.png) - -服务端Server希望估计一个针对实际步长$r$的动态范围$[0,2∗r_{est}]$,进而计算全局学习率$lr_{global}=2∗r_{est}*num_{clients}$。 - -$r$的调整采用类似二分法思路。具体流程如下: - -1. 训练开始前,服务端初始化一个较小的$r_{est}$(不会对模型收敛方向造成过大影响); -2. 每轮本地训练后,参与方计算真实幅值$r$(topk维度的均值),并根据当前云侧下发的$r_{est}$将$r$以一定规则转换为$b$; -3. 参与方对$b$进行本地差分Binary Randomized Response(BRR)扰动,并将结果上传。 - -整个训练过程分为两个阶段,即**快增长**阶段和**收缩**阶段。参与方在两个阶段进行$r \rightarrow b$转换和服务端更新$r_{est}$的规则略有不同: - -- 快增长阶段,选取一个较小的$r_{est}$,如$e^{−5}$。此时,需要以一定倍数扩大$r_{est}$。 - 因此定义: - - $$ - b = \begin{cases} - 0 & r \in [2*r_{est}, \infty] \\ - 1 & r \in [0,2*r_{est})] - \end{cases} - $$ - - 服务端聚合所有端侧随机响应结果进行频率统计,计算众数$B$, - 若$B=0$,则认为目前$r_{est}$未到达𝑟的范围,需继续增大$r_{est}$; - 若$B=1$,则认为$r_{est}$已到达𝑟的范围,保持$r_{est}$不变。 -- 收缩阶段,需要根据$r$的变化微调$r_{est}$。因此定义: - - $$ - b = \begin{cases} - 0 & r \in [r_{est}, \infty] \\ - 1 & r \in [0,r_{est})] - \end{cases} - $$ - - 计算$B$,若$B=0$,则认为目前$r_{est}$和$r$较为接近,保持$r_{est}$不变; - 若$B=1$,则认为$r$普遍小于$r_{est}$,则将$r_{est}$减半。 - -服务端Server根据各客户端Client上传的维度序号,sign值和$r_{est}$,构建带隐私保护的`update`,对所有`update`进行聚合平均并更新当前`oldModel`即完成一轮联邦学习。下表展示了$2∗r_{est}*num_{clients}=1$时的聚合情况。 - -| Client | d_1 | d_2 | d_3 | d_4 | d_5 | d_6 | d_7 | d_8 | -| :----: | :---: | :----: | :----: | :----: | :---: | :---: | :---: | :---: | -| 1 | **1** | 0 | 0 | 0 | **1** | 0 | 0 | **1** | -| 2 | 0 | **-1** | **-1** | **-1** | 0 | 0 | 0 | 0 | -| 3 | 0 | 0 | **1** | 0 | 0 | **1** | **1** | 0 | -| avg | 1/3 | -1/3 | 0 | -1/3 | 1/3 | 1/3 | 1/3 | 1/3 | - -SignDS方案使端侧client只上传算法输出的int类型维度序号列表,一个布尔类型的随机Sign值和对估计值的反馈结果到云侧,相比普通场景中上传数万float级别的完整模型权重或梯度,通讯开销显著降低。从实际重构攻击的角度来看,云侧仅获得维度序号、代表梯度更新方向的一个Sign值和隐私保护的步长估计反馈值,攻击更加难以实现。整体方案的数据流字段如下图所示: - -![flow](./images/signds_flow.png) - -## 隐私保护证明 - -差分隐私噪声方案通过加噪的方式,让攻击者无法确定原始信息,从而实现隐私保护;而差分隐私SignDS方案只激活部分维度,且用sign值代替原始值,很大程度上保护了用户隐私。进一步的,利用差分隐私指数机制让攻击者无法确认激活的维度是否是重要(来自`topk`集合),且无法确认输出维度中来自`topk`的维度数量是否超过给定阈值。 - -### 基于指数机制的维度选择机制 - -对于每个客户端Client的任意两个update $\Delta$ 和 $\Delta'$ ,其`topk`维度集合分别是 $S_{topk}$ , ${S'}_{topk}$ ,该算法任意可能的输出维度集合是 ${J}\in {\mathcal{J}}$ ,记 $\nu=|{S}_{topk}\cap {J}|$ , $\nu'=|{S'}_{topk}\cap {J}|$ 是 ${J}$ 和`topk` 集合交集的数量,算法使得以下不等式成立: - -$$ -\frac{{Pr}[{J}|\Delta]}{{Pr}[{J}|\Delta']}=\frac{{Pr}[{J}|{S}_{topk}]}{{Pr}[{J}|{S'}_{topk}]}=\frac{\frac{{exp}(\frac{\epsilon}{\phi_u}\cdot u({S}_{topk},{J}))}{\sum_{{J'}\in {\mathcal{J}}}{exp}(\frac{\epsilon}{\phi_u}\cdot u({S}_{topk}, {J'}))}}{\frac{{exp}(\frac{\epsilon}{\phi_u}\cdot u({S'}_{topk}, {J}))}{\sum_{ {J'}\in {\mathcal{J}}}{exp}(\frac{\epsilon}{\phi_u}\cdot u( {S'}_{topk},{J'}))}}=\frac{\frac{{exp}(\epsilon\cdot \unicode{x1D7D9}(\nu \geq \nu_{th}))}{\sum_{\tau=0}^{\tau=\nu_{th}-1}\omega_{\tau} + \sum_{\tau=\nu_{th}}^{\tau=h}\omega_{\tau}\cdot {exp}(\epsilon)}}{\frac{ {exp}(\epsilon\cdot \unicode{x1D7D9}(\nu' \geq\nu_{th}))}{\sum_{\tau=0}^{\tau=\nu_{th}-1}\omega_{\tau}+\sum_{\tau=\nu_{th}}^{\tau=h}\omega_{\tau}\cdot {exp}(\epsilon)}}\\= \frac{{exp}(\epsilon\cdot \unicode{x1D7D9} (\nu \geq \nu_{th}))}{ {exp}(\epsilon\cdot \unicode{x1D7D9} (\nu' \geq \nu_{th}))} \leq \frac{{exp}(\epsilon\cdot 1)}{{exp}(\epsilon\cdot 0)} = {exp}(\epsilon), -$$ - -证明该算法满足局部差分隐私。 - -### 局部差分隐私-随机响应机制 - -参与方收到服务端下发的估计值,在本地训练完成后,计算真实update的topk维度权重均值,根据magRR策略输出0或1,我们认为0或1仍然带有权重均值范围信息,则需要进一步保护。 - -随机响应机制的输入为待保护数据($b\in \{0,1\}$)和隐私参数$\epsilon$,按照一定的概率对数据进行翻转,输出$\hat{b} \in \{0,1\}$,规则如下: - -$$ -\hat{b} = \begin{cases} -b & with \quad probability \quad P \\ -1-b & with \quad probability \quad 1-P -\end{cases} -$$ - -其中$P=\frac{e^\epsilon}{1+e^\epsilon}$。 - -#### 基于随机响应机制的频率统计 - -通过随机响应的方式使敌手很难区分真实数据和扰动数据,来达到以假乱真的效果,但也会影响云侧统计任务的可用性。服务端可通过降噪的方式近似的真实统计频率值,但很难逆向推断出用户的真实输入。记$N$为一轮参与方总数,$N^T$为原始为1的总数,$N^C$为服务端收集到的1的总数,则有: - -$$ -N^T*P+(N-N^T)*(1-P)=N^C \\ -N^T=\frac{N^C-N+NP}{2P-1} -$$ - -## 准备工作 - -若要使用该算法,首先需要成功完成任一端云联邦场景的训练聚合过程,[实现一个端云联邦的图像分类应用(x86)](https://www.mindspore.cn/federated/docs/zh-CN/master/image_classification_application.html)详细介绍了数据集、网络模型等准备工作,以及模拟启动多客户端参与联邦学习的流程。 - -## 算法开启脚本 - -本地差分隐私SignDS训练目前只支持端云联邦学习场景。开启方式需要在启动云侧服务时,在yaml文件中更改下列参数配置,云侧完整启动脚本可参考云侧部署,这里给出启动该算法的相关参数配置。以LeNet任务为例,yaml相关配置如下: - -```python -encrypt: - encrypt_train_type: SIGNDS - ... - signds: - sign_k: 0.2 - sign_eps: 100 - sign_thr_ratio: 0.6 - sign_global_lr: 0.1 - sign_dim_out: 0 -``` - -具体样例可参考[图像分类应用](https://www.mindspore.cn/federated/docs/zh-CN/master/image_classification_application.html) -云侧代码实现给出了各个参数的定义域,若不在定义域内的,Server会报错提示定义域。以下参数改动的前提是保持其余4个参数不变: - -- `sign_k`:(0,0.25],k*inputDim>50. default=0.01,`inputDim`是模型或update的拉平长度,若不满足,端侧警告。排序update,占比前k(%)的组成`topk`集合。减少k,则意味着要从更重要的维度中以较大概率挑选,输出的维度会减少,但维度更重要,无法确定收敛性的变化,用户需观察模型update稀疏度来确定该值,当比较稀疏时(update有很多0),则应取小一点。 -- `sign_eps`:(0,100],default=100。隐私保护预算,数序符号为$\epsilon$,简写为eps。eps减少,挑选不重要的维度概率会增大,隐私保护力度增强,输出维度减少,占比不变,精度降低。 -- `sign_thr_ratio`:[0.5,1],default=0.6。激活的维度中来自`topk`的维度占比阈值下界。增大会减少输出维度,但输出维度中来自`topk`的占比会增加,当过度增大该值,要求输出中更多的来自`topk`,为了满足要求只能减少总的输出维度,当client用户数量不够多时,精度下降。 -- `sign_global_lr`:(0,),default=1。该值乘上sign来代替update,直接影响收敛快慢与精度,适度增大该值会提高收敛速度,但有可能让模型震荡,梯度爆炸。如果每个client用户本地跑更多的epoch,且增大本地训练使用的学习率,那么需要相应提高该值;如果参与聚合的client用户数目增多,那么也需要提高该值,因为重构时需要把该值聚合再除以用户数目,只有增大该值,结果才保持不变。若参与聚合的新版本(r0.2)参与方占比不足5%,则MagRR算法的$lr_{global}$直接调整为该入参。 -- `sign_dim_out`:[0,50],default=0。若给出非0值,client端直接使用该值,增大该值输出的维度增多,但来自`topk`的维度占比会减少;若为0,client用户要计算出最优的输出参数。eps不够大时,若增大该值,则会输出很多`non-topk`的不重要维度导致影响模型收敛,精度下降;当eps足够大时,增大该值会让更多的用户重要的维度信息离开本地,精度提升。 - -## LeNet实验结果 - -使用`3500_clients_bin`其中的100个client数据集,联邦聚合600个iteration,每个client本地运行20个epoch,端侧本地训练使用学习率为0.01,SignDS相关参数为`k=0.2,eps=100,ratio=0.6,lr=4,out=0`,Loss和Auc的变化曲线如下图所示。端侧训练结束上传到云侧的数据长度为266084,但SignDS上传的数据长度仅为656。 - -![loss](./images/lenet_signds_loss_auc.png) - -## 参考文献 - -[1] Ligeng Zhu, Zhijian Liu, and Song Han. [Deep Leakage from Gradients](http://arxiv.org/pdf/1906.08935.pdf). NeurIPS, 2019. - -[2] Xue Jiang, Xuebing Zhou, and Jens Grossklags. "SignDS-FL: Local Differentially-Private Federated Learning with Sign-based Dimension Selection." ACM Transactions on Intelligent Systems and Technology, 2022. \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/object_detection_application_in_cross_silo.md b/docs/federated/docs/source_zh_cn/object_detection_application_in_cross_silo.md deleted file mode 100644 index f16efa1a38407cb31eac55f51fd76381b44aa3d0..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/object_detection_application_in_cross_silo.md +++ /dev/null @@ -1,264 +0,0 @@ -# 实现一个云云联邦的目标检测应用(x86) - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/object_detection_application_in_cross_silo.md) - -根据参与客户端的类型,联邦学习可分为云云联邦学习(cross-silo)和端云联邦学习(cross-device)。在云云联邦学习场景中,参与联邦学习的客户端是不同的组织(例如,医疗或金融)或地理分布的数据中心,即在多个数据孤岛上训练模型。在端云联邦学习场景中,参与的客户端为大量的移动或物联网设备。本框架将介绍如何在MindSpore Federated云云联邦框架上使用网络Fast R-CNN实现一个目标检测应用。 - -启动云云联邦的目标检测应用的完整脚本可参考[这里](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_faster_rcnn)。 - -## 任务前准备 - -本教程基于MindSpore model_zoo中提供的的faster_rcnn网络部署云云联邦目标检测任务,请先根据官方[faster_rcnn教程及代码](https://gitee.com/mindspore/models/tree/master/official/cv/FasterRCNN)先了解COCO数据集、faster_rcnn网络结构、训练过程以及评估过程。由于COCO数据集已开源,请参照其[官网](https://cocodataset.org/#home)指引自行下载好数据集,并进行数据集切分(例如模拟100个客户端,可将数据集切分成100份,每份代表一个客户端所持有的数据)。 - -由于原始COCO数据集为json文件格式,云云联邦学习框架提供的目标检测脚本暂时只支持MindRecord格式输入数据,可根据以下步骤将json文件转换为MindRecord格式文件。 - -- 首先在配置文件[default_config.yaml](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/default_config.yaml)中设置以下参数: - - - 参数`mindrecord_dir` - - 用于设置生成的MindRecord格式文件保存路径,文件夹名称必须为mindrecord_{num}格式,数字num代表客户端标号0,1,2,3,...... - - ```sh - mindrecord_dir:"./datasets/coco_split/split_100/mindrecord_0" - ``` - - - 参数`instance_set` - - 用于设置原始json文件路径。 - - ```sh - instance_set: "./datasets/coco_split/split_100/train_0.json" - ``` - -- 运行脚本[generate_mindrecord.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/generate_mindrecord.py)即可生成`train_0.json`对应的MindRecord文件,保存在路径`mindrecord_dir`中。 - -## 启动云云联邦任务 - -### 安装MindSpore和MindSpore Federated - -包括源码和下载发布版两种方式,支持CPU、GPU、Ascend硬件平台,根据硬件平台选择安装即可。安装步骤可参考[MindSpore安装指南](https://www.mindspore.cn/install),[MindSpore Federated安装指南](https://www.mindspore.cn/federated/docs/zh-CN/master/index.html)。 - -目前联邦学习框架只支持Linux环境中部署,cross-silo联邦学习框架需要MindSpore版本号>=1.5.0。 - -### 启动任务 - -参考[示例](https://gitee.com/mindspore/federated/tree/master/example/cross_silo_faster_rcnn),启动集群。参考示例目录结构如下: - -```text -cross_silo_faster_rcnn -├── src -│ ├── FasterRcnn -│ │ ├── __init__.py // init文件 -│ │ ├── anchor_generator.py // 锚点生成器 -│ │ ├── bbox_assign_sample.py // 第一阶段采样器 -│ │ ├── bbox_assign_sample_stage2.py // 第二阶段采样器 -│ │ ├── faster_rcnn_resnet.py // Faster R-CNN网络 -│ │ ├── faster_rcnn_resnet50v1.py // 以Resnet50v1.0作为backbone的Faster R-CNN网络 -│ │ ├── fpn_neck.py // 特征金字塔网络 -│ │ ├── proposal_generator.py // 候选生成器 -│ │ ├── rcnn.py // R-CNN网络 -│ │ ├── resnet.py // 骨干网络 -│ │ ├── resnet50v1.py // Resnet50v1.0骨干网络 -│ │ ├── roi_align.py // ROI对齐网络 -│ │ └── rpn.py // 区域候选网络 -│ ├── dataset.py // 创建并处理数据集 -│ ├── lr_schedule.py // 学习率生成器 -│ ├── network_define.py // Faster R-CNN网络定义 -│ ├── util.py // 例行操作 -│ └── model_utils -│ ├── __init__.py // init文件 -│ ├── config.py // 获取.yaml配置参数 -│ ├── device_adapter.py // 获取云上id -│ ├── local_adapter.py // 获取本地id -│ └── moxing_adapter.py // 云上数据准备 -├── requirements.txt -├── mindspore_hub_conf.py -├── generate_mindrecord.py // 将.json格式的annotations文件转化为MindRecord格式,以便读取datasets -├── default_yaml_config.yaml // 联邦训练所需配置文件 -├── default_config.yaml // 网络结构、数据集地址、fl_plan所需配置文件 -├── run_cross_silo_fasterrcnn_worker.py // 启动云云联邦worker脚本 -├── run_cross_silo_fasterrcnn_worker_distribute.py // 启动云云联邦分布式worker训练脚本 -└── test_fl_fasterrcnn.py // 客户端使用的训练脚本 -└── run_cross_silo_fasterrcnn_sched.py // 启动云云联邦scheduler脚本 -└── run_cross_silo_fasterrcnn_server.py // 启动云云联邦server脚本 -``` - -1. 注意在`test_fl_fasterrcnn.py`文件中可通过设置参数`dataset_sink_mode`来选择是否记录每个step的loss值: - - ```python - model.train(config.client_epoch_num, dataset, callbacks=cb, dataset_sink_mode=True) # 不设置dataset_sink_mode=True代表只记录每个epoch中最后一个step的loss值。 - model.train(config.client_epoch_num, dataset, callbacks=cb, dataset_sink_mode=False) # 设置dataset_sink_mode=False代表记录每个step的loss值,代码里默认为这种方式。 - ``` - -2. 在配置文件[default_config.yaml](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/default_config.yaml)中设置以下参数: - - - 参数`pre_trained` - - 用于设置预训练模型路径(.ckpt 格式)。 - - 本教程中实验的预训练模型是在ImageNet2012上训练的ResNet-50检查点。你可以使用ModelZoo中 [resnet50](https://gitee.com/mindspore/models/tree/master/official/cv/ResNet) 脚本来训练,然后使用src/convert_checkpoint.py把训练好的resnet50的权重文件转换为可加载的权重文件。 - -3. 启动redis - - ```sh - redis-server --port 2345 --save "" - ``` - -4. 启动Scheduler - - `run_sched.py`是用于启动`Scheduler`的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的`Scheduler`,`--yaml_config`用于设置yaml文件路径,其管理ip:port为`127.0.0.1:18019`。 - - ```sh - python run_cross_silo_fasterrcnn_sched.py --yaml_config="default_yaml_config.yaml" --scheduler_manage_address="127.0.0.1:18019" - ``` - - 具体实现详见[run_cross_silo_fasterrcnn_sched.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/run_cross_silo_fasterrcnn_sched.py)。 - - 打印如下代表启动成功: - - ```sh - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.878 [mindspore_federated/fl_arch/ccsrc/scheduler/scheduler.cc:35] Run] Scheduler started successfully. - [INFO] FEDERATED(3944,2b28c5ada700,python):2022-10-10-17:11:08.155.056 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:90] Run] Start http server! - ``` - -5. 启动Server - - `run_cross_silo_fasterrcnn_server.py`是用于启动若干`Server`的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的`Server`,其TCP地址为`127.0.0.1`,联邦学习HTTP服务起始端口为`6668`,`Server`数量为`4`个。 - - ```sh - python run_cross_silo_fasterrcnn_server.py --yaml_config="default_yaml_config.yaml" --tcp_server_ip="127.0.0.1" --checkpoint_dir="/path/to/fl_ckpt" --local_server_num=4 --http_server_address="127.0.0.1:6668" - ``` - - 以上指令等价于启动了4个`Server`进程,每个`Server`的联邦学习服务端口分别为`6668`、`6669`、`6670`和`6671`,具体实现详见[run_cross_silo_fasterrcnn_server.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/run_cross_silo_fasterrcnn_server.py)。其中checkpoint_dir需要输入checkpoint所在的目录路径,server会从该路径下读取checkpoint初始化权重,checkpoint的前缀格式需要是`{fl_name}_recovery_iteration_`。 - - 打印如下代表启动成功: - - ```sh - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.645 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_server.cc:122] Start] Start http server! - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.725 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:85] Initialize] Ev http register handle of: [/d isableFLS, /enableFLS, /state, /queryInstance, /newInstance] success. - [INFO] FEDERATED(3944,2b280497ed00,python):2022-10-10-17:11:08.154.878 [mindspore_federated/fl_arch/ccsrc/scheduler/scheduler.cc:35] Run] Scheduler started successfully. - [INFO] FEDERATED(3944,2b28c5ada700,python):2022-10-10-17:11:08.155.056 [mindspore_federated/fl_arch/ccsrc/common/communicator/http_request_handler.cc:90] Run] Start http server! - ``` - -6. 启动Worker - - `run_cross_silo_femnist_worker.py`是用于启动若干`worker`的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的`worker`,联邦学习任务正常进行需要的`worker`数量至少为`2`个: - - ```sh - python run_cross_silo_fasterrcnn_worker.py --local_worker_num=2 --yaml_config="default_yaml_config.yaml" --pre_trained="/path/to/pre_trained" --dataset_path=/path/to/datasets/coco_split/split_100 --http_server_address=127.0.0.1:6668 - ``` - - 具体实现详见[run_cross_silo_femnist_worker.py](https://gitee.com/mindspore/federated/blob/master/example/cross_silo_faster_rcnn/run_cross_silo_fasterrcnn_worker.py)。在数据下沉模式下,云云联邦的同步频率以epoch为单位,否则同步频率以step为单位。 - - 如上指令,`--local_worker_num=2`代表启动两个客户端,且两个客户端使用的数据集分别为`datasets/coco_split/split_100/mindrecord_0`和`datasets/coco_split/split_100/mindrecord_1`,请根据`任务前准备`教程准备好对应客户端所需数据集。 - - 当执行以上三个指令之后,等待一段时间之后,进入当前目录下`worker_0`文件夹,通过指令`grep -rn "\epoch:" *`查看`worker_0`日志,可看到类似如下内容的日志信息: - - ```sh - epoch: 1 step: 1 total_loss: 0.6060338 - ``` - - 则说明云云联邦启动成功,`worker_0`正在训练,其他worker可通过类似方式查看。 - - 当前云云联邦的`worker`节点支持单机多卡&多机多卡的分布式训练方式,`run_cross_silo_fasterrcnn_worker_distributed.py`是为用户启动`worker`节点的分布式训练而提供的Python脚本,并支持通过`argparse`传参修改配置。执行指令如下,代表启动本次联邦学习任务的分布式`worker`,其中`device_num`表示`worker`集群启动的进程数目,`run_distribute`表示启动集群的分布式训练,其http起始端口为`6668`,`worker`进程数量为`4`个: - - ```sh - python run_cross_silo_fasterrcnn_worker_distributed.py --device_num=4 --run_distribute=True --dataset_path=/path/to/datasets/coco_split/split_100 --http_server_address=127.0.0.1:6668 - ``` - - 进入当前目录下`worker_distributed/log_output/`文件夹,通过指令`grep -rn "epoch" *`查看`worker`分布式集群的日志,可看到如下类似打印: - - ```sh - epoch: 1 step: 1 total_loss: 0.613467 - ``` - - 以上脚本中参数配置说明请参考[yaml配置说明](https://www.mindspore.cn/federated/docs/zh-CN/master/horizontal/federated_server_yaml.html)。 - -### 日志查看 - -成功启动任务之后,会在当前目录`cross_silo_faster_rcnn`下生成相应日志文件,日志文件目录结构如下: - -```text -cross_silo_faster_rcnn -├── scheduler -│ └── scheduler.log # 运行scheduler过程中打印日志 -├── server_0 -│ └── server.log # server_0运行过程中打印日志 -├── server_1 -│ └── server.log # server_1运行过程中打印日志 -├── server_2 -│ └── server.log # server_2运行过程中打印日志 -├── server_3 -│ └── server.log # server_3运行过程中打印日志 -├── worker_0 -│ ├── ckpt # 存放worker_0在每个联邦学习迭代结束时获取的聚合后的模型ckpt -│ │ └── mindrecord_0 -│ │ ├── mindrecord_0-fast-rcnn-0epoch.ckpt -│ │ ├── mindrecord_0-fast-rcnn-1epoch.ckpt -│ │ │ -│ │ │ ...... -│ │ │ -│ │ └── mindrecord_0-fast-rcnn-29epoch.ckpt -│ ├──loss_0.log # 记录worker_0训练过程中的每个step的loss值 -│ └── worker.log # 记录worker_0参与联邦学习任务过程中输出日志 -└── worker_1 - ├── ckpt # 存放worker_1在每个联邦学习迭代结束时获取的聚合后的模型ckpt - │ └── mindrecord_1 - │ ├── mindrecord_1-fast-rcnn-0epoch.ckpt - │ ├── mindrecord_1-fast-rcnn-1epoch.ckpt - │ │ - │ │ ...... - │ │ - │ └── mindrecord_1-fast-rcnn-29epoch.ckpt - ├──loss_0.log # 记录worker_1训练过程中的每个step的loss值 - └── worker.log # 记录worker_1参与联邦学习任务过程中输出日志 -``` - -### 关闭任务 - -若想中途退出,则可用以下指令: - -```sh -python finish_cross_silo_fasterrcnn.py --redis_port=2345 -``` - -具体实现详见[finish_cloud.py](https://gitee.com/mindspore/federated/blob/master/tests/st/cross_device_cloud/finish_cloud.py)。 - -或者等待训练任务结束之后集群会自动退出,不需要手动关闭。 - -### 实验结果 - -- 使用数据: - - COCO数据集,拆分为100份,取前两份分别作为两个worker的数据集 - -- 客户端本地训练epoch数:1 - -- 云云联邦学习总迭代数:30 - -- 实验结果(记录客户端本地训练过程中的loss值): - - 进入当前目录下`worker_0`文件夹,通过指令`grep -rn "\]epoch:" *`查看`worker_0`日志,可看到每个step输出的loss值,如下所示: - - ```sh - epoch: 1 step: 1 total_loss: 5.249325 - epoch: 1 step: 2 total_loss: 4.0856013 - epoch: 1 step: 3 total_loss: 2.6916502 - epoch: 1 step: 4 total_loss: 1.3917351 - epoch: 1 step: 5 total_loss: 0.8109232 - epoch: 1 step: 6 total_loss: 0.99101084 - epoch: 1 step: 7 total_loss: 1.7741735 - epoch: 1 step: 8 total_loss: 0.9517553 - epoch: 1 step: 9 total_loss: 1.7988946 - epoch: 1 step: 10 total_loss: 1.0213892 - epoch: 1 step: 11 total_loss: 1.1700443 - . - . - . - ``` - -worker_0和worker_1在30个迭代的训练过程中,统计每个step的训练loss变换柱状图如下[1]和[2]: - -worker_0和worker_1在30个迭代的训练过程中,统计每个epoch的平均loss(一个epoch中包含的所有step的loss之和除以step数)的折线图如下[3]和[4]: - -![cross-silo_fastrcnn-2workers-loss.png](images/cross-silo_fastrcnn-2workers-loss.png) diff --git a/docs/federated/docs/source_zh_cn/pairwise_encryption_training.md b/docs/federated/docs/source_zh_cn/pairwise_encryption_training.md deleted file mode 100644 index b4f451f0b46db2c2c057b511c269b37206fd82b8..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/pairwise_encryption_training.md +++ /dev/null @@ -1,64 +0,0 @@ -# 横向联邦-安全聚合训练 - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/pairwise_encryption_training.md) - -联邦学习过程中,用户数据仅用于本地设备训练,不需要上传至中心服务器,可以避免用户个人数据的直接泄露。然而传统联邦学习框架中,模型以明文形式上云,仍然存在间接泄露用户隐私的风险。攻击者获取到用户上传的明文模型后,可以通过重构、模型逆向等攻击方式恢复用户的个人训练数据,导致用户隐私泄露。 - -MindSpore Federated联邦学习框架,提供了基于多方安全计算(MPC)的安全聚合算法,在本地模型上云前加上秘密扰动。在保证模型可用性的前提下,解决横向联邦学习中的隐私泄露和模型窃取问题。 - -## 原理概述 - -尽管差分隐私技术可以适当保护用户数据隐私,但是当参与客户端数量比较少或者高斯噪声幅值较大时,模型精度会受较大影响。为了同时满足模型保护和模型收敛这两个要求,我们提供了基于MPC的安全聚合方案。 - -在这种训练模式下,假设参与的客户端集合为$U$,对于任意客户端Client $u$和$v$, -它们会两两协商出一对随机扰动$p_{uv}$、$p_{vu}$,满足 - -$$ -p_{uv}=\begin{cases} -p_{vu}, &u{\neq}v\\\\ 0, &u=v \end{cases} -$$ - -于是每个客户端Client $u$ 在上传模型至服务端Server前,会在原模型权重$x_u$加上它与其他用户协商的扰动: - -$$ -x_{encrypt}=x_u+\sum\limits_{v{\in}U}p_{uv} -$$ - -从而服务端Server聚合结果$\overline{x}$为: - -$$ -\begin{align} -\overline{x}&=\sum\limits_{u{\in}U}(x_{u}+\sum\limits_{v{\in}U}p_{uv})\\\\ -&=\sum\limits_{u{\in}U}x_{u}+\sum\limits_{u{\in}U}\sum\limits_{v{\in}U}p_{uv}\\\\ -&=\sum\limits_{u{\in}U}x_{u} -\end{align} -$$ - -上述过程仅介绍了聚合算法的主要思想,基于MPC的聚合方案是精度无损的,代价是通讯轮次的增加。 - -如果您对算法的具体步骤感兴趣,可以参考原论文[1]。 - -## 使用方式 - -### 端云联邦场景 - -开启安全聚合训练的方式很简单,只需要在启动云侧服务时,通过yaml文件设置`encrypt_train_type`字段为`PW_ENCRYPT`即可。 - -此外,由于端云联邦场景下,参与训练的Worker大多是手机等不稳定的边缘计算节点,所以要考虑计算节点的掉线和密钥恢复问题。与之相关的参数有`share_secrets_ratio`、`reconstruct_secrets_threshold`和`cipher_time_window`。 - -`share_client_ratio`指代公钥分发轮次、秘密分享轮次、秘钥恢复轮次的客户端阈值衰减比例,取值需要小于等于1。 - -`reconstruct_secrets_threshold`指代恢复秘密需要的碎片数量,取值需要小于参与updateModel的客户端数量(start_fl_job_threshold*update_model_ratio)。 - -通常为了保证系统安全,当不考虑Server和Client合谋的情况下,`reconstruct_secrets_threshold`需要大于联邦学习客户端数量的一半;当考虑Server和Client合谋,`reconstruct_secrets_threshold`需要大于联邦学习客户端数量的2/3。 - -`cipher_time_window`指代安全聚合各通讯轮次的时长限制,主要用来保证某些客户端掉线的情况下,Server可以开始新一轮迭代。 - -### 云云联邦场景 - -在云云联邦场景下,在云侧启动脚本通过yaml文件设置`encrypt_train_type`字段为`PW_ENCRYPT`即可。 - -此外,与端云联邦不同的是,在云云联邦场景中,每个Worker都是稳定的服务器,所以不需要考虑掉线问题,因此只需要设置`cipher_time_window`这一超参。 - -## 参考文献 - -[1] Keith Bonawitz, Vladimir Ivanov, Ben Kreuter, et al. [Practical Secure Aggregationfor Privacy-Preserving Machine Learning](https://dl.acm.org/doi/pdf/10.1145/3133956.3133982). Proceedings of the 2017 ACM SIGSAC Conference on Computer and communications Security. 2017. diff --git a/docs/federated/docs/source_zh_cn/private_set_intersection.md b/docs/federated/docs/source_zh_cn/private_set_intersection.md deleted file mode 100644 index d08377ec23d169b97ac0a04f0d7872f040f00eca..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/private_set_intersection.md +++ /dev/null @@ -1,129 +0,0 @@ -# 纵向联邦-隐私集合求交 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/private_set_intersection.md) - -## 隐私保护背景 - -随着数字化转型和数据要素流通的需求提升,以及《数据安全法》、《个人信息保护法》和欧盟《通用数据保护条例》(GDPR)的施行,数据的隐私性(Privacy)越来越成为诸多场景下必要的需求。例如,当数据集合是用户的敏感信息(医疗诊断信息、交易记录、身份识别码、设备唯一标识符 OAID 等),或者是公司的秘密信息时,为了防止信息泄露,在开放状态下使用数据之前必须采用密码学或者脱敏手段来确保数据的机密性(Confidentiality),以达到数据“可用不可见”的目标。考虑两个参与方利用各自数据共同训练一个机器学习模型(例如纵向联邦学习),该任务的第一步就是需要对齐双方的样本集,也就是所谓的实体解析(Entity Resolution)过程。传统的明文求交不可避免地会泄露整个数据库的 OAID,对双方的数据私密性产生破坏,因此需要采用隐私集合求交(Private Set Intersection,PSI)技术来完成该任务。 - -PSI 是安全多方计算(MPC)协议的一种,它接收两方的数据集合作为输入,经过一系列哈希、加密、数据交换等步骤,最终向约定的输出方输出集合的交集,同时保证参与方无法获得交集以外数据的任何信息。在纵向联邦学习任务中使用 PSI 协议,符合 GDPR 提出的数据最小化(Data Minimisation)要求,即除训练过程必须的部分(交集),数据不产生非必要的暴露;从数据控制者的角度来看,业务上不得不适当共享数据,但又只想基于业务共享必须数据,不对外暴露额外数据。值得注意的是,虽然 PSI 可以直接套用已有的 MPC 协议进行计算,但是这样做往往会带来较大的计算和通信开销,不利于业务的开展。本文将介绍一种结合布隆过滤器和椭圆曲线点乘逆元抵消的技术,实现 ECDH-PSI(Elliptic Curve Diffie–Hellman key Exchange-PSI),去更好地支撑云服务和开展隐私保护集合交集计算服务。 - -## 算法流程介绍 - -ECDH-PSI 的核心思想是:一条数据先经过 Alice 加密再经过 Bob 加密,与交换加密顺序的结果相同。那么一方在不泄露自己隐私的情况下,发送用自己私钥加密的数据,另一方基于接收的加密数据再用自己私钥再加密,如果加密结果相同,则说明原始数据相同。 - -求逆的 ECDH-PSI 的核心优化点是:在面对数据量不均衡的双方求交场景时(记 Bob 为数据量少的一方,$a$,$b$ 分别为 Alice 和 Bob 的私钥,双方的原始数据映射到椭圆曲线上分别记为 $P_1$ 和 $P_2$ ,用私钥 $k$ 进行椭圆曲线的点乘加密记为 $P^k$ 或 $kP$,私钥$k$的逆元为$k^{-1}$),尽可能的减少基于数据量多的集合的加密计算。那么 Alice 执行完 $p_1^a$ 发送至 Bob 后,Bob 不再基于此执行加密计算了,而是发送 $p_2^b$ 至Alice,Alice 发送$P_2^{ba}$之后,Bob 通过点乘自己私钥的逆元完成抵消操作,即计算$P_2^{bab^{-1}}$,将其与Alice发送来的 $P_1^a$ 进行对比,如果加密结果相同,那么说明 $P_1=P_2$。求逆 ECDH-PSI 流程图如图所示,红色字样表示收到的对方数据: - -![](./images/inverse_ecdh_psi_flow.png) - -图中 $bf$ 为布隆过滤器(bloom filter, bf)的缩写。若要在一个集合中查询是否存在一个元素,基本方法是遍历一遍集合进行查询,或将集合进行排序,使用二分查找进行查询,但当数据量过大时,排序不支持并行所以十分耗时。若使用布隆过滤器,将集合中的元素通过若干哈希函数映射至一个初始全 0 比特串中的若干位,所有集合中的元素共用一个比特串。查询时,只需将待查询数据也使用相同的这些若干哈希函数处理,直接访问所有对应位是否激活为 1,全为 1 则说明命中,数据存在;反之不存在。其中碰撞的概率可通过控制哈希函数的个数来实现。相较发送整个集合和发送布隆过滤器输出的一个比特串,后者通讯开销更低;在建立布隆过滤器和使用过滤器进行大规模数据查询过程中,也可以通过并行来加速计算。 - -## 快速体验 - -### 前置需要 - -在 Python 环境中完成安装`mindspore-federated`库。 - -### 启动脚本 - -可从 [MindSpore federated ST](https://gitee.com/mindspore/federated/blob/master/tests/st/psi/run_psi.py) 获取PSI双方启动脚本,开启两个进程分别模拟两方,下面给出本机与本机通讯的启动命令: - -```python -python run_psi.py --comm_role="server" --http_server_address="127.0.0.1:8004" --remote_server_address="127.0.0.1:8005" --input_begin=1 --input_end=100 - -python run_psi.py --comm_role="client" --http_server_address="127.0.0.1:8005" --remote_server_address="127.0.0.1:8004" --input_begin=50 --input_end=150 -``` - -- `input_begin`与`input_end`搭配使用,生成用于求交的数据集; -- `peer_input_begin`与`peer_input_end`表示对方的数据起止范围,使`--need_check`为`True`,可通过 Python set1.intersection(set2) 求交函数得到真实结果,用于校验 PSI 的正确性; -- `--bucket_size`(可选)表示串行进行多桶求交的 for 循环次数; -- `--thread_num`(可选)表示计算所使用的并行线程数; -- 如需运行明文通讯求交,命令中加入参数`--plain_intersection=True`即可。 - -目前psi支持亿级大数据求交,可以通过设置`input_begin`、`input_end`、`peer_input_begin`、`peer_input_end`参数来指定输入数据集的大小。理论证明机器的内存资源与系统资源足够,psi支持的数据求交数目没有上限。启动命令如下所示: - -```python -python run_psi.py --comm_role="server" --http_server_address="127.0.0.1:8004" --remote_server_address="127.0.0.1:8005" --input_begin=1 --input_end=100000000 - -python run_psi.py --comm_role="client" --http_server_address="127.0.0.1:8005" --remote_server_address="127.0.0.1:8004" --input_begin=1 --input_end=100000000 -``` - -### 输出结果 - -运行脚本前,可通过设置环境变量`export GLOG_v=1`来显示`INFO`级别的日志,也可以观察协议内部各个阶段的运行情况。脚本运行结束后,会打印输出交集结果,因交集数据量可能过大,这里限制输出前20个交集结果。 - -```bash -PSI result: ['50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69'] (display limit: 20) -``` - -## 深度体验 - -### 导入模块 - -运行隐私集合求交,需要依赖联邦库的通讯模块和求交模块,导入方法如下: - -```python -from mindspore_federated.startup.vertical_federated_local import VerticalFederatedCommunicator, ServerConfig -from mindspore_federated._mindspore_federated import RunPSI -from mindspore_federated._mindspore_federated import PlainIntersection -``` - -### 数据准备 - -`RunPSI`和`PlainIntersection`对输入数据的要求都是`List(String)`格式,这里给出了通过文件读取和for循环产生数据集的方法: - -```python -def generate_input_data(input_begin_, input_end_, read_file_, file_name_): - input_data_ = [] - if read_file_: - with open(file_name_, 'r') as f: - for line in f.readlines(): - input_data_.append(line.strip()) - else: - input_data_ = [str(i) for i in range(input_begin_, input_end_)] - return input_data_ -``` - -其中入参`input_begin_`和 `input_end_`限制了for循环的数据范围,`read_file_`和`file_name_`表示是否要读取文件和文件所在路径,文件可以自行构造,每行代表一个数据即可。 - -### 通信创建 - -在调用本接口前,需要初始化纵向联邦通讯实例,操作如下: - -```python -http_server_config = ServerConfig(server_name=comm_role, server_address=http_server_address) -remote_server_config = ServerConfig(server_name=peer_comm_role, server_address=remote_server_address) -vertical_communicator = VerticalFederatedCommunicator(http_server_config=http_server_config, - remote_server_config=remote_server_config) -vertical_communicator.launch() -``` - -- `server_name`根据该进程属于`server`还是`client`来确定,`comm_role`赋值为对应的"server"或"client"即可,`peer_comm_role_`表示对方的角色。 -- `server_address`的格式为"IP:port",`http_server_address`赋值为该进程的`IP`与`port`信息,如"127.0.0.1:8004",`remote_server_address`赋值为对方的`IP`和`port`信息。 - -### 开始求交 - -安全集合求交对外接口为`RunPSI`和`PlainIntersection`,分别为密文和明文求交,入参和返回结果类型、含义均相同,这里仅介绍密文求交`RunPSI`: - -```python -result = RunPSI(input_data, comm_role, peer_comm_role, bucket_id, thread_num) -``` - -- `input_data`: (list[string]);psi一方的输入数据; -- `comm_role`: (string);通讯相关参数,"server" 或 "client"。 -- `peer_comm_role`: (string);通讯相关参数,"server" 或 "client",与 comm_role 不同。 -- `bucket_id`: (int);外部分桶,传入桶的序号;传入负数、小数或其他类型报`TypeError`错误;双进程通讯若双方该值不同,server 报错退出,client 阻塞等待。 -- `thread_num`: (int);线程数目,自然数,0 为默认值,表示使用机器最大可用线程数目减 5,其他值会限定在 1 到机器最大可使用值;传入负数、小数或其他类型报`TypeError`错误。 - -### 输出结果 - -返回结果`result`是`list[string]`格式,表示交集结果,可自行打印输出。这里给出 Python 集合求交的方法: - -```python -def compute_right_result(self_input, peer_input): - self_input_set = set(self_input) - peer_input_set = set(peer_input) - return self_input_set.intersection(peer_input_set) -``` - -可将上述方法的结果和`result`进行对比,检测是否一致,可校验该接口的正确性。 diff --git a/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_DP.md b/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_DP.md deleted file mode 100644 index ca40c7da4f4f8e1a603a2ff69251b5db840dbbd1..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_DP.md +++ /dev/null @@ -1,157 +0,0 @@ -# 纵向联邦-基于差分隐私的标签保护 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_DP.md) - -## 背景 - -纵向联邦学习(vFL)是联邦学习(FL)的一大重要分支。当几个参与方拥有同一批用户不同属性的数据时,他们便可使用 vFL 进行协同训练。在 vFL 中,拥有用户特征的参与方(简称follower 方,如下图参与方 A)会持有一个下层网络(Bottom Model),他们将特征输入下层网络,计算得到中间结果(embedding),发送给拥有标签的参与方(简称 leader 方,如下图参与方 B),leader 方使用这些embedding 和自己持有的标签来训练上层网络(上层网络),再将算得的梯度回传给各个参与方来训练下层网络。由此可见,vFL 不需要任何参与方上传自己的原始数据即可协同训练模型。 - -![image.png](./images/vfl_1.png) - -vFL框架避免了原始数据的直接上传,因此在一定程度上保护了隐私安全,然而一个半诚实或者恶意的follower方有可能从leader方回传的梯度反推出leader方的标签信息,造成隐私安全隐患。考虑到在大量vFL场景中,标签是最有价值并且最需要保护的信息,在这样的背景下,我们需要对vFL训练提供更强的隐私保证来避免隐私信息的泄露。 - -差分隐私(Differential Privacy,DP)是一种严格基于统计学/信息论的隐私定义,是目前数据分析领域对于隐私保护的黄金标准。DP核心思想是通过在计算过程中引入随机性,来淹没个体数据对最终计算结果的影响,从而保证计算结果难以反推出个体信息。DP保护能够在极强的威胁模型下保持成立,即使在以下条件下都无法被攻破: - -- 攻击者知道算法的所有细节 -- 攻击者有无限的算力 -- 攻击者关于原始数据有任意多的背景知识 - -关于DP的背景、理论和具体实现,可以参见[1]获取更细致的介绍。 - -本设计方案基于标签差分隐私(label differential privacy,label dp)[2],在纵向联邦学习训练时为 leader 参与方的标签提供差分隐私保证,使攻击者难以从回传的梯度反推出数据的标签信息。在本方案的保护下,即使follower方是半诚实或者恶意的,都能确保在训练过程中leader方的标签信息不会被泄露,缓解参与方对于数据隐私安全的担忧。 - -## 算法实现 - -MindSpore Federated采用了一种轻量级的label dp实现方式:训练时,leader参与方在使用标签数据训练之前,对一定比例的标签进行随机翻转操作。由于随机性的引入,攻击者若想反推标签,最多只能反推出随机翻转/扰动之后的标签,增加了反推出原始标签的难度,满足差分隐私保证。在实际应用时,我们可以调整隐私参数`eps`(可以理解为随机翻转标签的比例)来满足不同的场景需求: - -- 较小`eps`(<1.0)对应高隐私,低精度 -- 较大`eps`(>5.0)对应高精度,低隐私 - -![image.png](./images/label_dp.png) - -本方案具体实际实现时,分为binary标签和onehot标签两种情况,函数中会自动判断输入的是binary还是onehot标签,输出的也是同类的标签。具体算法如下: - -### binary标签保护 - -1. 根据预设的隐私参数eps,计算翻转概率$p = \frac{1}{1 + e^{eps}}$。 -2. 以概率$p$翻转每个标签。 - -### one-hot标签保护 - -1. 对于n个类的标签,计算$p_1 = \frac{e^{eps}}{n - 1 + e^{eps}}$,$p_2 = \frac{1}{n - 1 + e^{eps}}$。 -2. 根据以下概率随机扰乱标签:维持当前标签不变的概率为$p_1$;改成其他n - 1个类里的任意一个的概率都为$p_2$。 - -## 快速体验 - -我们以[Wide&Deep纵向联邦学习案例](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo)中的单进程案例为例,介绍如何在一个纵向联邦模型中加入label dp保护。 - -### 前置需要 - -1. 安装MindSpore1.8.1或其更高版本,请参考[MindSpore官网安装指引](https://www.mindspore.cn/install)。 -2. 安装MindSpore Federated及所依赖Python库 - - ```shell - cd federated - python -m pip install -r requirements_test.txt - ``` - -3. 准备criteo数据集,请参考[Wide&Deep纵向联邦学习案例](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo)。 - -### 启动脚本 - -1. 下载federated仓 - - ```bash - git clone https://gitee.com/mindspore/federated.git - ``` - -2. 进入脚本所在文件夹 - - ```bash - cd federated/example/splitnn_criteo - ``` - -3. 运行脚本 - - ```bash - sh run_vfl_train_local_label_dp.sh - ``` - -### 查看结果 - -在训练日志`log_local_gpu.txt`查看模型训练的loss变化: - -```sh -INFO:root:epoch 0 step 100/2582 loss: 0.588637 -INFO:root:epoch 0 step 200/2582 loss: 0.561055 -INFO:root:epoch 0 step 300/2582 loss: 0.556246 -INFO:root:epoch 0 step 400/2582 loss: 0.557931 -INFO:root:epoch 0 step 500/2582 loss: 0.553283 -INFO:root:epoch 0 step 600/2582 loss: 0.549618 -INFO:root:epoch 0 step 700/2582 loss: 0.550243 -INFO:root:epoch 0 step 800/2582 loss: 0.549496 -INFO:root:epoch 0 step 900/2582 loss: 0.549224 -INFO:root:epoch 0 step 1000/2582 loss: 0.547547 -INFO:root:epoch 0 step 1100/2582 loss: 0.546989 -INFO:root:epoch 0 step 1200/2582 loss: 0.552165 -INFO:root:epoch 0 step 1300/2582 loss: 0.546926 -INFO:root:epoch 0 step 1400/2582 loss: 0.558071 -INFO:root:epoch 0 step 1500/2582 loss: 0.548258 -INFO:root:epoch 0 step 1600/2582 loss: 0.546442 -INFO:root:epoch 0 step 1700/2582 loss: 0.549062 -INFO:root:epoch 0 step 1800/2582 loss: 0.546558 -INFO:root:epoch 0 step 1900/2582 loss: 0.542755 -INFO:root:epoch 0 step 2000/2582 loss: 0.543118 -INFO:root:epoch 0 step 2100/2582 loss: 0.542587 -INFO:root:epoch 0 step 2200/2582 loss: 0.545770 -INFO:root:epoch 0 step 2300/2582 loss: 0.554520 -INFO:root:epoch 0 step 2400/2582 loss: 0.551129 -INFO:root:epoch 0 step 2500/2582 loss: 0.545622 -... -``` - -## 深度体验 - -我们以[Wide&Deep纵向联邦学习案例](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo)中的单进程案例为例,介绍在纵向联邦模型中加入label dp保护的具体操作方法。 - -### 前置需要 - -和[快速体验](#快速体验)相同:安装MindSpore、安装MindSpore Federated、准备数据集。 - -### 方案一:调用FLModel类中集成的label dp功能 - -MindSpore Federated采用`FLModel`(参见[纵向联邦学习模型训练接口](https://www.mindspore.cn/federated/docs/zh-CN/master/vertical/vertical_federated_FLModel.html))和yaml文件(参见[纵向联邦学习yaml详细配置项](https://www.mindspore.cn/federated/docs/zh-CN/master/vertical/vertical_federated_yaml.html)),建模纵向联邦学习的训练过程。 - -我们在`FLModel`类中集成了label dp功能。使用者在正常完成整个纵向联邦学习的训练过程建模后(关于vFL训练的详细介绍可以参见[纵向联邦学习模型训练 - 盘古α大模型跨域训练](https://www.mindspore.cn/federated/docs/zh-CN/master/split_pangu_alpha_application.html)),只需在标签方的yaml文件中,在`privacy`模块下加入`label_dp`子模块(若没有`privacy`模块则需使用者输入添加),并在`label_dp`模块内设定`eps`参数(差分隐私参数$\epsilon$,使用者可以根据实际需求设置此参数的值),即可让模型享受label dp保护: - -```yaml -privacy: - label_dp: - eps: 1.0 -``` - -### 方案二:直接调用LabelDP类 - -使用者也可以直接调用`LabelDP`类,更加灵活地使用label dp功能。`LabelDP`类集成在`mindspore_federated.privacy`模块中,使用者可以先指定`eps`的值定义一个`LabelDP`对象,然后将标签组作为参数传入这个对象,对象的`__call__`函数中会自动识别当前传入的是one-hot还是binary标签,输出一个经过label dp处理后的标签组。可参见以下范例: - -```python -# make private a batch of binary labels -import numpy as np -import mindspore -from mindspore import Tensor -from mindspore_federated.privacy import LabelDP -label_dp = LabelDP(eps=0.0) -label = Tensor(np.zero(5, 1), dtype=mindspore.float32) -dp_label = label_dp(label) - -# make private a batch of one-hot labels -label = Tensor(np.hstack((np.ones((5, 1)), np.zeros((5, 2)))), dtype=mindspore.float32) -dp_label = label_dp(label) -print(dp_label) -``` - -## 参考文献 - -[1] Dwork C, Roth A. The algorithmic foundations of differential privacy[J]. Foundations and Trends® in Theoretical Computer Science, 2014, 9(3–4): 211-407. - -[2] Ghazi B, Golowich N, Kumar R, et al. Deep learning with label differential privacy[J]. Advances in Neural Information Processing Systems, 2021, 34: 27131-27145. diff --git a/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_EmbeddingDP.md b/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_EmbeddingDP.md deleted file mode 100644 index ee13e08eec2a9f8b622032b2e90f86036b999354..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_EmbeddingDP.md +++ /dev/null @@ -1,135 +0,0 @@ -# 纵向联邦-基于信息混淆的特征保护 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_EmbeddingDP.md) - -## 背景介绍 - -纵向联邦学习(vertical Federated Learning, vFL)是一种主流且重要的联合学习范式。在vFL中,n(n≥2)个参与方拥有大量相同用户,但用户特征重叠较小。MindSpore Federated采用拆分学习(Split Learning, SL)技术实现vFL。以下图所示两方拆分学习为例,各参与方并不直接分享原始数据,而是分享经过本地模型提取的中间特征进行训练与推理,满足了原始数据不出本地的隐私要求。 - -然而,有研究表明[1],攻击者(例如参与方2)可以通过中间特征(E)还原出对应的原始数据(feature),造成隐私泄露。针对此类特征重构攻击,本教程提供一种基于信息混淆的轻量级特征保护方案[2]。 - -![image.png](./images/vfl_feature_reconstruction.png) - -## 方案详述 - -保护方案名为EmbeddingDP,总体如下图所示。对生成的中间特征E,依次施加量化(Quantization)和差分隐私(Differential Privacy, DP)等混淆操作,生成P,并将P作为中间特征发送至参与方2。混淆操作大大降低了中间特征与原始输入之间的相关性,加大了攻击难度。 - -![image.png](./images/vfl_feature_reconstruction_defense.png) - -目前,本教程支持单比特量化和基于随机响应的差分隐私保护,方案细节如下图所示。 - -1. **单比特量化(Quantization)**:对于输入向量E,单比特量化会将其中大于0的数置为1,小于等于0的数置为0,生成二值向量B。 - -2. **基于随机响应的差分隐私(DP)**:差分隐私需要配置关键参数`eps`。若未配置`eps`,则不进行差分隐私,直接将二值向量B作为待传的中间特征;若正确配置`eps`(即`eps`为非负实数),`eps`越大,混淆的概率越低,对数据影响越小,同时,隐私保护力度相对较弱。对二值向量B中的任一维度i,若B[i]=1,则以概率p保持数值不变;若B[i]=0,则以概率q翻转B[i],即令B[i]=1。其中,概率p和q依据如下公式计算。其中,e表示自然底数。 - -$$p = \frac{e^{(eps / 2)}}{e^{(eps / 2)} + 1},\quad q = \frac{1}{e^{(eps / 2)} + 1}$$ - -![image.png](./images/vfl_mnist_detail.png) - -## 特性体验 - -本特性可对一维或二维的张量数组进行处理。一维数组仅可由数字0和1组成,二维数组需由独热编码格式的一维向量组成。在[安装MindSpore与Federated](https://mindspore.cn/federated/docs/zh-CN/master/federated_install.html#%E8%8E%B7%E5%8F%96mindspore-federated)后,可应用本特性处理符合要求的张量数组,示例程序如下所示: - -```python -import mindspore as ms -from mindspore import Tensor -from mindspore.common.initializer import Normal -from mindspore_federated.privacy import EmbeddingDP - -ori_tensor = Tensor(shape=(2,3), dtype=ms.float32, init=Normal()) -print(ori_tensor) -dp_tensor = EmbeddingDP(eps=1)(ori_tensor) -print(dp_tensor) -``` - -## 应用案例 - -### 保护盘古α大模型跨域训练 - -#### 准备环节 - -下载federated代码仓,并依据教程[纵向联邦学习模型训练 - 盘古α大模型跨域训练](https://mindspore.cn/federated/docs/zh-CN/master/split_pangu_alpha_application.html#%E5%87%86%E5%A4%87%E7%8E%AF%E8%8A%82),配置运行环境与实验数据集,而后可根据需要运行单进程或多进程示例程序。 - -```bash -git clone https://gitee.com/mindspore/federated.git -``` - -#### 单进程样例 - -1. 进入样例所在目录,并执行[运行单进程样例](https://mindspore.cn/federated/docs/zh-CN/master/split_pangu_alpha_application.html#%E8%BF%90%E8%A1%8C%E5%8D%95%E8%BF%9B%E7%A8%8B%E6%A0%B7%E4%BE%8B)中第2至4步: - - ```bash - cd federated/example/splitnn_pangu_alpha - ``` - -2. 启动配置了EmbeddingDP的训练脚本: - - ```bash - sh run_pangu_train_local_embedding_dp.sh - ``` - -3. 查看训练日志`splitnn_pangu_local.txt`中的训练loss: - - ```text - 2023-02-07 01:34:00 INFO: The embedding is protected by EmbeddingDP with eps 5.000000. - 2023-02-07 01:35:40 INFO: epoch 0 step 10/43391 loss: 10.653997 - 2023-02-07 01:36:25 INFO: epoch 0 step 20/43391 loss: 10.570406 - 2023-02-07 01:37:11 INFO: epoch 0 step 30/43391 loss: 10.470503 - 2023-02-07 01:37:58 INFO: epoch 0 step 40/43391 loss: 10.242296 - 2023-02-07 01:38:45 INFO: epoch 0 step 50/43391 loss: 9.970814 - 2023-02-07 01:39:31 INFO: epoch 0 step 60/43391 loss: 9.735226 - 2023-02-07 01:40:16 INFO: epoch 0 step 70/43391 loss: 9.594692 - 2023-02-07 01:41:01 INFO: epoch 0 step 80/43391 loss: 9.340107 - 2023-02-07 01:41:47 INFO: epoch 0 step 90/43391 loss: 9.356388 - 2023-02-07 01:42:34 INFO: epoch 0 step 100/43391 loss: 8.797981 - ... - ``` - -#### 多进程样例 - -1. 进入样例所在目录,安装依赖包,并配置数据集: - - ```bash - cd federated/example/splitnn_pangu_alpha - python -m pip install -r requirements.txt - cp -r {dataset_dir}/wiki ./ - ``` - -2. 在服务器1启动配置了EmbeddingDP的训练脚本: - - ```bash - sh run_pangu_train_leader_embedding_dp.sh {ip1:port1} {ip2:port2} ./wiki/train ./wiki/train - ``` - - `ip1`和`port1`表示参与本地服务器(服务器1)的IP地址和端口号,`ip2`和`port2`表示对端服务器(服务器2)的IP地址和端口号,`./wiki/train`是训练数据集文件路径,`./wiki/test`是评估数据集文件路径。 - -3. 在服务器2启动另一参与方的训练脚本: - - ```bash - sh run_pangu_train_follower.sh {ip2:port2} {ip1:port1} - ``` - -4. 查看训练日志`leader_process.log`中的训练loss: - - ```text - 2023-02-07 01:39:15 INFO: config is: - 2023-02-07 01:39:15 INFO: Namespace(ckpt_name_prefix='pangu', ...) - 2023-02-07 01:39:21 INFO: The embedding is protected by EmbeddingDP with eps 5.000000. - 2023-02-07 01:41:05 INFO: epoch 0 step 10/43391 loss: 10.669225 - 2023-02-07 01:41:38 INFO: epoch 0 step 20/43391 loss: 10.571924 - 2023-02-07 01:42:11 INFO: epoch 0 step 30/43391 loss: 10.440327 - 2023-02-07 01:42:44 INFO: epoch 0 step 40/43391 loss: 10.253876 - 2023-02-07 01:43:16 INFO: epoch 0 step 50/43391 loss: 9.958257 - 2023-02-07 01:43:49 INFO: epoch 0 step 60/43391 loss: 9.704673 - 2023-02-07 01:44:21 INFO: epoch 0 step 70/43391 loss: 9.543740 - 2023-02-07 01:44:54 INFO: epoch 0 step 80/43391 loss: 9.376131 - 2023-02-07 01:45:26 INFO: epoch 0 step 90/43391 loss: 9.376905 - 2023-02-07 01:45:58 INFO: epoch 0 step 100/43391 loss: 8.766671 - ... - ``` - -## 参考文献 - -[1] Erdogan, Ege, Alptekin Kupcu, and A. Ercument Cicek. "Unsplit: Data-oblivious model inversion, model stealing, and label inference attacks against split learning." arXiv preprint arXiv:2108.09033 (2021). - -[2] Anonymous Author(s). "MistNet: Towards Private Neural Network Training with Local Differential Privacy". (https://github.com/TL-System/plato/blob/2e5290c1f3acf4f604dad223b62e801bbefea211/docs/papers/MistNet.pdf) diff --git a/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_TEE.md b/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_TEE.md deleted file mode 100644 index baa563e667ef5a0bbc7cdf5f869bc5f0560a59fd..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_TEE.md +++ /dev/null @@ -1,281 +0,0 @@ -# 纵向联邦-基于可信执行环境的特征保护 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/secure_vertical_federated_learning_with_TEE.md) - -注:这是一个实验特性,未来有可能被修改或删除。 - -## 背景 - -纵向联邦学习(vFL)是联邦学习(FL)的一大重要分支。当不同的参与方拥有来自相同一批用户但属性不同的数据时,他们便可使用vFL进行协同训练。在vFL中,拥有属性的参与方都会持有一个下层网络(Bottom Model),他们分别将属性输入下层网络,得到中间结果(embedding),发送给拥有标签的参与方(简称leader方,如下图参与方B,而不拥有标签的被称作follower方,如下图参与方A),leader方使用embedding和标签来训练上层网络,再将算得的梯度回传给各个参与方用以训练下层网络。由此可见,vFL不需要任何参与方上传自己的原始数据即可协同训练模型。 - -![image.png](./images/vfl_1.png) - -由于避免了直接上传原始数据,vFL在一定程度上保护了隐私安全(这也是vFL的核心目标之一),然而攻击者还是有可能从上传的embedding反推出用户信息,造成隐私安全隐患。在这样的背景下,我们需要对vFL在训练时传输的embedding和梯度提供更强的隐私保证来规避隐私安全风险。 - -可信执行环境(Trusted Execution Environment,TEE)是一种基于硬件的可信计算方案,通过使硬件中的整个计算过程相对于外界黑盒化,来保证计算过程的数据安全。在vFL中,我们使用TEE将网络中的关键层屏蔽,可以使该层计算难以被反推,从而保证vFL训练和推理过程的数据安全。 - -## 算法介绍 - -![image.png](./images/vfl_with_tee.png) - -如图,如果参与方A将中间结果$\alpha^{(A)}$直接发给参与方B,则参与方B很有可能用中间结果反推出参与方A的原始数据$X^{(A)}$。为了降低这样的风险,参与方A将Bottom Model计算得到的中间结果$\alpha^{(A)}$先进行加密得到$E(\alpha^{(A)})$,将$E(\alpha^{(A)})$传给参与方B,参与方B将$E(\alpha^{(A)})$输入到TEE中的Cut Layer层中,然后在TEE的内部解密出$\alpha^{(A)}$进行前向传播。上述的整个过程,对于B来说都是黑盒的。 - -反向传梯度时也类似,Cut Layer运算出梯度$\nabla\alpha^{(A)}$,加密成$E(\nabla\alpha^{(A)})$后再由参与方B传回给参与方A,然后参与方A解密成$\nabla\alpha^{(A)}$后继续做反向传播。 - -## 快速体验 - -我们以[Wide&Deep纵向联邦学习案例](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo)中的单进程案例为例,给出一个配置TEE保护的范例脚本。 - -### 前置需要&环境配置 - -1. 环境要求: - - - 处理器:需要支持Intel SGX(Intel Sofrware Guard Extensions)功能 - - 操作系统:openEuler 20.03、openEuler 21.03 LTS SP2或更高版本 - -2. 安装SGX和SecGear(可以参考[secGear官网](https://gitee.com/openeuler/secGear)): - - ```sh - sudo yum install -y cmake ocaml-dune linux-sgx-driver sgxsdk libsgx-launch libsgx-urts sgxssl - git clone https://gitee.com/openeuler/secGear.git - cd secGear - source /opt/intel/sgxsdk/environment && source environment - mkdir debug && cd debug && cmake .. && make && sudo make install - ``` - -3. 安装MindSpore1.8.1或更高版本,请参考[MindSpore官网安装指引](https://www.mindspore.cn/install)。 - -4. 下载federated仓 - - ```sh - git clone https://gitee.com/mindspore/federated.git - ``` - -5. 下载TEE所依赖的4个库文件:[libsgx_0.so](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/libsgx_0.so)、[libsecgear.so](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/libsecgear.so)、[enclave.signed.so](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/enclave.signed.so)和[libcsecure_channel_static.a](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/tutorials-develop/federated/libcsecure_channel_static.a),并将这4个文件放至`mindspore_federated/fl_arch/ccsrc/armour/lib`路径下(需新建文件夹)。 - -6. 安装MindSpore Federated依赖Python库,请参考[Wide&Deep纵向联邦学习案例](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo)。 - -7. 为TEE编译安装MindSpore Federated(需要加入额外编译选项,表示是否使用SGX): - - ```sh - sh federated/build.sh -s on - pip install federated/build/packages/mindspore_federated-XXXXX.whl - ``` - -8. 准备criteo数据集,请参考[Wide&Deep纵向联邦学习案例](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo)。 - -### 启动脚本 - -1. 进入脚本所在文件夹 - - ```sh - cd federated/example/splitnn_criteo - ``` - -2. 运行脚本 - - ```sh - sh run_vfl_train_local_tee.sh - ``` - -### 查看结果 - -在训练日志`log_local_cpu_tee.txt`查看模型训练的loss变化: - -```sh -INFO:root:epoch 0 step 100/41322 wide_loss: 0.661822 deep_loss: 0.662018 -INFO:root:epoch 0 step 100/41322 wide_loss: 0.685003 deep_loss: 0.685198 -INFO:root:epoch 0 step 200/41322 wide_loss: 0.649380 deep_loss: 0.649381 -INFO:root:epoch 0 step 300/41322 wide_loss: 0.612189 deep_loss: 0.612189 -INFO:root:epoch 0 step 400/41322 wide_loss: 0.630079 deep_loss: 0.630079 -INFO:root:epoch 0 step 500/41322 wide_loss: 0.602897 deep_loss: 0.602897 -INFO:root:epoch 0 step 600/41322 wide_loss: 0.621647 deep_loss: 0.621647 -INFO:root:epoch 0 step 700/41322 wide_loss: 0.624762 deep_loss: 0.624762 -INFO:root:epoch 0 step 800/41322 wide_loss: 0.622042 deep_loss: 0.622042 -INFO:root:epoch 0 step 900/41322 wide_loss: 0.585274 deep_loss: 0.585274 -INFO:root:epoch 0 step 1000/41322 wide_loss: 0.590947 deep_loss: 0.590947 -INFO:root:epoch 0 step 1100/41322 wide_loss: 0.586775 deep_loss: 0.586775 -INFO:root:epoch 0 step 1200/41322 wide_loss: 0.597362 deep_loss: 0.597362 -INFO:root:epoch 0 step 1300/41322 wide_loss: 0.607390 deep_loss: 0.607390 -INFO:root:epoch 0 step 1400/41322 wide_loss: 0.584204 deep_loss: 0.584204 -INFO:root:epoch 0 step 1500/41322 wide_loss: 0.583618 deep_loss: 0.583618 -INFO:root:epoch 0 step 1600/41322 wide_loss: 0.573294 deep_loss: 0.573294 -INFO:root:epoch 0 step 1700/41322 wide_loss: 0.600686 deep_loss: 0.600686 -INFO:root:epoch 0 step 1800/41322 wide_loss: 0.585533 deep_loss: 0.585533 -INFO:root:epoch 0 step 1900/41322 wide_loss: 0.583466 deep_loss: 0.583466 -INFO:root:epoch 0 step 2000/41322 wide_loss: 0.560188 deep_loss: 0.560188 -INFO:root:epoch 0 step 2100/41322 wide_loss: 0.569232 deep_loss: 0.569232 -INFO:root:epoch 0 step 2200/41322 wide_loss: 0.591643 deep_loss: 0.591643 -INFO:root:epoch 0 step 2300/41322 wide_loss: 0.572473 deep_loss: 0.572473 -INFO:root:epoch 0 step 2400/41322 wide_loss: 0.582825 deep_loss: 0.582825 -INFO:root:epoch 0 step 2500/41322 wide_loss: 0.567196 deep_loss: 0.567196 -INFO:root:epoch 0 step 2600/41322 wide_loss: 0.602022 deep_loss: 0.602022 -``` - -## 深度体验 - -TEE层的正向传播、反向传播都需要调用它自己的函数而非通过MindSpore,因此在实现时和通常的vFL模型存在不同。 - -通常,vFL模型在训练的反向传播时,Top Model和Cut Layer是放在一起,由参与方B通过MindSpore一步求导、一步更新的;而含有TEE的网络在反向传播时,Top Model由参与方B基于MindSpore更新,而Cut Layer(TEE)是在接收到Top Model传回的梯度后,在它自己内部进行更新的,再将需要传回参与方A的梯度加密后传出给参与方B,整个过程都在TEE内部完成。 - -目前在MindSpore Federated中,上述功能是通过在`mindspore_federated.vfl_model.FLModel()`定义时传入`grad_network`来实现自定义的反向传播流程的。因此,要实现含有TEE的网络,用户可以在`grad_network`中定义好Top Model和Cut Layer的反向传播流程并传入`FLModel`即可,在反向传播时`FLModel`就会走用户自定义的训练流程。 - -我们以[Wide&Deep纵向联邦学习案例](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo)中的单进程案例为例,介绍在纵向联邦模型中配置TEE保护的具体操作方法。介绍的内容主要针对使用TEE时配置上和通常情况下的不同点,相同点则会略过(关于vFL训练的详细介绍可以参见[纵向联邦学习模型训练 - 盘古α大模型跨域训练](https://www.mindspore.cn/federated/docs/zh-CN/master/split_pangu_alpha_application.html))。 - -### 前置需要&环境配置 - -参照[快速体验](#快速体验)。 - -### 定义网络模型 - -#### 正向传播 - -和通常的vFL训练相同,使用者需要基于MindSpore提供的`nn.Cell`(参见[mindspore.nn.Cell](https://mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html#mindspore-nn-cell))来开发训练网络。不同点则在于,在TEE所在的这一层,使用者需要在该类的`construct`函数中调用TEE前向传播的函数: - -```python -from mindspore_federated._mindspore_federated import init_tee_cut_layer, backward_tee_cut_layer, \ - encrypt_client_data, secure_forward_tee_cut_layer - -class TeeLayer(nn.Cell): - """ - TEE layer of the leader net. - Args: - config (class): default config info. - """ - def __init__(self, config): - super(TeeLayer, self).__init__() - init_tee_cut_layer(config.batch_size, 2, 2, 1, 3.5e-4, 1024.0) - self.concat = ops.Concat(axis=1) - self.reshape = ops.Reshape() - - def construct(self, wide_out0, deep_out0, wide_embedding, deep_embedding): - """Convert and encrypt the intermediate data""" - local_emb = self.concat((wide_out0, deep_out0)) - remote_emb = self.concat((wide_embedding, deep_embedding)) - aa = remote_emb.flatten().asnumpy().tolist() - bb = local_emb.flatten().asnumpy().tolist() - enc_aa, enc_aa_len = encrypt_client_data(aa, len(aa)) - enc_bb, enc_bb_len = encrypt_client_data(bb, len(bb)) - tee_output = secure_forward_tee_cut_layer(remote_emb.shape[0], remote_emb.shape[1], - local_emb.shape[1], enc_aa, enc_aa_len, enc_bb, enc_bb_len, 2) - tee_output = self.reshape(Tensor(tee_output), (remote_emb.shape[0], 2)) - return tee_output -``` - -#### 反向传播 - -在通常的vfl模型中,反向传播是由`FLModel`类自动配置实现的,但在含有TEE的模型中,使用者需开发一个`grad_network`来定义反向传播流程。`grad_network`也基于`nn.Cell`,包括一个`__init__`函数和一个`construct`函数,初始化时,需要传入训练使用的网络,并且在`__init__`函数中定义:求导算子、Cut Layer之外网络的参数、loss函数、Cut Layer之外网络的优化器,示例如下: - -```python -class LeaderGradNet(nn.Cell): - """ - grad_network of the leader party. - Args: - net (class): LeaderNet, which is the net of leader party. - config (class): default config info. - """ - - def __init__(self, net: LeaderNet): - super().__init__() - self.net = net - self.sens = 1024.0 - - self.grad_op_param_sens = ops.GradOperation(get_by_list=True, sens_param=True) - self.grad_op_input_sens = ops.GradOperation(get_all=True, sens_param=True) - - self.params_head = ParameterTuple(net.head_layer.trainable_params()) - self.params_bottom_deep = vfl_utils.get_params_by_name(self.net.bottom_net, ['deep', 'dense']) - self.params_bottom_wide = vfl_utils.get_params_by_name(self.net.bottom_net, ['wide']) - - self.loss_net = HeadLossNet(net.head_layer) - self.loss_net_l2 = L2LossNet(net.bottom_net, config) - - self.optimizer_head = Adam(self.params_head, learning_rate=3.5e-4, eps=1e-8, loss_scale=self.sens) - self.optimizer_bottom_deep = Adam(self.params_bottom_deep, learning_rate=3.5e-4, eps=1e-8, loss_scale=self.sens) - self.optimizer_bottom_wide = FTRL(self.params_bottom_wide, learning_rate=5e-2, l1=1e-8, l2=1e-8, - initial_accum=1.0, loss_scale=self.sens) -``` - -`grad_network`的`construct`函数的输入是`local_data_batch`和`remote_data_batch`两个字典,在`construct`函数中首先需要从字典中提取相应的数据。接下来,除TEE外的其他层,需要分别调用MindSpore关于参数和关于输入的求导算子进行求导操作,并用优化器进行更新;TEE层则需要调用TEE的内置函数进行求导和更新,示例如下: - -```python -def construct(self, local_data_batch, remote_data_batch): - """ - The backward propagation of the leader net. - """ - # data processing - id_hldr = local_data_batch['id_hldr'] - wt_hldr = local_data_batch['wt_hldr'] - label = local_data_batch['label'] - wide_embedding = remote_data_batch['wide_embedding'] - deep_embedding = remote_data_batch['deep_embedding'] - - # forward - wide_out0, deep_out0 = self.net.bottom_net(id_hldr, wt_hldr) - local_emb = self.concat((wide_out0, deep_out0)) - remote_emb = self.concat((wide_embedding, deep_embedding)) - head_input = self.net.cut_layer(wide_out0, deep_out0, wide_embedding, deep_embedding) - loss = self.loss_net(head_input, label) - - # update of head net - sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), 1024.0) - grad_head_input, _ = self.grad_op_input_sens(self.loss_net)(head_input, label, sens) - grad_head_param = self.grad_op_param_sens(self.loss_net, self.params_head)(head_input, label, sens) - self.optimizer_head(grad_head_param) - - # update of cut layer - - tmp = grad_head_input.flatten().asnumpy().tolist() - grad_input = backward_tee_cut_layer(remote_emb.shape[0], remote_emb.shape[1], local_emb.shape[1], 1, tmp) - grad_inputa = self.reshape(Tensor(grad_input[0]), remote_emb.shape) - grad_inputb = self.reshape(Tensor(grad_input[1]), local_emb.shape) - grad_cutlayer_input = (grad_inputb[:, :1], grad_inputb[:, 1:2], grad_inputa[:, :1], grad_inputa[:, 1:2]) - - # update of bottom net - grad_bottom_wide = self.grad_op_param_sens(self.net.bottom_net, - self.params_bottom_wide)(id_hldr, wt_hldr, - grad_cutlayer_input[0:2]) - self.optimizer_bottom_wide(grad_bottom_wide) - grad_bottom_deep = self.grad_op_param_sens(self.net.bottom_net, - self.params_bottom_deep)(id_hldr, wt_hldr, - grad_cutlayer_input[0:2]) - grad_bottom_l2 = self.grad_op_param_sens(self.loss_net_l2, self.params_bottom_deep)(sens) - zipped = zip(grad_bottom_deep, grad_bottom_l2) - grad_bottom_deep = tuple(map(sum, zipped)) - self.optimizer_bottom_deep(grad_bottom_deep) - - # output the gradients for follower party - scales = {} - scales['wide_loss'] = OrderedDict(zip(['wide_embedding', 'deep_embedding'], grad_cutlayer_input[2:4])) - scales['deep_loss'] = scales['wide_loss'] - return scales -``` - -#### 定义优化器 - -定义优化器时,在yaml文件中就不需定义`grad_network`已涉及到的反向传播部分了,除此之外和通常的vfl模型定义优化器就没有区别了。 - -### 构建训练脚本 - -#### 构建网络 - -与通常的vFL训练相同,用户需要使用MindSpore Federated提供的类,将自己构造好的网络封装成纵向联邦网络。详细的API文档可以参考[纵向联邦训练接口](https://gitee.com/mindspore/federated/blob/master/docs/api/api_python/vertical/vertical_federated_FLModel.rst)。不同点则在于:构建leader方网络时,需要加上`grad_network`: - -```python -from mindspore_federated import FLModel, FLYamlData -from network_config import config -from wide_and_deep import LeaderNet, LeaderLossNet, LeaderGradNet - - -leader_base_net = LeaderNet(config) -leader_train_net = LeaderLossNet(leader_base_net, config) -leader_grad_net = LeaderGradNet(leader_base_net, config) - -leader_yaml_data = FLYamlData(config.leader_yaml_path) -leader_fl_model = FLModel(yaml_data=leader_yaml_data, - network=leader_base_net, - grad_network=Leader_grad_net, - train_network=leader_train_net) -``` - -除了上述提到的内容之外,TEE训练的其他的部分都和通常的vFL训练完全一致,使用者在配置完成后便可让模型享受TEE的安全保证。 diff --git a/docs/federated/docs/source_zh_cn/sentiment_classification_application.md b/docs/federated/docs/source_zh_cn/sentiment_classification_application.md deleted file mode 100644 index 104e6f5e3bdb85b2656df72950985328549fac52..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/sentiment_classification_application.md +++ /dev/null @@ -1,563 +0,0 @@ -# 实现一个端云情感分类应用(Android) - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/sentiment_classification_application.md) - -通过端云协同的联邦学习建模方式,可以充分发挥端侧数据的优势,避免用户敏感数据直接上传云侧。由于用户在使用输入法时,十分重视所输入文字的隐私,且输入法的智慧功能对提升用户体验非常需要。因此,联邦学习天然适用于输入法应用场景。 - -MindSpore Federated将联邦语言模型应用到了输入法的表情图片预测功能中。联邦语言模型会根据聊天文本数据推荐出适合当前语境的表情图片。在使用联邦学习建模时,每一张表情图片会被定义为一个情感标签类别,而每个聊天短语会对应一个表情图片。MindSpore Federated将表情图片预测任务定义为联邦情感分类任务。 - -## 准备环节 - -### 环境 - -参考[服务端环境配置](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_server.html)和[客户端环境配置](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)。 - -### 数据 - -[用于训练的数据](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/supervise/client.tar.gz)包含20个用户聊天文件,其目录结构如下: - -```text -datasets/supervise/client/ - ├── 0.txt # 用户0的训练数据 - ├── 1.txt # 用户1的训练数据 - │ - │ ...... - │ - └── 19.txt # 用户19的训练数据 -``` - -[用于验证的数据](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/supervise/eval.tar.gz)包含1个聊天文件,其目录结构如下: - -```text -datasets/supervise/eval/ - └── eval.txt # 验证数据 -``` - -用于训练和验证的数据中,标签包含4类表情:`good`、`leimu`、`xiaoku`、`xin`。 - -### 模型相关文件 - -模型相关的[词典](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vocab.txt)和[词典ID映射文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vocab_map_ids.txt)的目录结构如下: - -```text -datasets/ - ├── vocab.txt # 词典 - └── vocab_map_ids.txt # 词典ID映射文件 -``` - -## 定义网络 - -联邦学习中的语言模型使用ALBERT模型[1]。客户端上的ALBERT模型包括:embedding层、encoder层和classifier层。 - -具体网络定义请参考[源码](https://gitee.com/mindspore/federated/blob/master/tests/st/network/albert.py)。 - -### 生成端侧模型文件 - -用户可以根据如下指导生成端侧模型文件,或下载已经生成好的[ALBERT端侧模型文件](https://gitee.com/link?target=https%3A%2F%2Fmindspore-website.obs.cn-north-4.myhuaweicloud.com%2Fnotebook%2Fmodels%2Falbert_supervise.mindir.ms)。 - -#### 将模型导出为MindIR格式文件 - -示例代码如下: - -```python -import argparse -import os -import random -from time import time -import numpy as np -import mindspore as ms -from mindspore.nn import AdamWeightDecay -from src.config import train_cfg, client_net_cfg -from src.utils import restore_params -from src.model import AlbertModelCLS -from src.cell_wrapper import NetworkWithCLSLoss, NetworkTrainCell - - -def parse_args(): - """ - parse args - """ - parser = argparse.ArgumentParser(description='export task') - parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU']) - parser.add_argument('--device_id', type=str, default='0') - parser.add_argument('--init_model_path', type=str, default='none') - parser.add_argument('--output_dir', type=str, default='./models/mindir/') - parser.add_argument('--seed', type=int, default=0) - return parser.parse_args() - - -def supervise_export(args_opt): - ms.set_seed(args_opt.seed), random.seed(args_opt.seed) - start = time() - # 参数配置 - os.environ['CUDA_VISIBLE_DEVICES'] = args_opt.device_id - init_model_path = args_opt.init_model_path - output_dir = args_opt.output_dir - if not os.path.exists(output_dir): - os.makedirs(output_dir) - print('Parameters setting is done! Time cost: {}'.format(time() - start)) - start = time() - - # MindSpore配置 - ms.set_context(mode=ms.GRAPH_MODE, device_target=args_opt.device_target) - print('Context setting is done! Time cost: {}'.format(time() - start)) - start = time() - - # 建立模型 - albert_model_cls = AlbertModelCLS(client_net_cfg) - network_with_cls_loss = NetworkWithCLSLoss(albert_model_cls) - network_with_cls_loss.set_train(True) - print('Model construction is done! Time cost: {}'.format(time() - start)) - start = time() - - # 建立优化器 - client_params = [_ for _ in network_with_cls_loss.trainable_params()] - client_decay_params = list( - filter(train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter, client_params) - ) - client_other_params = list( - filter(lambda x: not train_cfg.optimizer_cfg.AdamWeightDecay.decay_filter(x), client_params) - ) - client_group_params = [ - {'params': client_decay_params, 'weight_decay': train_cfg.optimizer_cfg.AdamWeightDecay.weight_decay}, - {'params': client_other_params, 'weight_decay': 0.0}, - {'order_params': client_params} - ] - client_optimizer = AdamWeightDecay(client_group_params, - learning_rate=train_cfg.client_cfg.learning_rate, - eps=train_cfg.optimizer_cfg.AdamWeightDecay.eps) - client_network_train_cell = NetworkTrainCell(network_with_cls_loss, optimizer=client_optimizer) - print('Optimizer construction is done! Time cost: {}'.format(time() - start)) - start = time() - - # 构造数据 - input_ids = ms.Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), np.int32)) - attention_mask = ms.Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), np.int32)) - token_type_ids = ms.Tensor(np.zeros((train_cfg.batch_size, client_net_cfg.seq_length), np.int32)) - label_ids = ms.Tensor(np.zeros((train_cfg.batch_size,), np.int32)) - print('Client data loading is done! Time cost: {}'.format(time() - start)) - start = time() - - # 读取checkpoint - if init_model_path != 'none': - init_param_dict = ms.load_checkpoint(init_model_path) - restore_params(client_network_train_cell, init_param_dict) - print('Checkpoint loading is done! Time cost: {}'.format(time() - start)) - start = time() - - # 导出 - ms.export(client_network_train_cell, input_ids, attention_mask, token_type_ids, label_ids, - file_name=os.path.join(output_dir, 'albert_supervise'), file_format='MINDIR') - print('Supervise model export process is done! Time cost: {}'.format(time() - start)) - - -if __name__ == '__main__': - total_time_start = time() - args = parse_args() - supervise_export(args) - print('All is done! Time cost: {}'.format(time() - total_time_start)) - -``` - -#### 将MindIR文件转化为联邦学习端侧框架可用的ms文件 - -参考[图像分类应用](https://www.mindspore.cn/federated/docs/zh-CN/master/image_classification_application.html)中生成端侧模型文件部分。 - -## 启动联邦学习流程 - -首先在服务端启动脚本,参考[横向云端部署](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_server.html)。 -对应云侧配置和模型权重文件参考[albert example](https://gitee.com/mindspore/federated/tree/master/example/cross_device_albert) - -以ALBERT模型的训练与推理任务为基础,整体流程为: - -1. Android新建工程; - -2. 编译MindSpore Lite AAR包; - -3. Android实例程序结构说明; - -4. 编写代码; - -5. Android工程配置依赖项; - -6. Android构建与运行。 - -### Android新建工程 - -在Android Studio中新建项目工程,并安装相应的SDK(指定SDK版本后,由Android Studio自动安装)。 - -![新建工程](./images/create_android_project.png) - -### 相关包获取 - -1. 获取MindSpore Lite AAR包 - - 参考[MindSpore Lite](https://www.mindspore.cn/lite/docs/zh-CN/master/use/downloads.html)。 - - ```text - mindspore-lite-full-{version}.aar - ``` - -2. 获取MindSpore Federated 端侧jar包。 - - 参考[横向端侧部署](https://www.mindspore.cn/federated/docs/zh-CN/master/deploy_federated_client.html)。 - - ```text - mindspore_federated/device_client/build/libs/jarAAR/mindspore-lite-java-flclient.jar - ``` - -3. 将AAR包放置安卓工程的app/libs/目录下。 - -### Android实例程序结构说明 - -```text -app -│ ├── libs # Android库项目的二进制归档文件 -| | ├── mindspore-lite-full-{version}.aar # MindSpore Lite针对Android版本的归档文件 -| | └── mindspore-lite-java-flclient.jar # MindSpore Federated针对Android版本的归档文件 -├── src/main -│ ├── assets # 资源目录 -| | └── model # 模型目录 -| | └── albert_supervise.mindir.ms # 存放的预训练模型文件 -│ | └── albert_inference.mindir.ms # 存放的推理模型文件 -│ | └── data # 数据目录 -| | └── 0.txt # 训练数据文件 -| | └── vocab.txt # 词典文件 -| | └── vocab_map_ids.txt # 词典ID映射文件 -| | └── eval.txt # 训练结果评估文件 -| | └── eval_no_label.txt # 推理数据文件 -│ | -│ ├── java # java层应用代码 -│ │ └── ... 存放Android代码文件,相关目录可以自定义 -│ │ -│ ├── res # 存放Android相关的资源文件 -│ └── AndroidManifest.xml # Android配置文件 -│ -│ -├── build.gradle # Android工程构建配置文件 -├── download.gradle # 工程依赖文件下载 -└── ... -``` - -### 编写代码 - -1. AssetCopyer.java:该代码文件作用是把Android工程的app/src/main/assets目录下的资源文件存放到Android系统的磁盘中,以便在模型训练与推理时联邦学习框架的接口能够根据绝对路径读取到资源文件。 - - ```java - import android.content.Context; - import java.io.File; - import java.io.FileOutputStream; - import java.io.InputStream; - import java.util.logging.Logger; - public class AssetCopyer { - private static final Logger LOGGER = Logger.getLogger(AssetCopyer.class.toString()); - public static void copyAllAssets(Context context,String destination) { - LOGGER.info("destination: " + destination); - copyAssetsToDst(context,"",destination); - } - // copy assets目录下面的资源文件到Android系统的磁盘中,具体的路径可打印destination查看 - private static void copyAssetsToDst(Context context,String srcPath, String dstPath) { - try { - // 递归获取assets目录的所有的文件名 - String[] fileNames =context.getAssets().list(srcPath); - if (fileNames.length > 0) { - // 构建目标file对象 - File file = new File(dstPath); - //创建目标目录 - file.mkdirs(); - for (String fileName : fileNames) { - // copy文件到指定的磁盘 - if(!srcPath.equals("")) { - copyAssetsToDst(context,srcPath + "/" + fileName,dstPath+"/"+fileName); - }else{ - copyAssetsToDst(context, fileName,dstPath+"/"+fileName); - } - } - } else { - // 构建源文件的输入流 - InputStream is = context.getAssets().open(srcPath); - // 构建目标文件的输出流 - FileOutputStream fos = new FileOutputStream(new File(dstPath)); - // 定义1024大小的缓冲数组 - byte[] buffer = new byte[1024]; - int byteCount=0; - // 源文件写到目标文件 - while((byteCount=is.read(buffer))!=-1) { - fos.write(buffer, 0, byteCount); - } - // 刷新输出流 - fos.flush(); - // 关闭输入流 - is.close(); - // 关闭输出流 - fos.close(); - } - } catch (Exception e) { - e.printStackTrace(); - } - } - } - ``` - -2. FlJob.java:该代码文件作用是定义训练与推理任务的内容,具体的联邦学习接口含义请参考[联邦学习接口介绍](https://www.mindspore.cn/federated/docs/zh-CN/master/interface_description_federated_client.html)。 - - ```java - import android.annotation.SuppressLint; - import android.os.Build; - import androidx.annotation.RequiresApi; - import com.mindspore.flAndroid.utils.AssetCopyer; - import com.mindspore.flclient.FLParameter; - import com.mindspore.flclient.SyncFLJob; - import java.util.Arrays; - import java.util.UUID; - import java.util.logging.Logger; - public class FlJob { - private static final Logger LOGGER = Logger.getLogger(AssetCopyer.class.toString()); - private final String parentPath; - public FlJob(String parentPath) { - this.parentPath = parentPath; - } - // Android的联邦学习训练任务 - @SuppressLint("NewApi") - @RequiresApi(api = Build.VERSION_CODES.M) - public void syncJobTrain() { - // 构造dataMap - String trainTxtPath = "data/albert/supervise/client/1.txt"; - String evalTxtPath = "data/albert/supervise/eval/eval.txt"; // 非必须,getModel之后不进行验证可不设置 - String vocabFile = "data/albert/supervise/vocab.txt"; // 数据预处理的词典文件路径 - String idsFile = "data/albert/supervise/vocab_map_ids.txt" // 词典的映射id文件路径 - Map> dataMap = new HashMap<>(); - List trainPath = new ArrayList<>(); - trainPath.add(trainTxtPath); - trainPath.add(vocabFile); - trainPath.add(idsFile); - List evalPath = new ArrayList<>(); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(evalTxtPath); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(vocabFile); // 非必须,getModel之后不进行验证可不设置 - evalPath.add(idsFile); // 非必须,getModel之后不进行验证可不设置 - dataMap.put(RunType.TRAINMODE, trainPath); - dataMap.put(RunType.EVALMODE, evalPath); // 非必须,getModel之后不进行验证可不设置 - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // AlBertClient.java 包路径 - String trainModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // 绝对路径 - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // 绝对路径, 和trainModelPath保持一致 - String sslProtocol = "TLSv1.2"; - String deployEnv = "android"; - - // 端云通信url,请保证Android能够访问到server,否则会出现connection failed - String domainName = "http://10.*.*.*:6668"; - boolean ifUseElb = true; - int serverNum = 4; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setSslProtocol(sslProtocol); - flParameter.setDeployEnv(deployEnv); - flParameter.setDomainName(domainName); - flParameter.setUseElb(ifUseElb); - flParameter.setServerNum(serverNum); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); - - // start FLJob - SyncFLJob syncFLJob = new SyncFLJob(); - syncFLJob.flJobRun(); - } - // Android的联邦学习推理任务 - public void syncJobPredict() { - // 构造dataMap - String inferTxtPath = "data/albert/supervise/eval/eval.txt"; - String vocabFile = "data/albert/supervise/vocab.txt"; - String idsFile = "data/albert/supervise/vocab_map_ids.txt"; - Map> dataMap = new HashMap<>(); - List inferPath = new ArrayList<>(); - inferPath.add(inferTxtPath); - inferPath.add(vocabFile); - inferPath.add(idsFile); - dataMap.put(RunType.INFERMODE, inferPath); - - String flName = "com.mindspore.flclient.demo.albert.AlbertClient"; // AlBertClient.java 包路径 - String inferModelPath = "ms/albert/train/albert_ad_train.mindir0.ms"; // 绝对路径, 和trainModelPath保持一致; - int threadNum = 4; - BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; - int batchSize = 32; - - FLParameter flParameter = FLParameter.getInstance(); - flParameter.setFlName(flName); - flParameter.setDataMap(dataMap); - flParameter.setInferModelPath(inferModelPath); - flParameter.setThreadNum(threadNum); - flParameter.setCpuBindMode(cpuBindMode); - flParameter.setBatchSize(batchSize); - - // inference - SyncFLJob syncFLJob = new SyncFLJob(); - int[] labels = syncFLJob.modelInference(); - LOGGER.info("labels = " + Arrays.toString(labels)); - } - } - ``` - - 上面的eval_no_label.txt是指不存在标签的文件,每一行为一条语句,格式参考如下,用户可自由设置: - - ```text - 愿以吾辈之青春 护卫这盛世之中华🇨🇳 - girls help girls - 太美了,祝祖国繁荣昌盛! - 中国人民站起来了 - 难道就我一个人觉得这个是plus版本? - 被安利到啦!明天起来就看!早点睡觉莲莲 - ``` - -3. MainActivity.java:该代码文件作用是启动联邦学习训练与推理任务。 - - ```java - import android.os.Build; - import android.os.Bundle; - import androidx.annotation.RequiresApi; - import androidx.appcompat.app.AppCompatActivity; - import com.mindspore.flAndroid.job.FlJob; - import com.mindspore.flAndroid.utils.AssetCopyer; - @RequiresApi(api = Build.VERSION_CODES.P) - public class MainActivity extends AppCompatActivity { - private String parentPath; - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - // 获取该应用程序在Android系统中的磁盘路径 - this.parentPath = this.getExternalFilesDir(null).getAbsolutePath(); - // copy assets目录下面的资源文件到Android系统的磁盘中 - AssetCopyer.copyAllAssets(this.getApplicationContext(), parentPath); - // 新建一个线程,启动联邦学习训练与推理任务 - new Thread(() -> { - FlJob flJob = new FlJob(parentPath); - flJob.syncJobTrain(); - flJob.syncJobPredict(); - }).start(); - } - } - ``` - -### Android工程配置依赖项 - -1. AndroidManifest.xml - - ```xml - - - - - - - - - - - - - - - ``` - -2. app/build.gradle - - ```text - plugins { - id 'com.android.application' - } - android { - // Android SDK的编译版本,建议大于27 - compileSdkVersion 30 - buildToolsVersion "30.0.3" - defaultConfig { - applicationId "com.mindspore.flAndroid" - minSdkVersion 27 - targetSdkVersion 30 - versionCode 1 - versionName "1.0" - multiDexEnabled true - testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - ndk { - // 不同的手机型号,对应ndk不相同,本文使用的mate20手机是'armeabi-v7a' - abiFilters 'armeabi-v7a' - } - } - //指定ndk版本 - ndkVersion '21.3.6528147' - sourceSets{ - main { - // 指定jni目录 - jniLibs.srcDirs = ['libs'] - jni.srcDirs = [] - } - } - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 - } - } - dependencies { - //指定扫描libs目录下的AAR包 - implementation fileTree(dir:'libs',include:['*.aar', '*.jar']) - implementation 'androidx.appcompat:appcompat:1.1.0' - implementation 'com.google.android.material:material:1.1.0' - implementation 'androidx.constraintlayout:constraintlayout:1.1.3' - androidTestImplementation 'androidx.test.ext:junit:1.1.1' - androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0' - implementation 'com.android.support:multidex:1.0.3' - - //添加联邦学习所依赖的第三方开源软件 - implementation group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.9' - implementation group: 'com.google.flatbuffers', name: 'flatbuffers-java', version: '2.0.0' - implementation(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.68') - } - ``` - -### Android构建与运行 - -1. 连接Android设备,运行联邦学习训练与推理应用程序。通过USB连接Android设备调试,点击`Run 'app'`即可在你的设备上运行联邦学习任务。 - - ![run_app](./images/start_android_project.png) - -2. Android Studio连接设备调试操作,可参考。手机需开启“USB调试模式”,Android Studio才能识别到手机。华为手机一般在`设置->系统和更新->开发人员选项->USB调试`中打开“USB调试模式”。 - -3. 在Android设备上,点击“继续安装”,安装完即可在APP启动之后执行ALBERT模型的联邦学习的训练与推理任务。 - -4. 程序运行结果如下: - - ```text - I/SyncFLJob: [model inference] inference finish - I/SyncFLJob: labels = [2, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4] - ``` - -## 实验结果 - -联邦学习总迭代数为10,客户端本地训练epoch数为1,batchSize设置为16。 - -```text - total acc:0.44488978 - total acc:0.583166333 - total acc:0.609218437 - total acc:0.645290581 - total acc:0.667334669 - total acc:0.685370741 - total acc:0.70741483 - total acc:0.711422846 - total acc:0.719438878 - total acc:0.733466934 -``` - -## 参考文献 - -[1] Lan Z, Chen M , Goodman S, et al. ALBERT: A Lite BERT for Self-supervised Learning of Language Representations[J]. 2019. diff --git a/docs/federated/docs/source_zh_cn/split_pangu_alpha_application.md b/docs/federated/docs/source_zh_cn/split_pangu_alpha_application.md deleted file mode 100644 index debed2191573cb966a09dfb4144d51dd8242abc6..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/split_pangu_alpha_application.md +++ /dev/null @@ -1,338 +0,0 @@ -# 纵向联邦学习模型训练 - 盘古α大模型跨域训练 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/split_pangu_alpha_application.md) - -## 概述 - -随着硬件算力的进步和网络数据规模的持续膨胀,预训练大模型已日趋成为自然语言处理、图文多模态等领域的重要研究方向。以2021年发布中文NLP预训练大模型的盘古α为例,其模型参数量达2000亿,训练过程依赖海量数据和先进计算中心,限制了其应用落地和技术演进。一种可行的解决方案是基于纵向联邦学习或拆分学习(Split Learning)技术,整合多参与方的算力和数据资源,在确保安全隐私的前提下,实现预训练大模型的的跨域协同训练。 - -MindSpore Federated提供基于拆分学习的纵向联邦学习基础功能组件。本样例以盘古α模型为例,提供了面向NLP大模型的联邦学习训练样例。 - -![实现盘古α大模型跨域训练](./images/splitnn_pangu_alpha.png) - -如上图所示,该案例中,盘古α模型被依次切分为Embedding、Backbone、Head等3个子网络。其中,前级子网络Embedding和末级子网络Head部署在的参与方A网络域内,包含多级Transformer模块的Backbone子网络部署在参与方B网络域内。Embedding子网络和Head子网络读取参与方A所持有的数据,主导执行盘古α模型的训练和推理任务。 - -* 前向推理阶段,参与方A采用Embedding子网络处理原始数据后,将输出的Embedding Feature特征张量和Attention Mask特征张量传输给参与方B,作为参与方B Backbone子网络的输入。然后,参与方A读取Backbone子网络输出的Hide State特征张量,作为参与方A Head子网络的输入,最终由Head子网络输出预测结果或损失值。 - -* 反向传播阶段,参与方A在完成Head子网络的梯度计算和参数更新后,将Hide State特征张量所关联的梯度张量,传输给参与方B,用于Backbone子网络的梯度计算和参数更新。然后,参与方B在完成Backbone子网络的梯度计算和参数更新后,将Embedding Feature特征张量所关联的梯度张量,传输给参与方A,用于Embedding子网络的梯度计算和参数更新。 - -上述前向推理和反向传播过程中,参与方A和参与方B交换的特征张量和梯度张量,均采用隐私安全机制和加密算法进行处理,从而无需将参与方A所持有的数据传输给参与方B,即可实现两个参与方对网络模型的协同训练。由于Embedding子网络和Head子网络参数量较少,而Backbone子网络参数量巨大,该应用样例适用于业务方(对应参与方A)与计算中心(对应参与方B)的大模型协同训练或部署。 - -盘古α模型原理的详细介绍,可参考[MindSpore ModelZoo - pangu_alpha](https://gitee.com/mindspore/models/tree/master/official/nlp/Pangu_alpha)、[鹏程·盘古α介绍](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha),及其[研究论文](https://arxiv.org/pdf/2104.12369.pdf)。 - -## 准备环节 - -### 环境准备 - -1. 参考[获取MindSpore Federated](https://mindspore.cn/federated/docs/zh-CN/master/federated_install.html),安装MindSpore 1.8.1及以上版本和MindSpore Federated。 - -2. 下载MindSpore Federated代码,安装本应用样例依赖的Python软件包。 - - ```bash - git https://gitee.com/mindspore/federated.git - cd federated/example/splitnn_pangu_alpha/ - python -m pip install -r requirements.txt - ``` - -### 数据集准备 - -在运行样例前,需参考[MindSpore ModelZoo - pangu_alpha - Dataset Generation](https://gitee.com/mindspore/models/tree/master/official/nlp/Pangu_alpha#dataset-generation),采用preprocess.py脚本将用于训练的原始文本语料,转换为可用于模型训练的数据集。 - -## 定义纵向联邦学习训练过程 - -MindSpore Federated纵向联邦学习框架采用FLModel(参见[纵向联邦学习模型训练接口](https://mindspore.cn/federated/docs/zh-CN/master/vertical/vertical_federated_FLModel.html))和yaml文件(参见[纵向联邦学习yaml详细配置项](https://mindspore.cn/federated/docs/zh-CN/master/vertical/vertical_federated_yaml.html)),建模纵向联邦学习的训练过程。 - -### 定义网络模型 - -1. 采用MindSpore提供的功能组件,以nn.Cell(参见[mindspore.nn.Cell](https://mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html?highlight=cell#mindspore-nn-cell))为基类,编程开发本参与方待参与纵向联邦学习的训练网络。以本应用实践中参与方A的Embedding子网络为例,[示例代码](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/src/split_pangu_alpha.py)如下: - - ```python - class EmbeddingLossNet(nn.Cell): - """ - Train net of the embedding party, or the tail sub-network. - Args: - net (class): EmbeddingLayer, which is the 1st sub-network. - config (class): default config info. - """ - - def __init__(self, net: EmbeddingLayer, config): - super(EmbeddingLossNet, self).__init__(auto_prefix=False) - - self.batch_size = config.batch_size - self.seq_length = config.seq_length - dp = config.parallel_config.data_parallel - self.eod_token = config.eod_token - self.net = net - self.slice = P.StridedSlice().shard(((dp, 1),)) - self.not_equal = P.NotEqual().shard(((dp, 1), ())) - self.batch_size = config.batch_size - self.len = config.seq_length - self.slice2 = P.StridedSlice().shard(((dp, 1, 1),)) - - def construct(self, input_ids, position_id, attention_mask): - """forward process of FollowerLossNet""" - tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) - embedding_table, word_table = self.net(tokens, position_id, batch_valid_length=None) - return embedding_table, word_table, position_id, attention_mask - ``` - -2. 在yaml配置文件中,描述训练网络对应的名称、输入、输出等信息。以本应用实践中参与方A的Embedding子网络为例,[示例代码](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/embedding.yaml)如下: - - ```yaml - train_net: - name: follower_loss_net - inputs: - - name: input_ids - source: local - - name: position_id - source: local - - name: attention_mask - source: local - outputs: - - name: embedding_table - destination: remote - - name: word_table - destination: remote - - name: position_id - destination: remote - - name: attention_mask - destination: remote - ``` - - 其中,`name`字段为训练网络名称,将用于命名训练过程中保存的checkpoints文件。`inputs`字段为训练网络输入张量列表,`outputs`字段为训练网络输出张量列表。 - - `inputs`和`outputs`字段下的`name`字段,为输入/输出张量名称。输入/输出张量的名称和顺序,需要与训练网络对应Python代码中`construct`方法的输入/输出严格对应。 - - `inputs`字段下的`source`字段标识输入张量的数据来源,`local`代表输入张量来源于本地数据加载,`remote`代表输入张量来源于其他参与方网络传输。 - - `outputs`字段下的`destination`字段标识输出张量的数据去向,`local`代表输出张量仅用于本地,`remote`代表输出张量将通过网络传输给其他参与方。 - -3. 可选的,采用类似方法建模本参与方待参与纵向联邦学习的评估网络。 - -### 定义优化器 - -1. 采用MindSpore提供的功能组件,编程开发用于本参与方训练网络参数更新的优化器。以本应用实践中参与方A用于Embedding子网络训练的自定义优化器为例,[示例代码](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/src/pangu_optim.py)如下: - - ```python - class PanguAlphaAdam(TrainOneStepWithLossScaleCell): - """ - Customized Adam optimizer for training of pangu_alpha in the splitnn demo system. - """ - def __init__(self, net, optim_inst, scale_update_cell, config, yaml_data) -> None: - # 自定义优化器相关算子 - ... - - def __call__(self, *inputs, sens=None): - # 定义梯度计算和参数更新过程 - ... - ``` - - 开发者可自定义优化器类的`__init__`方法的输入输出,但优化器类的`__call__`方法的输入需仅包含`inputs`和`sens`。其中,`inputs`为`list`类型,对应训练网络的输入张量列表,其元素为`mindspore.Tensor`类型。`sens`为`dict`类型,保存用于计算训练网络参数梯度值的加权系数,其key为`str`类型的梯度加权系数标识符;value为`dict`类型,其key为`str`类型,是训练网络输出张量名称,value为`mindspore.Tensor`类型,是该输出张量对应的训练网络参数梯度值的加权系数。 - -2. 在yaml配置文件中,描述优化器对应的梯度计算、参数更新等信息。[示例代码](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/embedding.yaml)如下: - - ```yaml - opts: - - type: CustomizedAdam - grads: - - inputs: - - name: input_ids - - name: position_id - - name: attention_mask - output: - name: embedding_table - sens: hidden_states - - inputs: - - name: input_ids - - name: position_id - - name: attention_mask - output: - name: word_table - sens: word_table - params: - - name: word_embedding - - name: position_embedding - hyper_parameters: - learning_rate: 5.e-6 - eps: 1.e-8 - loss_scale: 1024.0 - ``` - - 其中,`type`字段为优化器类型,此处为开发者自定义优化器。 - - `grads`字段为优化器关联的`GradOperation`列表,优化器将使用列表中`GradOperation`算子计算输出的梯度值,更新训练网络参数。`inputs`和`output`字段为`GradOperation`算子的输入和输出张量列表,其元素分别为一个输入/输出张量名称。`sens`字段为`GradOperation`算子的梯度加权系数或灵敏度(参考[mindspore.ops.GradOperation](https://mindspore.cn/docs/zh-CN/master/api_python/ops/mindspore.ops.GradOperation.html?highlight=gradoperation))的标识符。 - - `params`字段为优化器即将更新的训练网络参数名称列表,其元素分别为一个训练网络参数名称。本示例中,自定义优化器将更新名称中包含`word_embedding`字符串和`position_embedding`字符串的网络参数。 - - `hyper_parameters`字段为优化器的超参数列表。 - -### 定义梯度加权系数计算 - -根据梯度计算的链式法则,位于全局网络后级的子网络,需要计算其输出张量相对于输入张量的梯度值,即梯度加权系数或灵敏度,传递给位于全局网络前级的子网络,用于其训练参数更新。 - -MindSpore Federated采用`GradOperation`算子,完成上述梯度加权系数或灵敏度计算过程。开发者需在yaml配置文件中,描述用于计算梯度加权系数的`GradOperation`算子。以本应用实践中参与方A的Head为例,[示例代码](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/head.yaml)如下: - -```yaml -grad_scalers: - - inputs: - - name: hidden_states - - name: input_ids - - name: word_table - - name: position_id - - name: attention_mask - output: - name: output - sens: 1024.0 -``` - -其中,`inputs`和`output`字段为`GradOperation`算子的输入和输出张量列表,其元素分别为一个输入/输出张量名称。`sens`字段为该`GradOperation`算子的梯度加权系数或灵敏度(参考[mindspore.ops.GradOperation](https://mindspore.cn/docs/zh-CN/master/api_python/ops/mindspore.ops.GradOperation.html?highlight=gradoperation)),如果为`float`或`int`型数值,则将构造一个常量张量作为梯度加权系数,如果为`str`型字符串,则将从其他参与方经网络传输的加权系数中,解析名称与其对应的张量作为加权系数。 - -### 执行训练 - -1. 完成上述Python编程开发和yaml配置文件编写后,采用MindSpore Federated提供的`FLModel`类和`FLYamlData`类,构建纵向联邦学习流程。以本应用实践中参与方A的Embedding子网络为例,[示例代码](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/run_pangu_train_local.py)如下: - - ```python - embedding_yaml = FLYamlData('./embedding.yaml') - embedding_base_net = EmbeddingLayer(config) - embedding_eval_net = embedding_train_net = EmbeddingLossNet(embedding_base_net, config) - embedding_with_loss = _VirtualDatasetCell(embedding_eval_net) - embedding_params = embedding_with_loss.trainable_params() - embedding_group_params = set_embedding_weight_decay(embedding_params) - embedding_optim_inst = FP32StateAdamWeightDecay(embedding_group_params, lr, eps=1e-8, beta1=0.9, beta2=0.95) - embedding_optim = PanguAlphaAdam(embedding_train_net, embedding_optim_inst, update_cell, config, embedding_yaml) - - embedding_fl_model = FLModel(yaml_data=embedding_yaml, - network=embedding_train_net, - eval_network=embedding_eval_net, - optimizers=embedding_optim) - ``` - - 其中,`FLYamlData`类主要完成yaml配置文件的解析和校验,`FLModel`类主要提供纵向联邦学习训练、推理等流程的控制接口。 - -2. 调用`FLModel`类的接口方法,执行纵向联邦学习训练。以本应用实践中参与方A的Embedding子网络为例,[示例代码](https://gitee.com/mindspore/federated/blob/master/example/splitnn_pangu_alpha/run_pangu_train_local.py)如下: - - ```python - if opt.resume: - embedding_fl_model.load_ckpt() - ... - for epoch in range(50): - for step, item in enumerate(train_iter, start=1): - # forward process - step = epoch * train_size + step - embedding_out = embedding_fl_model.forward_one_step(item) - ... - # backward process - embedding_fl_model.backward_one_step(item, sens=backbone_scale) - ... - if step % 1000 == 0: - embedding_fl_model.save_ckpt() - ``` - - 其中,`forward_one_step`方法和`backward_one_step`方法分别执行一个数据batch的前向推理和反向传播操作。`load_ckpt`方法和`save_ckpt`方法分别执行checkpoints文件的加载和保存操作。 - -## 运行样例 - -本样例提供2个示例程序,均以Shell脚本拉起Python程序的形式运行。 - -1. `run_pangu_train_local.sh`:单进程示例程序,参与方A和参与方B同一进程训练,其以程序内变量的方式,直接传输特征张量和梯度张量至另一参与方。 - -2. `run_pangu_train_leader.sh`和`run_pangu_train_follower.sh`:多进程示例程序,参与方A和参与方B分别运行一个进程,其分别将特征张量和梯度张量封装为protobuf消息后,通过https通信接口传输至另一参与方。`run_pangu_train_leader.sh`和`run_pangu_train_follower.sh`可分别在两台服务器上运行,实现跨域协同训练。 - -3. 当前纵向联邦分布式训练支持https跨域加密通信,启动命令如下: - - ```bash - # 以https加密通信的方式启动leader进程: - bash run_pangu_train_leader.sh 127.0.0.1:10087 127.0.0.1:10086 /path/to/train/data_set /path/to/eval/data_set True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - - # 以https加密通信的方式启动follower进程: - bash run_pangu_train_follower.sh 127.0.0.1:10086 127.0.0.1:10087 True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - ``` - -### 运行单进程样例 - -`run_pangu_train_local.sh`为例,运行示例程序的步骤如下: - -1. 进入示例程序目录: - - ```bash - cd federated/example/splitnn_pangu_alpha/ - ``` - -2. 以wiki数据集为例,拷贝数据集至示例程序目录: - - ```bash - cp -r {dataset_dir}/wiki ./ - ``` - -3. 安装依赖的Python软件包: - - ```bash - python -m pip install -r requirements.txt - ``` - -4. 修改`src/utils.py`,配置checkpoint文件加载路径、训练数据集路径、评估数据集路径等参数,示例如下: - - ```python - parser.add_argument("--load_ckpt_path", type=str, default='./checkpoints', help="predict file path.") - parser.add_argument('--data_url', required=False, default='./wiki/train/', help='Location of data.') - parser.add_argument('--eval_data_url', required=False, default='./wiki/eval/', help='Location of eval data.') - ``` - -5. 执行训练脚本: - - ```bash - ./run_pangu_train_local.sh - ``` - -6. 查看训练日志`splitnn_pangu_local.txt`中记录的训练loss信息。 - - ```text - INFO:root:epoch 0 step 10/43391 loss: 10.616087 - INFO:root:epoch 0 step 20/43391 loss: 10.424824 - INFO:root:epoch 0 step 30/43391 loss: 10.209235 - INFO:root:epoch 0 step 40/43391 loss: 9.950026 - INFO:root:epoch 0 step 50/43391 loss: 9.712448 - INFO:root:epoch 0 step 60/43391 loss: 9.557744 - INFO:root:epoch 0 step 70/43391 loss: 9.501564 - INFO:root:epoch 0 step 80/43391 loss: 9.326054 - INFO:root:epoch 0 step 90/43391 loss: 9.387547 - INFO:root:epoch 0 step 100/43391 loss: 8.795234 - ... - ``` - - 对应的可视化结果如下图所示,其中横轴为训练步数,纵轴为loss值,红色曲线为盘古α训练loss值,蓝色曲线为本示例中基于拆分学习的盘古α训练loss值。二者loss值下降的趋势基本一致,考虑到网络参数值初始化具有随机性,可验证训练过程的正确性。 - - ![盘古α大模型跨域训练结果](./images/splitnn_pangu_alpha_result.png) - -### 运行多进程样例 - -1. 类似单进程样例,进入示例程序目录,安装依赖的Python软件包: - - ```bash - cd federated/example/splitnn_pangu_alpha/ - python -m pip install -r requirements.txt - ``` - -2. 拷贝数据集至服务器1的示例程序目录: - - ```bash - cp -r {dataset_dir}/wiki ./ - ``` - -3. 在服务器1启动参与方A的训练脚本: - - ```bash - ./run_pangu_train_leader.sh {ip_address_server1} {ip_address_server2} ./wiki/train ./wiki/train - ``` - - 训练脚本的第1个参数是本地服务器(服务器1)的IP地址和端口号,第2个参数是对端服务器(服务器2)的IP地址和端口号,第3个参数是训练数据集文件路径,第4个参数是评估数据集文件路径,第5个参数标识是否加载已有的checkpoint文件。 - -4. 在服务器2启动参与方B的训练脚本: - - ```bash - ./run_pangu_train_follower.sh {ip_address_server2} {ip_address_server1} - ``` - - 训练脚本的第1个参数是本地服务器(服务器2)的IP地址和端口号,第2个参数是对端服务器(服务器2)的IP地址和端口号,第3个参数标识是否加载已有的checkpoint文件。 - -5. 查看服务器1的训练日志`leader_processs.log`中记录的训练loss信息。若其loss信息与盘古α集中式训练loss值趋势一致,可验证训练过程的正确性。 \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/split_wnd_application.md b/docs/federated/docs/source_zh_cn/split_wnd_application.md deleted file mode 100644 index d554a34af12bc35ad016b8e8a2bea66e76bab01b..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/split_wnd_application.md +++ /dev/null @@ -1,280 +0,0 @@ -# 纵向联邦学习模型训练 - Wide&Deep推荐应用 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/split_wnd_application.md) - -## 概述 - -MindSpore Federated提供基于拆分学习(Split Learning)的纵向联邦学习基础功能组件。 - -纵向FL模型训练场景:包括前向传播和后向传播/参数更新两个阶段。 - -前向传播:经数据求交模块处理参数方数据,配准特征信息和标签信息后,Follower参与方将本地特征信息输入前级网络模型,将前级网络模型输出的特征张量,经隐私安全模块加密/加扰后,由通信模块传输给Leader参与方。Leader参与方将收到的特征张量输入后级网络模型,以后级网络模型输出的预测值和本地标签信息为损失函数输入,计算损失值。 - -![](./images/vfl_forward.png) - -后向传播:Leader参与方基于损失值,计算后级网络模型的参数梯度,训练更新后级网络模型的参数,并将与特征张量关联的梯度张量,经隐私安全模块加密/加扰后,由通信模块传输传输给Follower参与方。Follower参与方将收到的梯度张量用于前级网络模型的参数训练更新。 - -![](./images/vfl_backward.png) - -纵向FL模型推理场景:与训练场景的前向传播阶段类似,但直接以后级网络模型的预测值为输出,而无需计算损失值。 - -## 网络和数据 - -![](./images/splitnn_wide_and_deep.png) - -本样例以Wide&Deep网络和Criteo数据集为例,提供了面向推荐任务的联邦学习训练样例。如上图所示,本案例中,纵向联邦学习系统由Leader参与方和Follower参与方组成。其中,Leader参与方持有20×2维特征信息和标签信息,Follower参与方持有19×2维特征信息。Leader参与方和Follower参与方分别部署1组Wide&Deep网络,并通过交换embedding向量和梯度向量,在不泄露原始特征和标签信息的前提下,实现对网络模型的协同训练。 - -Wide&Deep网络原理特性的详细介绍,可参考[MindSpore ModelZoo - Wide&Deep - Wide&Deep概述](https://gitee.com/mindspore/models/blob/master/official/recommend/Wide_and_Deep/README_CN.md#widedeep%E6%A6%82%E8%BF%B0) 及其[研究论文](https://arxiv.org/pdf/1606.07792.pdf)。 - -## 数据集准备 - -本样例基于Criteo数据集进行训练和测试,在运行样例前,需参考[MindSpore ModelZoo - Wide&Deep - 快速入门](https://gitee.com/mindspore/models/blob/master/official/recommend/Wide_and_Deep/README_CN.md#%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8),对Criteo数据集进行预处理。 - -1. 克隆MindSpore ModelZoo代码。 - - ```shell - git clone https://gitee.com/mindspore/models.git - cd models/official/recommend/Wide_and_Deep - ``` - -2. 下载数据集。 - - ```shell - mkdir -p data/origin_data && cd data/origin_data - wget http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz - tar -zxvf criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz - ``` - -3. 使用此脚本预处理数据。预处理过程可能需要一小时,生成的MindRecord数据存放在data/mindrecord路径下。预处理过程内存消耗巨大,建议使用服务器。 - - ```shell - cd ../.. - python src/preprocess_data.py --data_path=./data/ --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0 - ``` - -## 快速体验 - -本样例以Shell脚本拉起Python程序的形式运行。 - -1. 参考[MindSpore官网指引](https://www.mindspore.cn/install),安装MindSpore 1.8.1或更高版本。 - -2. 采用安装MindSpore Federated所依赖Python库。 - - ```shell - cd federated - python -m pip install -r requirements_test.txt - ``` - -3. 拷贝[预处理](#数据集准备)后的Criteo数据集至本目录下。 - - ```shell - cd tests/example/splitnn_criteo - cp -rf ${DATA_ROOT_PATH}/data/mindrecord/ ./ - ``` - -4. 运行示例程序启动脚本。 - - ```shell - # 启动leader进程: - bash run_vfl_train_leader.sh 127.0.0.1:10087 127.0.0.1:10086 /path/to/data_set False - - # 启动follower进程: - bash run_vfl_train_follower.sh 127.0.0.1:10086 127.0.0.1:10087 /path/to/data_set False - ``` - - 或者 - - ```shell - # 以https加密通信的方式启动leader进程: - bash run_vfl_train_leader.sh 127.0.0.1:10087 127.0.0.1:10086 /path/to/data_set True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - - # 以https加密通信的方式启动follower进程: - bash run_vfl_train_follower.sh 127.0.0.1:10086 127.0.0.1:10087 /path/to/data_set True server_cert_password client_cert_password /path/to/server_cert /path/to/client_cert /path/to/ca_cert - ``` - -5. 查看训练日志`log_local_gpu.txt`。 - - ```text - INFO:root:epoch 0 step 100/2582 wide_loss: 0.528141 deep_loss: 0.528339 - INFO:root:epoch 0 step 200/2582 wide_loss: 0.499408 deep_loss: 0.499410 - INFO:root:epoch 0 step 300/2582 wide_loss: 0.477544 deep_loss: 0.477882 - INFO:root:epoch 0 step 400/2582 wide_loss: 0.474377 deep_loss: 0.476771 - INFO:root:epoch 0 step 500/2582 wide_loss: 0.472926 deep_loss: 0.475157 - INFO:root:epoch 0 step 600/2582 wide_loss: 0.464844 deep_loss: 0.467011 - INFO:root:epoch 0 step 700/2582 wide_loss: 0.464496 deep_loss: 0.466615 - INFO:root:epoch 0 step 800/2582 wide_loss: 0.466895 deep_loss: 0.468971 - INFO:root:epoch 0 step 900/2582 wide_loss: 0.463155 deep_loss: 0.465299 - INFO:root:epoch 0 step 1000/2582 wide_loss: 0.457914 deep_loss: 0.460132 - INFO:root:epoch 0 step 1100/2582 wide_loss: 0.453361 deep_loss: 0.455767 - INFO:root:epoch 0 step 1200/2582 wide_loss: 0.457566 deep_loss: 0.459997 - INFO:root:epoch 0 step 1300/2582 wide_loss: 0.460841 deep_loss: 0.463281 - INFO:root:epoch 0 step 1400/2582 wide_loss: 0.460973 deep_loss: 0.463365 - INFO:root:epoch 0 step 1500/2582 wide_loss: 0.459204 deep_loss: 0.461563 - INFO:root:epoch 0 step 1600/2582 wide_loss: 0.456771 deep_loss: 0.459200 - INFO:root:epoch 0 step 1700/2582 wide_loss: 0.458479 deep_loss: 0.460963 - INFO:root:epoch 0 step 1800/2582 wide_loss: 0.449609 deep_loss: 0.452122 - INFO:root:epoch 0 step 1900/2582 wide_loss: 0.451775 deep_loss: 0.454225 - INFO:root:epoch 0 step 2000/2582 wide_loss: 0.460343 deep_loss: 0.462826 - INFO:root:epoch 0 step 2100/2582 wide_loss: 0.456814 deep_loss: 0.459201 - INFO:root:epoch 0 step 2200/2582 wide_loss: 0.452091 deep_loss: 0.454555 - INFO:root:epoch 0 step 2300/2582 wide_loss: 0.461522 deep_loss: 0.464001 - INFO:root:epoch 0 step 2400/2582 wide_loss: 0.442355 deep_loss: 0.444790 - INFO:root:epoch 0 step 2500/2582 wide_loss: 0.450675 deep_loss: 0.453242 - ... - ``` - -6. 关闭训练进程。 - - ```shell - pid=`ps -ef|grep run_vfl_train_socket |grep -v "grep" | grep -v "finish" |awk '{print $2}'` && for id in $pid; do kill -9 $id && echo "killed $id"; done - ``` - -## 深度体验 - -在启动纵向联邦学习训练之前,用户需要和使用MindSpore做普通深度学习训练一样,构造数据集迭代器和网络结构。 - -### 构造数据集 - -当前采用模拟流程,即两方读取数据源一样,但训练时,两方只使用部分的特征或标签数据,如[网络和数据](#网络和数据)所示。后续将采用[数据接入](https://www.mindspore.cn/federated/docs/zh-CN/master/data_join/data_join.html)方法两方各自导入数据。 - -```python -from run_vfl_train_local import construct_local_dataset - - -ds_train, _ = construct_local_dataset() -train_iter = ds_train.create_dict_iterator() -``` - -### 构建网络 - -Leader参与方网络: - -```python -from wide_and_deep import WideDeepModel, BottomLossNet, LeaderTopNet, LeaderTopLossNet, LeaderTopEvalNet, \ - LeaderTeeNet, LeaderTeeLossNet, LeaderTopAfterTeeNet, LeaderTopAfterTeeLossNet, LeaderTopAfterTeeEvalNet, \ - AUCMetric -from network_config import config - - -# Leader Top Net -leader_top_base_net = LeaderTopNet() -leader_top_train_net = LeaderTopLossNet(leader_top_base_net) -... -# Leader Bottom Net -leader_bottom_eval_net = leader_bottom_base_net = WideDeepModel(config, config.leader_field_size) -leader_bottom_train_net = BottomLossNet(leader_bottom_base_net, config) -``` - -Follower参与方网络: - -```python -from wide_and_deep import WideDeepModel, BottomLossNet -from network_config import config - - -follower_bottom_eval_net = follower_base_net = WideDeepModel(config, config.follower_field_size) -follower_bottom_train_net = BottomLossNet(follower_base_net, config) -``` - -### 纵向联邦通信底座 - -在训练前首先要启动通信底座,使Leader和Follower参与方组网。详细的API文档可以参考[纵向联邦通信器](https://gitee.com/mindspore/federated/blob/master/docs/api/api_python/vertical/vertical_communicator.rst)。 - -两方都需要导入纵向联邦通信器: - -```python -from mindspore_federated.startup.vertical_federated_local import VerticalFederatedCommunicator, ServerConfig -``` - -Leader参与方通信底座: - -```python -http_server_config = ServerConfig(server_name='leader', server_address=config.http_server_address) -remote_server_config = ServerConfig(server_name='follower', server_address=config.remote_server_address) -self.vertical_communicator = VerticalFederatedCommunicator(http_server_config=http_server_config, - remote_server_config=remote_server_config, - compress_configs=compress_configs) -self.vertical_communicator.launch() -``` - -Follower参与方通信底座: - -```python -http_server_config = ServerConfig(server_name='follower', server_address=config.http_server_address) -remote_server_config = ServerConfig(server_name='leader', server_address=config.remote_server_address) -self.vertical_communicator = VerticalFederatedCommunicator(http_server_config=http_server_config, - remote_server_config=remote_server_config, - compress_configs=compress_configs) -self.vertical_communicator.launch() -``` - -### 构建纵向联邦网络 - -用户需要使用MindSpore Federated提供的类,将自己构造好的网络封装成纵向联邦网络。详细的API文档可以参考[纵向联邦训练接口](https://gitee.com/mindspore/federated/blob/master/docs/api/api_python/vertical/vertical_federated_FLModel.rst)。 - -两方都需要导入纵向联邦训练接口: - -```python -from mindspore_federated import FLModel, FLYamlData -``` - -Leader参与方纵向联邦网络: - -```python -leader_bottom_yaml_data = FLYamlData(config.leader_bottom_yaml_path) -leader_top_yaml_data = FLYamlData(config.leader_top_yaml_path) -... -self.leader_top_fl_model = FLModel(yaml_data=leader_top_yaml_data, - network=leader_top_train_net, - metrics=self.eval_metric, - eval_network=leader_top_eval_net) -... -self.leader_bottom_fl_model = FLModel(yaml_data=leader_bottom_yaml_data, - network=leader_bottom_train_net, - eval_network=leader_bottom_eval_net) -``` - -Follower参与方纵向联邦网络: - -```python -follower_bottom_yaml_data = FLYamlData(config.follower_bottom_yaml_path) -... -self.follower_bottom_fl_model = FLModel(yaml_data=follower_bottom_yaml_data, - network=follower_bottom_train_net, - eval_network=follower_bottom_eval_net) -``` - -### 纵向训练 - -纵向训练的流程可以参考[概述](#概述)。 - -Leader参与方训练流程: - -```python -for epoch in range(config.epochs): - for step, item in enumerate(train_iter): - leader_embedding = self.leader_bottom_fl_model.forward_one_step(item) - item.update(leader_embedding) - follower_embedding = self.vertical_communicator.receive("follower") - ... - leader_out = self.leader_top_fl_model.forward_one_step(item, follower_embedding) - grad_scale = self.leader_top_fl_model.backward_one_step(item, follower_embedding) - scale_name = 'loss' - ... - grad_scale_follower = {scale_name: OrderedDict(list(grad_scale[scale_name].items())[2:])} - self.vertical_communicator.send_tensors("follower", grad_scale_follower) - grad_scale_leader = {scale_name: OrderedDict(list(grad_scale[scale_name].items())[:2])} - self.leader_bottom_fl_model.backward_one_step(item, sens=grad_scale_leader) -``` - -Follower参与方训练流程: - -```python -for _ in range(config.epochs): - for _, item in enumerate(train_iter): - follower_embedding = self.follower_bottom_fl_model.forward_one_step(item) - self.vertical_communicator.send_tensors("leader", follower_embedding) - scale = self.vertical_communicator.receive("leader") - self.follower_bottom_fl_model.backward_one_step(item, sens=scale) -``` - diff --git a/docs/federated/docs/source_zh_cn/vertical_federated_trainer.rst b/docs/federated/docs/source_zh_cn/vertical_federated_trainer.rst deleted file mode 100644 index 48a11d81317eb3c91f221f24b27c4809727c4019..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/vertical_federated_trainer.rst +++ /dev/null @@ -1,12 +0,0 @@ -纵向联邦训练器 -============== - -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/vertical_federated_trainer.rst - :alt: 查看源文件 - -.. toctree:: - :maxdepth: 1 - - vertical/vertical_federated_FLModel - vertical/vertical_federated_yaml \ No newline at end of file diff --git a/docs/federated/docs/source_zh_cn/vfl_communication_compress.md b/docs/federated/docs/source_zh_cn/vfl_communication_compress.md deleted file mode 100644 index 1cb0e0d102f1f71f9bfa6b4471e3176da988f273..0000000000000000000000000000000000000000 --- a/docs/federated/docs/source_zh_cn/vfl_communication_compress.md +++ /dev/null @@ -1,231 +0,0 @@ -# 纵向联邦学习通信压缩 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/federated/docs/source_zh_cn/vfl_communication_compress.md) - -纵向联邦学习通信量会影响用户体验(用户流量、通信时延、联邦学习训练效率),并受性能约束(内存、带宽、CPU 占用率)限制。小的通信量对提高用户体验和减少性能瓶颈都有很大帮助,因此需要对通信量进行压缩。MindSpore Federated 在纵向联邦应用场景中,实现了Leader和Follower之间的双向通信压缩。 - -## 总体流程 - -![image1](./images/vfl_normal_communication_compress.png) - -图 1 普通纵向联邦学习通信压缩流程框架图 - -首先在Follower上进行Embedding DP(EDP)加密操作。然后进入比特打包流程。比特打包流程中会自动判断输入数据是否可以被打包,只有输入数据可以被强转为指定比特存储格式且没有精度丢失时,才会执行比特打包操作。Follower将打包后的数据发送给Leader,Leader会根据上报的数据信息判断是否需要被拆包。在Leader将数据传给Follower之前,会将数据进行量化压缩。Follower收到数据后会对量化数据进行解压缩。 - -![image2](./images/vfl_pangu_communication_compress.png) - -图 2 盘古纵向联邦学习通信压缩流程框架图 - -总体流程和普通纵向联邦学习通信压缩流程一致。盘古纵向联邦相较于普通纵向联邦,每个iteration会多一轮通信,因此需要多进行一次量化压缩和解压缩流程。 - -## 压缩方法 - -### 比特打包压缩方法 - -比特打包压缩方法是一种将数据结构序列转换为一种紧凑的二进制表示方法。比特打包本身属于无损压缩方法,但通常输入给比特打包的数据经过了有损压缩。 - -以 3-bit 打包举例来讲: - -量化位数 bit_num= 3 - -压缩前的存储格式为float32的数据为: - -data = [3, -4, 3, -2, 3, -2, -4, 0, 1, 3] - -首先判断使用比特打包压缩是否可以压缩: - -data_int = int(data) - -若data - data_int中的元素不都为0,则退出比特打包流程。 - -将源数据根据bit_num转换为二进制格式: - -data_bin = [011, 100, 011, 110, 011, 110, 100, 000, 001, 011] - -注:转换前需要判断当前数据是否在bit_num所能容纳的范围内,若超过范围退出比特打包流程。 - -由于原生C++没有专门的二进制存储格式,需要将多个二进制数据拼接,组合成int8格式数据存储。若位数不够,则在最后一个数据上补零。组合后数据: - -data_int8 = [01110001, 11100111, 10100000, 00101100] - -再将二进制数据转换为-128 到 127 之间的整数,并强转数据类型到 int8: - -data_packed = [113, -25, -96, 44] - -最后将data_packed和bit_num传递给对端。 - -拆包时,接收方将上述流程反过来即可。 - -### 量化压缩方法 - -量化压缩方法即将浮点型的通信数据定点近似为有限多个离散值。当前支持的量化压缩方法为最小最大压缩(min_max)。 - -以 8-bit 量化举例来讲: - -量化位数 bit_num= 8 - -压缩前的浮点型数据为: - -data = [0.03356021, -0.01842778, -0.009684053, 0.025363436, -0.027571501, 0.0077043395, 0.016391572, -0.03598478, -0.0009508357] - -计算最大和最小值: - -min_val = -0.03598478 - -max_val = 0.03356021 - -计算缩放系数: - -scale = (max_val - min_val) / (2 ^ bit_num- 1) = 0.000272725450980392 - -将压缩前数据转换为-128 到 127 之间的整数,转换公式为 quant_data = round((data - min_val) / scale) - 2 ^ (bit_num - 1)。并强转数据类型到 int8: - -quant_data = [127, -64, -32, 97, -97, 32, 64, -128, 0] - -量化编码后,发送方需要上传的参数即为 quant_data、bit_num 以及最大小值 min_val 和 max_val。 - -接收方在收到 quant_data、min_val 和 max_val 后,使用反量化公式(quant_data + 2 ^ (bit_num - 1)) * (max_val - min_val) / (2 ^ bit_num - 1) + min_val,还原出权重。 - -## 快速体验 - -若要使用比特打包或量化压缩方法,首先需要成功完成任一纵向联邦场景的训练聚合过程,如[纵向联邦学习模型训练 - Wide&Deep推荐应用](https://www.mindspore.cn/federated/docs/zh-CN/master/split_wnd_application.html)。在该文档中详细介绍了包括数据集和网络模型等准备工作和模拟启动联邦学习的流程。 - -1. 安装MindSpore、MindSpore Federated以及数据预处理操作参照[纵向联邦学习模型训练 - Wide&Deep推荐应用](https://www.mindspore.cn/federated/docs/zh-CN/master/split_wnd_application.html)。 - -2. 在[相应yaml](https://gitee.com/mindspore/federated/tree/master/example/splitnn_criteo/yaml_files)中设置压缩相关配置。 - - [leader_top.yaml](https://gitee.com/mindspore/federated/blob/master/example/splitnn_criteo/yaml_files/leader_top.yaml)的配置如下: - - ```yaml - role: leader - model: # define the net of vFL party - train_net: - name: leader_loss_net - inputs: - - name: leader_wide_embedding - source: local - - name: leader_deep_embedding - source: local - - name: follower_wide_embedding - source: remote - compress_type: min_max - bit_num: 6 - - name: follower_deep_embedding - source: remote - compress_type: min_max - bit_num: 6 - ... - ``` - - [follower_bottom.yaml](https://gitee.com/mindspore/federated/blob/master/example/splitnn_criteo/yaml_files/follower_bottom.yaml)的配置如下: - - ```yaml - role: follower - model: # define the net of vFL party - train_net: - name: follower_loss_net - inputs: - - name: id_hldr0 - source: local - - name: wt_hldr0 - source: local - outputs: - - name: follower_wide_embedding - destination: remote - compress_type: min_max - bit_num: 6 - - name: follower_deep_embedding - destination: remote - compress_type: min_max - bit_num: 6 - - name: follower_l2_regu - destination: local - ... - ``` - -3. 用户可根据实际情况,进行超参修改。 - - - compress_type:压缩类型,string类型,包括:"min_max"、"bit_pack"。 - - bit_num:比特数,int类型,定义域在[1, 8]内。 - -4. 运行示例程序启动脚本。 - - ```shell - # 启动leader进程: - bash run_vfl_train_leader.sh 127.0.0.1:1984 127.0.0.1:1230 ./mindrecord/ False - # 启动follower进程: - bash run_vfl_train_follower.sh 127.0.0.1:1230 127.0.0.1:1984 ./mindrecord/ False - ``` - -5. 查看训练日志`vfl_train_leader.log`。loss正常收敛。 - - ```text - epoch 0 step 0 loss: 0.693124 - epoch 0 step 100 loss: 0.512151 - epoch 0 step 200 loss: 0.493524 - epoch 0 step 300 loss: 0.473054 - epoch 0 step 400 loss: 0.466222 - epoch 0 step 500 loss: 0.464252 - epoch 0 step 600 loss: 0.469296 - epoch 0 step 700 loss: 0.451647 - epoch 0 step 800 loss: 0.457797 - epoch 0 step 900 loss: 0.457930 - epoch 0 step 1000 loss: 0.461664 - epoch 0 step 1100 loss: 0.460415 - epoch 0 step 1200 loss: 0.466883 - epoch 0 step 1300 loss: 0.455919 - epoch 0 step 1400 loss: 0.466984 - epoch 0 step 1500 loss: 0.454486 - epoch 0 step 1600 loss: 0.458730 - epoch 0 step 1700 loss: 0.451275 - epoch 0 step 1800 loss: 0.445938 - epoch 0 step 1900 loss: 0.458323 - epoch 0 step 2000 loss: 0.446709 - ... - ``` - -6. 关闭训练进程。 - - ```shell - pid=`ps -ef|grep run_vfl_train_ |grep -v "grep" | grep -v "finish" |awk '{print $2}'` && for id in $pid; do kill -9 $id && echo "killed $id"; done - ``` - -## 深度体验 - -### 获取压缩配置 - -用户可以使用已经封装好的接口获取通讯压缩相关配置。[模型训练yaml详细配置项](https://www.mindspore.cn/federated/docs/zh-CN/master/vertical/vertical_federated_yaml.html)给出启动的相关参数配置说明。[模型训练接口](https://www.mindspore.cn/federated/docs/zh-CN/master/vertical/vertical_federated_FLModel.html)提供了获取压缩配置的接口。示例方法如下: - -```python -# parse yaml files -leader_top_yaml_data = FLYamlData(config.leader_top_yaml_path) - -# Leader Top Net -leader_top_base_net = LeaderTopNet() -leader_top_train_net = LeaderTopLossNet(leader_top_base_net) -leader_top_fl_model = FLModel( - yaml_data=leader_top_yaml_data, - network=leader_top_train_net -) - -# get compress config -compress_configs = leader_top_fl_model.get_compress_configs() -``` - -### 设置压缩配置 - -用户可以使用已经封装好的[纵向联邦学习通信器](https://www.mindspore.cn/federated/docs/zh-CN/master/vertical/vertical_communicator.html)接口将通讯压缩相关配置设置到通信器中,方法如下: - -```python -# build vertical communicator -http_server_config = ServerConfig(server_name='leader', server_address=config.http_server_address) -remote_server_config = ServerConfig(server_name='follower', server_address=config.remote_server_address) -vertical_communicator = VerticalFederatedCommunicator( - http_server_config=http_server_config, - remote_server_config=remote_server_config, - compress_configs=compress_configs -) -vertical_communicator.launch() -``` - -设置好通讯压缩配置以后,纵向联邦框架会在后端自动将通讯内容进行压缩。 diff --git a/docs/golden_stick/docs/source_en/conf.py b/docs/golden_stick/docs/source_en/conf.py index 916aac459b3d895e8c68a03bbfe43a87d5d60f07..5b28d6a31ed3862667bcdcb41de7bf6429613b1f 100644 --- a/docs/golden_stick/docs/source_en/conf.py +++ b/docs/golden_stick/docs/source_en/conf.py @@ -177,6 +177,30 @@ try: except: pass +re_url = r"(((gitee.com/mindspore/docs)|(github.com/mindspore-ai/(mindspore|docs))|" + \ + r"(mindspore.cn/(docs|tutorials|lite))|(obs.dualstack.cn-north-4.myhuaweicloud)|" + \ + r"(mindspore-website.obs.cn-north-4.myhuaweicloud))[\w\d/_.-]*?)/(master)" + +re_url2 = r"(gitee.com/mindspore/mindspore[\w\d/_.-]*?)/(master)" + +re_url3 = r"(((gitee.com/mindspore/golden-stick)|(mindspore.cn/golden_stick))[\w\d/_.-]*?)/(master)" + +re_url4 = r"(((gitee.com/mindspore/mindformers)|(mindspore.cn/mindformers))[\w\d/_.-]*?)/(dev)" + +for cur, _, files in os.walk(os.path.join(base_path, 'mindspore_gs')): + for i in files: + if i.endswith('.py'): + with open(os.path.join(cur, i), 'r+', encoding='utf-8') as f: + content = f.read() + new_content = re.sub(re_url, r'\1/r2.6.0', content) + new_content = re.sub(re_url2, r'\1/v2.6.0', new_content) + new_content = re.sub(re_url3, r'\1/r1.1.0', new_content) + new_content = re.sub(re_url4, r'\1/r1.5.0', new_content) + if new_content != content: + f.seek(0) + f.truncate() + f.write(new_content) + import mindspore_gs # Copy source files of chinese python api from golden-stick repository. diff --git a/docs/golden_stick/docs/source_en/index.rst b/docs/golden_stick/docs/source_en/index.rst index d9d7c07c4912e17173fb8d772a477f6e8dfe74e6..908bcaebb9941c3df951b13fd841de7f212535da 100644 --- a/docs/golden_stick/docs/source_en/index.rst +++ b/docs/golden_stick/docs/source_en/index.rst @@ -1,126 +1,3 @@ MindSpore Golden Stick ============================= -MindSpore Golden Stick is a model compression algorithm tool, which reduces computing power, memory, and power consumption during AI deployment and enables AI deployment in all scenarios. - -MindSpore Golden jointly designed and developed by Huawei's Noah team and Huawei's MindSpore team. The architecture diagram is shown in the figure below, which is divided into five parts: - -.. raw:: html - - - -1. The underlying MindSpore Rewrite module provides the ability to modify the front-end network. Based on the interface provided by this module, algorithm developers can add, delete, query and modify the nodes and topology relationships of the MindSpore front-end network according to specific rules; - -2. Based on MindSpore Rewrite, MindSpore Golden Stick will provide various types of algorithms, such as SimQAT algorithm, SLB quantization algorithm, SCOP pruning algorithm, etc.; - -3. At the upper level of the algorithm, MindSpore Golden Stick also plans advanced technologies such as AMC (AutoML for Model Compression), NAS (Neural Architecture Search), and HAQ (Hardware-aware Automated Quantization); - -4. In order to facilitate developers to analyze and debug algorithms, MindSpore Golden Stick provides some tools, such as visualization tool, profiler tool, summary tool, etc.; - -5. In the outermost layer, MindSpore Golden Stick encapsulates a set of concise user interface. - -.. note:: - The architecture diagram is the overall picture of MindSpore Golden Stick, which includes the features that have been implemented in the current version and the capabilities planned in RoadMap. Please refer to release notes for available features in current version. - -Code repository address: - -Design Guidelines ---------------------------------------- - -In addition to providing rich model compression algorithms, an important design concept of MindSpore Golden Stick is try to provide users with the most unified and concise experience for a wide variety of model compression algorithms in the industry, and reduce the cost of algorithm application for users. MindSpore Golden Stick implements this philosophy through two initiatives: - -1. Unified algorithm interface design to reduce user application costs: - - There are many types of model compression algorithms, such as quantization-aware training algorithms, pruning algorithms, matrix decomposition algorithms, knowledge distillation algorithms, etc. In each type of compression algorithm, there are also various specific algorithms, such as LSQ and PACT, which are both quantization-aware training algorithms. Different algorithms are often applied in different ways, which increases the learning cost for users to apply algorithms. MindSpore Golden Stick sorts out and abstracts the algorithm application process, and provides a set of unified algorithm application interfaces to minimize the learning cost of algorithm application. At the same time, this also facilitates the exploration of advanced technologies such as AMC, NAS, and HAQ based on the algorithm ecology. - -2. Provide front-end network modification capabilities to reduce algorithm development costs: - - Model compression algorithms are often designed or optimized for specific network structures. For example, perceptual quantization algorithms often insert fake-quantization nodes on the Conv2d, Conv2d + BatchNorm2d, or Conv2d + BatchNorm2d + Relu structures in the network. MindSpore Golden Stick provides the ability to modify the front-end network through API. Based on this ability, algorithm developers can formulate general network transform rules to implement the algorithm logic without needing to implement the algorithm logic for each specific network. In addition, MindSpore Golden Stick also provides some debugging capabilities, including network dump, level-wise profiling, algorithm effect analysis and visualization tool, aiming to help algorithm developers improve development and research efficiency, and help users find algorithms that meet their needs. - -General Process of Applying the MindSpore Golden Stick ------------------------------------------------------- - -.. raw:: html - - - -1. Training - - During network training, the MindSpore Golden Stick does not have great impact on the original training script logic. As shown in the highlighted part in the preceding figure, only the following two steps need to be added: - - - **Optimize the network using the MindSpore Golden Stick:** In the original training process, after the original network is defined and before the network is trained, use the MindSpore Golden Stick to optimize the network structure. Generally, this step is implemented by calling the `apply` API of MindSpore Golden Stick. For details, see `Applying the SimQAT Algorithm `_ . - - - **Register the MindSpore Golden Stick callback:** Register the callback of the MindSpore Golden Stick into the model to be trained. Generally, in this step, the `callback` function of MindSpore Golden Stick is called to obtain the corresponding callback object and register the object into the model. - -2. Deployment - - - **Network conversion:** A network compressed by MindSpore Golden Stick may require additional steps to convert the model compression structure from training mode to deployment mode, facilitating model export and deployment. For example, in the quantization aware scenario, a fake quantization node in a network usually needs to be eliminated, and converted into an operator attribute in the network. - -.. note:: - - For details about how to apply the MindSpore Golden Stick, see the detailed description and sample code in each algorithm section. - - For details about the "ms.export" step in the process, see `Exporting MINDIR Model `_ . - - For details about the "MindSpore infer" step in the process, see `MindSpore Inference Runtime `_ . - -Roadmap ---------------------------------------- - -The current release version of MindSpore Golden Stick provides a stable API and provides a linear quantization algorithm, a nonlinear quantization algorithm and a structured pruning algorithm. More algorithms and better network support will be provided in the future version, and debugging capabilities will also be provided in subsequent versions. With the enrichment of algorithms in the future, MindSpore Golden Stick will also explore capabilities such as AMC, HAQ, NAS, etc., so stay tuned. - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Installation and Deployment - - install - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Quantization Aware Training Algorithms - - quantization/overview - quantization/simqat - quantization/slb - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Post Training Quantization Algorithms - - ptq/ptq - ptq/round_to_nearest - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Pruning Algorithms - - pruner/overview - pruner/scop - pruner/lrp - pruner/lrp_tutorial - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Model Deployment - - deployment/overview - deployment/convert - -.. toctree:: - :maxdepth: 1 - :caption: API References - - mindspore_gs - mindspore_gs.common - mindspore_gs.quantization - mindspore_gs.ptq - mindspore_gs.pruner - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE diff --git a/docs/golden_stick/docs/source_zh_cn/conf.py b/docs/golden_stick/docs/source_zh_cn/conf.py index 5e968e8503472ed79cdeaae2e78d08ebfeacb640..9c228da5c4af291f17386ce6b7b644dc719106ed 100644 --- a/docs/golden_stick/docs/source_zh_cn/conf.py +++ b/docs/golden_stick/docs/source_zh_cn/conf.py @@ -211,6 +211,7 @@ for root,dirs,files in os.walk(src_dir_api): if os.path.exists(os.path.join(f'./{outer_dir_name}', file)): os.remove(os.path.join(f'./{outer_dir_name}', file)) shutil.copy(os.path.join(root, file), os.path.join(f'./{outer_dir_name}', file)) + copy_list.append(os.path.join(f'./{outer_dir_name}', file)) break else: if not os.path.exists('.' + root.split(copy_path)[-1]): @@ -218,6 +219,7 @@ for root,dirs,files in os.walk(src_dir_api): if os.path.exists('.' + root.split(copy_path)[-1] + '/'+file): os.remove('.' + root.split(copy_path)[-1] + '/'+file) shutil.copy(os.path.join(root, file), '.' + root.split(copy_path)[-1]+'/'+file) + copy_list.append('.' + root.split(copy_path)[-1]+'/'+file) readme_path = os.path.join(os.getenv("GS_PATH"), 'README_CN.md') @@ -345,6 +347,16 @@ docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if vers re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{docs_branch}/" + \ f"resource/_static/logo_source.svg\n :target: https://gitee.com/mindspore/{copy_repo}/blob/{branch}/" +re_url = r"(((gitee.com/mindspore/docs)|(github.com/mindspore-ai/(mindspore|docs))|" + \ + r"(mindspore.cn/(docs|tutorials|lite))|(obs.dualstack.cn-north-4.myhuaweicloud)|" + \ + r"(mindspore-website.obs.cn-north-4.myhuaweicloud))[\w\d/_.-]*?)/(master)" + +re_url2 = r"(gitee.com/mindspore/mindspore[\w\d/_.-]*?)/(master)" + +re_url3 = r"(((gitee.com/mindspore/golden-stick)|(mindspore.cn/golden_stick))[\w\d/_.-]*?)/(master)" + +re_url4 = r"(((gitee.com/mindspore/mindformers)|(mindspore.cn/mindformers))[\w\d/_.-]*?)/(dev)" + for cur, _, files in os.walk(moment_dir): for i in files: flag_copy = 0 @@ -360,10 +372,14 @@ for cur, _, files in os.walk(moment_dir): new_content = content if '.. include::' in content and '.. automodule::' in content: continue - if 'autosummary::' not in content and "\n=====" in content: + if 'autosummary::' not in content and "\n======" in content: re_view_ = re_view + copy_path + cur.split(moment_dir)[-1] + '/' + i + \ '\n :alt: 查看源文件\n\n' new_content = re.sub('([=]{5,})\n', r'\1\n' + re_view_, content, 1) + new_content = re.sub(re_url, r'\1/r2.6.0', new_content) + new_content = re.sub(re_url2, r'\1/v2.6.0', new_content) + new_content = re.sub(re_url3, r'\1/r1.1.0', new_content) + new_content = re.sub(re_url4, r'\1/r1.5.0', new_content) if new_content != content: f.seek(0) f.truncate() diff --git a/docs/golden_stick/docs/source_zh_cn/index.rst b/docs/golden_stick/docs/source_zh_cn/index.rst index e3aee133410bb44a7bbf40f3849e26455d6f863d..371c590604e6d9d15110bf3c4172764acad70204 100644 --- a/docs/golden_stick/docs/source_zh_cn/index.rst +++ b/docs/golden_stick/docs/source_zh_cn/index.rst @@ -1,125 +1,3 @@ MindSpore Golden Stick 文档 ============================= -MindSpore Golden Stick是一款模型压缩算法工具,降低AI部署时算力、内存和电量的消耗,使能全场景AI部署。 - -MindSpore Golden Stick由华为诺亚团队和华为MindSpore团队联合设计开发。架构如下图所示,分为五个部分: - -.. raw:: html - - - -1. 底层的MindSpore Rewrite模块提供修改前端网络的能力,基于此模块提供的接口,算法开发者可以按照特定的规则对MindSpore的前端网络做节点和拓扑关系的增删查改; - -2. 基于MindSpore Rewrite这个基础能力,MindSpore Golden Stick会提供各种类型的算法,比如SimQAT算法、SLB量化算法、SCOP剪枝算法等; - -3. 在算法的更上层,MindSpore Golden Stick还规划了如AMC(自动模型压缩技术)、NAS(网络结构搜索)、HAQ(硬件感知的自动量化)等高阶技术; - -4. 为了方便开发者分析调试算法,MindSpore Golden Stick提供了一些工具,如Visualization工具(可视化工具)、Profiler工具(逐层分析工具)、Summary工具(算法压缩效果分析工具)等; - -5. 在最外层,MindSpore Golden Stick封装了一套简洁的用户接口。 - -.. note:: - 架构图是MindSpore Golden Stick的全貌,其中包含了当前版本已经实现的功能以及规划在RoadMap中能力。具体开放的功能可以参考对应版本的ReleaseNotes。 - -代码仓地址: - -设计思路 ---------------------------------------- - -MindSpore Golden Stick除了提供丰富的模型压缩算法外,一个重要的设计理念是针对业界种类繁多的模型压缩算法,提供给用户一个尽可能统一且简洁的体验,降低用户的算法应用成本。MindSpore Golden Stick通过两个举措来实现该理念: - -1. 统一的算法接口设计,降低用户应用成本 - - 模型压缩算法种类繁多,有如量化感知训练算法、剪枝算法、矩阵分解算法、知识蒸馏算法等;在每类压缩算法中,还有会各种具体的算法,比如LSQ、PACT都是量化感知训练算法。不同算法的应用方式往往各不相同,这增加了用户应用算法的学习成本。MindSpore Golden Stick对算法应用流程做了梳理和抽象,提供了一套统一的算法应用接口,最大程度缩减算法应用的学习成本。同时这也方便了后续在算法生态的基础上,做一些AMC(自动模型压缩技术)、NAS(网络结构搜索)、HAQ(硬件感知的自动量化)等高阶技术的探索。 - -2. 提供前端网络修改能力,降低算法接入成本 - - 模型压缩算法往往会针对特定的网络结构做设计或者优化,如感知量化算法往往在网络中的Conv2d、Conv2d + BatchNorm2d或者Conv2d + BatchNorm2d + Relu结构上插入伪量化节点。MindSpore Golden Stick提供了通过接口修改前端网络的能力,算法开发者可以基于此能力制定通用的改图规则去实现算法逻辑,而不需要对每个特定的网络都实现一遍算法逻辑算法。此外MindSpore Golden Stick还会提供一些调测能力,包括网络dump、逐层profiling、算法效果分析、可视化等能力,旨在帮助算法开发者提升开发和研究效率,帮助用户寻找契合于自己需求的算法。 - -应用MindSpore Golden Stick算法的一般流程 ------------------------------------------ - -.. raw:: html - - - -1. 训练阶段 - - 在训练网络时应用MindSpore Golden Stick算法不会对原有的训练脚本逻辑产生很大的影响,如上图中黄色部分所示,仅需要增加额外两步: - - - **应用MindSpore Golden Stick算法优化网络:** 在原训练流程中,在定义原始网络之后,网络训练之前,应用MindSpore Golden Stick算法优化网络结构。一般这个步骤是调用MindSpore Golden Stick的 `apply` 接口实现的,可以参考 `应用SimQAT算法 `_。 - - - **注册MindSpore Golden Stick回调逻辑:** 将MindSpore Golden Stick算法的回调逻辑注册到要训练的model中。一般这个步骤是调用MindSpore Golden Stick的 `callback` 获取相应的callback对象, 注册到model中。 - -2. 部署阶段 - - - **网络转换:** 经过MindSpore Golden Stick压缩的网络可能需要额外的步骤,将网络中模型压缩相关的结构从训练形态转化为部署形态,方便进一步进行模型导出和模型部署。比如在对于感知量化场景,常常需要将网络中的伪量化节点消除,转换为网络中的算子属性。 - -.. note:: - - 应用MindSpore Golden Stick算法的细节,可以在每个算法章节中找到详细说明和示例代码。 - - 流程中的"ms.export"步骤可以参考 `导出mindir格式文件 `_ 章节。 - - 流程中的"昇思推理优化工具和运行时"步骤可以参考 `昇思推理 `_ 章节。 - -未来规划 ----------- - -MindSpore Golden Stick初始版本包含一个稳定的API,并提供一个线性量化算法,一个非线性量化算法和一个结构化剪枝算法。后续会提供更多的算法和更完善的网络支持,调测能力也会在后续版本提供。将来随着算法的丰富,MindSpore Golden Stick还会探索AMC、HAQ和NAS等能力,敬请期待。 - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: 安装部署 - - install - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: 量化感知训练算法 - - quantization/overview - quantization/simqat - quantization/slb - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: 训练后量化算法 - - ptq/overview - ptq/ptq - ptq/round_to_nearest - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: 剪枝算法 - - pruner/overview - pruner/scop - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: 模型部署 - - deployment/overview - deployment/convert - -.. toctree:: - :maxdepth: 1 - :caption: API参考 - - mindspore_gs - mindspore_gs.common - mindspore_gs.quantization - mindspore_gs.ptq - mindspore_gs.pruner - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE diff --git a/docs/graphlearning/docs/Makefile b/docs/graphlearning/docs/Makefile deleted file mode 100644 index 1eff8952707bdfa503c8d60c1e9a903053170ba2..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source_zh_cn -BUILDDIR = build_zh_cn - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/graphlearning/docs/_ext/customdocumenter.txt b/docs/graphlearning/docs/_ext/customdocumenter.txt deleted file mode 100644 index 2d37ae41f6772a21da2a7dc5c7bff75128e68330..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/_ext/customdocumenter.txt +++ /dev/null @@ -1,245 +0,0 @@ -import re -import os -from sphinx.ext.autodoc import Documenter - - -class CustomDocumenter(Documenter): - - def document_members(self, all_members: bool = False) -> None: - """Generate reST for member documentation. - - If *all_members* is True, do all members, else those given by - *self.options.members*. - """ - # set current namespace for finding members - self.env.temp_data['autodoc:module'] = self.modname - if self.objpath: - self.env.temp_data['autodoc:class'] = self.objpath[0] - - want_all = all_members or self.options.inherited_members or \ - self.options.members is ALL - # find out which members are documentable - members_check_module, members = self.get_object_members(want_all) - - # **** 排除已写中文接口名 **** - file_path = os.path.join(self.env.app.srcdir, self.env.docname+'.rst') - exclude_re = re.compile(r'(.. py:class::|.. py:function::)\s+(.*?)(\(|\n)') - includerst_re = re.compile(r'.. include::\s+(.*?)\n') - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - excluded_members = exclude_re.findall(content) - if excluded_members: - excluded_members = [i[1].split('.')[-1] for i in excluded_members] - rst_included = includerst_re.findall(content) - if rst_included: - for i in rst_included: - include_path = os.path.join(os.path.dirname(file_path), i) - if os.path.exists(include_path): - with open(include_path, 'r', encoding='utf8') as g: - content_ = g.read() - excluded_member_ = exclude_re.findall(content_) - if excluded_member_: - excluded_member_ = [j[1].split('.')[-1] for j in excluded_member_] - excluded_members.extend(excluded_member_) - - if excluded_members: - if self.options.exclude_members: - self.options.exclude_members |= set(excluded_members) - else: - self.options.exclude_members = excluded_members - - # remove members given by exclude-members - if self.options.exclude_members: - members = [ - (membername, member) for (membername, member) in members - if ( - self.options.exclude_members is ALL or - membername not in self.options.exclude_members - ) - ] - - # document non-skipped members - memberdocumenters = [] # type: List[Tuple[Documenter, bool]] - for (mname, member, isattr) in self.filter_members(members, want_all): - classes = [cls for cls in self.documenters.values() - if cls.can_document_member(member, mname, isattr, self)] - if not classes: - # don't know how to document this member - continue - # prefer the documenter with the highest priority - classes.sort(key=lambda cls: cls.priority) - # give explicitly separated module name, so that members - # of inner classes can be documented - full_mname = self.modname + '::' + \ - '.'.join(self.objpath + [mname]) - documenter = classes[-1](self.directive, full_mname, self.indent) - memberdocumenters.append((documenter, isattr)) - member_order = self.options.member_order or \ - self.env.config.autodoc_member_order - if member_order == 'groupwise': - # sort by group; relies on stable sort to keep items in the - # same group sorted alphabetically - memberdocumenters.sort(key=lambda e: e[0].member_order) - elif member_order == 'bysource' and self.analyzer: - # sort by source order, by virtue of the module analyzer - tagorder = self.analyzer.tagorder - - def keyfunc(entry: Tuple[Documenter, bool]) -> int: - fullname = entry[0].name.split('::')[1] - return tagorder.get(fullname, len(tagorder)) - memberdocumenters.sort(key=keyfunc) - - for documenter, isattr in memberdocumenters: - documenter.generate( - all_members=True, real_modname=self.real_modname, - check_module=members_check_module and not isattr) - - # reset current objects - self.env.temp_data['autodoc:module'] = None - self.env.temp_data['autodoc:class'] = None - - def generate(self, more_content: Any = None, real_modname: str = None, - check_module: bool = False, all_members: bool = False) -> None: - """Generate reST for the object given by *self.name*, and possibly for - its members. - - If *more_content* is given, include that content. If *real_modname* is - given, use that module name to find attribute docs. If *check_module* is - True, only generate if the object is defined in the module name it is - imported from. If *all_members* is True, document all members. - """ - if not self.parse_name(): - # need a module to import - logger.warning( - __('don\'t know which module to import for autodocumenting ' - '%r (try placing a "module" or "currentmodule" directive ' - 'in the document, or giving an explicit module name)') % - self.name, type='autodoc') - return - - # now, import the module and get object to document - if not self.import_object(): - return - - # If there is no real module defined, figure out which to use. - # The real module is used in the module analyzer to look up the module - # where the attribute documentation would actually be found in. - # This is used for situations where you have a module that collects the - # functions and classes of internal submodules. - self.real_modname = real_modname or self.get_real_modname() # type: str - - # try to also get a source code analyzer for attribute docs - try: - self.analyzer = ModuleAnalyzer.for_module(self.real_modname) - # parse right now, to get PycodeErrors on parsing (results will - # be cached anyway) - self.analyzer.find_attr_docs() - except PycodeError as err: - logger.debug('[autodoc] module analyzer failed: %s', err) - # no source file -- e.g. for builtin and C modules - self.analyzer = None - # at least add the module.__file__ as a dependency - if hasattr(self.module, '__file__') and self.module.__file__: - self.directive.filename_set.add(self.module.__file__) - else: - self.directive.filename_set.add(self.analyzer.srcname) - - # check __module__ of object (for members not given explicitly) - if check_module: - if not self.check_module(): - return - - # document members, if possible - self.document_members(all_members) - - -class ModuleDocumenter(CustomDocumenter): - """ - Specialized Documenter subclass for modules. - """ - objtype = 'module' - content_indent = '' - titles_allowed = True - - option_spec = { - 'members': members_option, 'undoc-members': bool_option, - 'noindex': bool_option, 'inherited-members': bool_option, - 'show-inheritance': bool_option, 'synopsis': identity, - 'platform': identity, 'deprecated': bool_option, - 'member-order': identity, 'exclude-members': members_set_option, - 'private-members': bool_option, 'special-members': members_option, - 'imported-members': bool_option, 'ignore-module-all': bool_option - } # type: Dict[str, Callable] - - def __init__(self, *args: Any) -> None: - super().__init__(*args) - merge_members_option(self.options) - - @classmethod - def can_document_member(cls, member: Any, membername: str, isattr: bool, parent: Any - ) -> bool: - # don't document submodules automatically - return False - - def resolve_name(self, modname: str, parents: Any, path: str, base: Any - ) -> Tuple[str, List[str]]: - if modname is not None: - logger.warning(__('"::" in automodule name doesn\'t make sense'), - type='autodoc') - return (path or '') + base, [] - - def parse_name(self) -> bool: - ret = super().parse_name() - if self.args or self.retann: - logger.warning(__('signature arguments or return annotation ' - 'given for automodule %s') % self.fullname, - type='autodoc') - return ret - - def add_directive_header(self, sig: str) -> None: - Documenter.add_directive_header(self, sig) - - sourcename = self.get_sourcename() - - # add some module-specific options - if self.options.synopsis: - self.add_line(' :synopsis: ' + self.options.synopsis, sourcename) - if self.options.platform: - self.add_line(' :platform: ' + self.options.platform, sourcename) - if self.options.deprecated: - self.add_line(' :deprecated:', sourcename) - - def get_object_members(self, want_all: bool) -> Tuple[bool, List[Tuple[str, object]]]: - if want_all: - if (self.options.ignore_module_all or not - hasattr(self.object, '__all__')): - # for implicit module members, check __module__ to avoid - # documenting imported objects - return True, get_module_members(self.object) - else: - memberlist = self.object.__all__ - # Sometimes __all__ is broken... - if not isinstance(memberlist, (list, tuple)) or not \ - all(isinstance(entry, str) for entry in memberlist): - logger.warning( - __('__all__ should be a list of strings, not %r ' - '(in module %s) -- ignoring __all__') % - (memberlist, self.fullname), - type='autodoc' - ) - # fall back to all members - return True, get_module_members(self.object) - else: - memberlist = self.options.members or [] - ret = [] - for mname in memberlist: - try: - ret.append((mname, safe_getattr(self.object, mname))) - except AttributeError: - logger.warning( - __('missing attribute mentioned in :members: or __all__: ' - 'module %s, attribute %s') % - (safe_getattr(self.object, '__name__', '???'), mname), - type='autodoc' - ) - return False, ret diff --git a/docs/graphlearning/docs/_ext/myautosummary.py b/docs/graphlearning/docs/_ext/myautosummary.py deleted file mode 100644 index ce61d3e218c85ee713db7dd15ad547cce5f9da30..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/_ext/myautosummary.py +++ /dev/null @@ -1,522 +0,0 @@ -"""Customized autosummary directives for sphinx.""" - -import importlib -import inspect -import os -import re -from typing import List, Tuple -from docutils.nodes import Node -from sphinx.locale import __ -from sphinx.ext.autosummary import Autosummary, posixpath, addnodes, logger, Matcher, autosummary_toc, get_import_prefixes_from_env -from sphinx.ext.autosummary import mock, StringList, ModuleType, get_documenter, ModuleAnalyzer, PycodeError, mangle_signature -from sphinx.ext.autosummary import import_by_name, extract_summary, autosummary_table, nodes, switch_source_input, rst -from sphinx.ext.autodoc.directive import DocumenterBridge, Options - -class MsAutosummary(Autosummary): - """ - Inherited from sphinx's autosummary, add titles and a column for the generated table. - """ - - def init(self): - """ - init method - """ - self.find_doc_name = "" - self.third_title = "" - self.default_doc = "" - - def extract_env_summary(self, doc: List[str]) -> str: - """Extract env summary from docstring.""" - env_sum = self.default_doc - for i, piece in enumerate(doc): - if piece.startswith(self.find_doc_name): - env_sum = doc[i+1][4:] - return env_sum - - def run(self): - """ - run method - """ - self.init() - self.bridge = DocumenterBridge(self.env, self.state.document.reporter, - Options(), self.lineno, self.state) - - names = [x.strip().split()[0] for x in self.content - if x.strip() and re.search(r'^[~a-zA-Z_]', x.strip()[0])] - items = self.get_items(names) - teble_nodes = self.get_table(items) - - if 'toctree' in self.options: - dirname = posixpath.dirname(self.env.docname) - - tree_prefix = self.options['toctree'].strip() - docnames = [] - excluded = Matcher(self.config.exclude_patterns) - for item in items: - docname = posixpath.join(tree_prefix, item[3]) - docname = posixpath.normpath(posixpath.join(dirname, docname)) - if docname not in self.env.found_docs: - location = self.state_machine.get_source_and_line(self.lineno) - if excluded(self.env.doc2path(docname, None)): - msg = __('autosummary references excluded document %r. Ignored.') - else: - msg = __('autosummary: stub file not found %r. ' - 'Check your autosummary_generate setting.') - logger.warning(msg, item[3], location=location) - continue - docnames.append(docname) - - if docnames: - tocnode = addnodes.toctree() - tocnode['includefiles'] = docnames - tocnode['entries'] = [(None, docn) for docn in docnames] - tocnode['maxdepth'] = -1 - tocnode['glob'] = None - teble_nodes.append(autosummary_toc('', '', tocnode)) - return teble_nodes - - def get_items(self, names: List[str]) -> List[Tuple[str, str, str, str, str]]: - """Try to import the given names, and return a list of - ``[(name, signature, summary_string, real_name, env_summary), ...]``. - """ - prefixes = get_import_prefixes_from_env(self.env) - items = [] # type: List[Tuple[str, str, str, str, str]] - max_item_chars = 50 - - for name in names: - display_name = name - if name.startswith('~'): - name = name[1:] - display_name = name.split('.')[-1] - try: - with mock(self.config.autosummary_mock_imports): - real_name, obj, parent, modname = import_by_name(name, prefixes=prefixes) - except ImportError: - logger.warning(__('failed to import %s'), name) - items.append((name, '', '', name, '')) - continue - - self.bridge.result = StringList() # initialize for each documenter - full_name = real_name - if not isinstance(obj, ModuleType): - # give explicitly separated module name, so that members - # of inner classes can be documented - full_name = modname + '::' + full_name[len(modname) + 1:] - # NB. using full_name here is important, since Documenters - # handle module prefixes slightly differently - doccls = get_documenter(self.env.app, obj, parent) - documenter = doccls(self.bridge, full_name) - - if not documenter.parse_name(): - logger.warning(__('failed to parse name %s'), real_name) - items.append((display_name, '', '', real_name, '')) - continue - if not documenter.import_object(): - logger.warning(__('failed to import object %s'), real_name) - items.append((display_name, '', '', real_name, '')) - continue - if documenter.options.members and not documenter.check_module(): - continue - - # try to also get a source code analyzer for attribute docs - try: - documenter.analyzer = ModuleAnalyzer.for_module( - documenter.get_real_modname()) - # parse right now, to get PycodeErrors on parsing (results will - # be cached anyway) - documenter.analyzer.find_attr_docs() - except PycodeError as err: - logger.debug('[autodoc] module analyzer failed: %s', err) - # no source file -- e.g. for builtin and C modules - documenter.analyzer = None - - # -- Grab the signature - - try: - sig = documenter.format_signature(show_annotation=False) - except TypeError: - # the documenter does not support ``show_annotation`` option - sig = documenter.format_signature() - - if not sig: - sig = '' - else: - max_chars = max(10, max_item_chars - len(display_name)) - sig = mangle_signature(sig, max_chars=max_chars) - - # -- Grab the summary - - documenter.add_content(None) - summary = extract_summary(self.bridge.result.data[:], self.state.document) - env_sum = self.extract_env_summary(self.bridge.result.data[:]) - items.append((display_name, sig, summary, real_name, env_sum)) - - return items - - def get_table(self, items: List[Tuple[str, str, str, str, str]]) -> List[Node]: - """Generate a proper list of table nodes for autosummary:: directive. - - *items* is a list produced by :meth:`get_items`. - """ - table_spec = addnodes.tabular_col_spec() - table_spec['spec'] = r'\X{1}{2}\X{1}{2}' - - table = autosummary_table('') - real_table = nodes.table('', classes=['longtable']) - table.append(real_table) - group = nodes.tgroup('', cols=3) - real_table.append(group) - group.append(nodes.colspec('', colwidth=10)) - group.append(nodes.colspec('', colwidth=70)) - group.append(nodes.colspec('', colwidth=30)) - body = nodes.tbody('') - group.append(body) - - def append_row(*column_texts: str) -> None: - row = nodes.row('', color="red") - source, line = self.state_machine.get_source_and_line() - for text in column_texts: - node = nodes.paragraph('') - vl = StringList() - vl.append(text, '%s:%d:' % (source, line)) - with switch_source_input(self.state, vl): - self.state.nested_parse(vl, 0, node) - try: - if isinstance(node[0], nodes.paragraph): - node = node[0] - except IndexError: - pass - row.append(nodes.entry('', node)) - body.append(row) - - # add table's title - append_row("**API Name**", "**Description**", self.third_title) - for name, sig, summary, real_name, env_sum in items: - qualifier = 'obj' - if 'nosignatures' not in self.options: - col1 = ':%s:`%s <%s>`\\ %s' % (qualifier, name, real_name, rst.escape(sig)) - else: - col1 = ':%s:`%s <%s>`' % (qualifier, name, real_name) - col2 = summary - col3 = env_sum - append_row(col1, col2, col3) - - return [table_spec, table] - - -class MsNoteAutoSummary(MsAutosummary): - """ - Inherited from MsAutosummary. Add a third column about `Note` to the table. - """ - - def init(self): - """ - init method - """ - self.find_doc_name = ".. note::" - self.third_title = "**Note**" - self.default_doc = "None" - - def extract_env_summary(self, doc: List[str]) -> str: - """Extract env summary from docstring.""" - env_sum = self.default_doc - for piece in doc: - if piece.startswith(self.find_doc_name): - env_sum = piece[10:] - return env_sum - - -class MsPlatformAutoSummary(MsAutosummary): - """ - Inherited from MsAutosummary. Add a third column about `Supported Platforms` to the table. - """ - def init(self): - """ - init method - """ - self.find_doc_name = "Supported Platforms:" - self.third_title = "**{}**".format(self.find_doc_name[:-1]) - self.default_doc = "To Be Developed" - -class MsCnAutoSummary(Autosummary): - """Overwrite MsPlatformAutosummary for chinese python api.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.table_head = () - self.find_doc_name = "" - self.third_title = "" - self.default_doc = "" - self.third_name_en = "" - - def get_third_column_en(self, doc): - """Get the third column for en.""" - third_column = self.default_doc - for i, piece in enumerate(doc): - if piece.startswith(self.third_name_en): - try: - if "eprecated" in doc[i+1][4:]: - third_column = "弃用" - else: - third_column = doc[i+1][4:] - except IndexError: - third_column = '' - return third_column - - def get_summary_re(self, display_name: str): - return re.compile(rf'\.\. \w+:\w+::\s+{display_name}.*?\n\n\s+(.*?)[。\n]') - - def run(self) -> List[Node]: - self.bridge = DocumenterBridge(self.env, self.state.document.reporter, - Options(), self.lineno, self.state) - - names = [x.strip().split()[0] for x in self.content - if x.strip() and re.search(r'^[~a-zA-Z_]', x.strip()[0])] - items = self.get_items(names) - #pylint: disable=redefined-outer-name - nodes = self.get_table(items) - - dirname = posixpath.dirname(self.env.docname) - - tree_prefix = self.options['toctree'].strip() - docnames = [] - names = [i[0] for i in items] - for name in names: - docname = posixpath.join(tree_prefix, name) - docname = posixpath.normpath(posixpath.join(dirname, docname)) - if docname not in self.env.found_docs: - continue - - docnames.append(docname) - - if docnames: - tocnode = addnodes.toctree() - tocnode['includefiles'] = docnames - tocnode['entries'] = [(None, docn) for docn in docnames] - tocnode['maxdepth'] = -1 - tocnode['glob'] = None - - nodes.append(autosummary_toc('', '', tocnode)) - - return nodes - - def get_items(self, names: List[str]) -> List[Tuple[str, str, str, str]]: - """Try to import the given names, and return a list of - ``[(name, signature, summary_string, real_name), ...]``. - """ - prefixes = get_import_prefixes_from_env(self.env) - doc_path = os.path.dirname(self.state.document.current_source) - items = [] # type: List[Tuple[str, str, str, str]] - max_item_chars = 50 - origin_rst_files = self.env.config.rst_files - all_rst_files = self.env.found_docs - generated_files = all_rst_files.difference(origin_rst_files) - - for name in names: - display_name = name - if name.startswith('~'): - name = name[1:] - display_name = name.split('.')[-1] - - dir_name = self.options['toctree'] - spec_path = os.path.join('api_python', dir_name, display_name) - file_path = os.path.join(doc_path, dir_name, display_name+'.rst') - if os.path.exists(file_path) and spec_path not in generated_files: - summary_re_tag = re.compile(rf'\.\. \w+:\w+::\s+{display_name}.*?\n\s+:.*?:\n\n\s+(.*?)[。\n]') - summary_re_line = re.compile(rf'\.\. \w+:\w+::\s+{display_name}(?:.|\n|)+?\n\n\s+(.*?)[。\n]') - summary_re = self.get_summary_re(display_name) - content = '' - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - if content: - summary_str = summary_re.findall(content) - summary_str_tag = summary_re_tag.findall(content) - summary_str_line = summary_re_line.findall(content) - if summary_str: - if re.findall("[::,,。.;;]", summary_str[0][-1]): - logger.warning(f"{display_name}接口的概述格式需调整") - summary_str = summary_str[0] + '。' - elif summary_str_tag: - if re.findall("[::,,。.;;]", summary_str_tag[0][-1]): - logger.warning(f"{display_name}接口的概述格式需调整") - summary_str = summary_str_tag[0] + '。' - elif summary_str_line: - if re.findall("[::,,。.;;]", summary_str_line[0][-1]): - logger.warning(f"{display_name}接口的概述格式需调整") - summary_str = summary_str_line[0] + '。' - else: - summary_str = '' - if not self.table_head: - items.append((display_name, summary_str)) - else: - third_str = self.get_third_column(display_name, content) - if third_str: - third_str = third_str[0] - else: - third_str = '' - - items.append((display_name, summary_str, third_str)) - else: - try: - with mock(self.config.autosummary_mock_imports): - real_name, obj, parent, modname = import_by_name(name, prefixes=prefixes) - except ImportError: - logger.warning(__('failed to import %s'), name) - items.append((name, '', '')) - continue - - self.bridge.result = StringList() # initialize for each documenter - full_name = real_name - if not isinstance(obj, ModuleType): - # give explicitly separated module name, so that members - # of inner classes can be documented - full_name = modname + '::' + full_name[len(modname) + 1:] - # NB. using full_name here is important, since Documenters - # handle module prefixes slightly differently - doccls = get_documenter(self.env.app, obj, parent) - documenter = doccls(self.bridge, full_name) - - if not documenter.parse_name(): - logger.warning(__('failed to parse name %s'), real_name) - items.append((display_name, '', '')) - continue - if not documenter.import_object(): - logger.warning(__('failed to import object %s'), real_name) - items.append((display_name, '', '')) - continue - if documenter.options.members and not documenter.check_module(): - continue - - # try to also get a source code analyzer for attribute docs - try: - documenter.analyzer = ModuleAnalyzer.for_module( - documenter.get_real_modname()) - # parse right now, to get PycodeErrors on parsing (results will - # be cached anyway) - documenter.analyzer.find_attr_docs() - except PycodeError as err: - logger.debug('[autodoc] module analyzer failed: %s', err) - # no source file -- e.g. for builtin and C modules - documenter.analyzer = None - - # -- Grab the signature - - try: - sig = documenter.format_signature(show_annotation=False) - except TypeError: - # the documenter does not support ``show_annotation`` option - sig = documenter.format_signature() - - if not sig: - sig = '' - else: - max_chars = max(10, max_item_chars - len(display_name)) - sig = mangle_signature(sig, max_chars=max_chars) - - # -- Grab the summary and third_colum - - documenter.add_content(None) - summary = extract_summary(self.bridge.result.data[:], self.state.document) - if self.table_head: - third_colum = self.get_third_column_en(self.bridge.result.data[:]) - items.append((display_name, summary, third_colum)) - else: - items.append((display_name, summary)) - - - return items - - def get_table(self, items: List[Tuple[str, str, str]]) -> List[Node]: - """Generate a proper list of table nodes for autosummary:: directive. - - *items* is a list produced by :meth:`get_items`. - """ - table_spec = addnodes.tabular_col_spec() - table = autosummary_table('') - real_table = nodes.table('', classes=['longtable']) - table.append(real_table) - - if not self.table_head: - table_spec['spec'] = r'\X{1}{2}\X{1}{2}' - group = nodes.tgroup('', cols=2) - real_table.append(group) - group.append(nodes.colspec('', colwidth=10)) - group.append(nodes.colspec('', colwidth=90)) - else: - table_spec['spec'] = r'\X{1}{2}\X{1}{2}\X{1}{2}' - group = nodes.tgroup('', cols=3) - real_table.append(group) - group.append(nodes.colspec('', colwidth=10)) - group.append(nodes.colspec('', colwidth=60)) - group.append(nodes.colspec('', colwidth=30)) - body = nodes.tbody('') - group.append(body) - - def append_row(*column_texts: str) -> None: - row = nodes.row('') - source, line = self.state_machine.get_source_and_line() - for text in column_texts: - node = nodes.paragraph('') - vl = StringList() - vl.append(text, '%s:%d:' % (source, line)) - with switch_source_input(self.state, vl): - self.state.nested_parse(vl, 0, node) - try: - if isinstance(node[0], nodes.paragraph): - node = node[0] - except IndexError: - pass - row.append(nodes.entry('', node)) - body.append(row) - append_row(*self.table_head) - if not self.table_head: - for name, summary in items: - qualifier = 'obj' - col1 = ':%s:`%s <%s>`' % (qualifier, name, name) - col2 = summary - append_row(col1, col2) - else: - for name, summary, other in items: - qualifier = 'obj' - col1 = ':%s:`%s <%s>`' % (qualifier, name, name) - col2 = summary - col3 = other - append_row(col1, col2, col3) - return [table_spec, table] - -def get_api(fullname): - """Get the api module.""" - try: - module_name, api_name = ".".join(fullname.split('.')[:-1]), fullname.split('.')[-1] - # pylint: disable=unused-variable - module_import = importlib.import_module(module_name) - except ModuleNotFoundError: - module_name, api_name = ".".join(fullname.split('.')[:-2]), ".".join(fullname.split('.')[-2:]) - module_import = importlib.import_module(module_name) - # pylint: disable=eval-used - api = eval(f"module_import.{api_name}") - return api - -class MsCnPlatformAutoSummary(MsCnAutoSummary): - """definition of cnmsplatformautosummary.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.table_head = ('**接口名**', '**概述**', '**支持平台**') - self.third_name_en = "Supported Platforms:" - - def get_third_column(self, name=None, content=None): - """Get the`Supported Platforms`.""" - if not name: - return [] - try: - api_doc = inspect.getdoc(get_api(name)) - platform_str = re.findall(r'Supported Platforms:\n\s+(.*?)\n\n', api_doc) - if ['deprecated'] == platform_str: - return ["弃用"] - if not platform_str: - platform_str_leak = re.findall(r'Supported Platforms:\n\s+(.*)', api_doc) - if platform_str_leak: - return platform_str_leak - return ["``Ascend`` ``GPU`` ``CPU``"] - return platform_str - except: #pylint: disable=bare-except - return [] diff --git a/docs/graphlearning/docs/_ext/overwriteautosummary_generate.txt b/docs/graphlearning/docs/_ext/overwriteautosummary_generate.txt deleted file mode 100644 index 4b0a1b1dd2b410ecab971b13da9993c90d65ef0d..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/_ext/overwriteautosummary_generate.txt +++ /dev/null @@ -1,707 +0,0 @@ -""" - sphinx.ext.autosummary.generate - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - Usable as a library or script to generate automatic RST source files for - items referred to in autosummary:: directives. - - Each generated RST file contains a single auto*:: directive which - extracts the docstring of the referred item. - - Example Makefile rule:: - - generate: - sphinx-autogen -o source/generated source/*.rst - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import argparse -import importlib -import inspect -import locale -import os -import pkgutil -import pydoc -import re -import sys -import warnings -from gettext import NullTranslations -from os import path -from typing import Any, Dict, List, NamedTuple, Sequence, Set, Tuple, Type, Union - -from jinja2 import TemplateNotFound -from jinja2.sandbox import SandboxedEnvironment - -import sphinx.locale -from sphinx import __display_version__, package_dir -from sphinx.application import Sphinx -from sphinx.builders import Builder -from sphinx.config import Config -from sphinx.deprecation import RemovedInSphinx50Warning -from sphinx.ext.autodoc import Documenter -from sphinx.ext.autodoc.importer import import_module -from sphinx.ext.autosummary import (ImportExceptionGroup, get_documenter, import_by_name, - import_ivar_by_name) -from sphinx.locale import __ -from sphinx.pycode import ModuleAnalyzer, PycodeError -from sphinx.registry import SphinxComponentRegistry -from sphinx.util import logging, rst, split_full_qualified_name, get_full_modname -from sphinx.util.inspect import getall, safe_getattr -from sphinx.util.osutil import ensuredir -from sphinx.util.template import SphinxTemplateLoader - -logger = logging.getLogger(__name__) - - -class DummyApplication: - """Dummy Application class for sphinx-autogen command.""" - - def __init__(self, translator: NullTranslations) -> None: - self.config = Config() - self.registry = SphinxComponentRegistry() - self.messagelog: List[str] = [] - self.srcdir = "/" - self.translator = translator - self.verbosity = 0 - self._warncount = 0 - self.warningiserror = False - - self.config.add('autosummary_context', {}, True, None) - self.config.add('autosummary_filename_map', {}, True, None) - self.config.add('autosummary_ignore_module_all', True, 'env', bool) - self.config.add('docs_branch', '', True, None) - self.config.add('branch', '', True, None) - self.config.add('cst_module_name', '', True, None) - self.config.add('copy_repo', '', True, None) - self.config.add('giturl', '', True, None) - self.config.add('repo_whl', '', True, None) - self.config.init_values() - - def emit_firstresult(self, *args: Any) -> None: - pass - - -class AutosummaryEntry(NamedTuple): - name: str - path: str - template: str - recursive: bool - - -def setup_documenters(app: Any) -> None: - from sphinx.ext.autodoc import (AttributeDocumenter, ClassDocumenter, DataDocumenter, - DecoratorDocumenter, ExceptionDocumenter, - FunctionDocumenter, MethodDocumenter, ModuleDocumenter, - NewTypeAttributeDocumenter, NewTypeDataDocumenter, - PropertyDocumenter) - documenters: List[Type[Documenter]] = [ - ModuleDocumenter, ClassDocumenter, ExceptionDocumenter, DataDocumenter, - FunctionDocumenter, MethodDocumenter, NewTypeAttributeDocumenter, - NewTypeDataDocumenter, AttributeDocumenter, DecoratorDocumenter, PropertyDocumenter, - ] - for documenter in documenters: - app.registry.add_documenter(documenter.objtype, documenter) - - -def _simple_info(msg: str) -> None: - warnings.warn('_simple_info() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - print(msg) - - -def _simple_warn(msg: str) -> None: - warnings.warn('_simple_warn() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - print('WARNING: ' + msg, file=sys.stderr) - - -def _underline(title: str, line: str = '=') -> str: - if '\n' in title: - raise ValueError('Can only underline single lines') - return title + '\n' + line * len(title) - - -class AutosummaryRenderer: - """A helper class for rendering.""" - - def __init__(self, app: Union[Builder, Sphinx], template_dir: str = None) -> None: - if isinstance(app, Builder): - warnings.warn('The first argument for AutosummaryRenderer has been ' - 'changed to Sphinx object', - RemovedInSphinx50Warning, stacklevel=2) - if template_dir: - warnings.warn('template_dir argument for AutosummaryRenderer is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - system_templates_path = [os.path.join(package_dir, 'ext', 'autosummary', 'templates')] - loader = SphinxTemplateLoader(app.srcdir, app.config.templates_path, - system_templates_path) - - self.env = SandboxedEnvironment(loader=loader) - self.env.filters['escape'] = rst.escape - self.env.filters['e'] = rst.escape - self.env.filters['underline'] = _underline - - if isinstance(app, (Sphinx, DummyApplication)): - if app.translator: - self.env.add_extension("jinja2.ext.i18n") - self.env.install_gettext_translations(app.translator) - elif isinstance(app, Builder): - if app.app.translator: - self.env.add_extension("jinja2.ext.i18n") - self.env.install_gettext_translations(app.app.translator) - - def exists(self, template_name: str) -> bool: - """Check if template file exists.""" - warnings.warn('AutosummaryRenderer.exists() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - try: - self.env.get_template(template_name) - return True - except TemplateNotFound: - return False - - def render(self, template_name: str, context: Dict) -> str: - """Render a template file.""" - try: - template = self.env.get_template(template_name) - except TemplateNotFound: - try: - # objtype is given as template_name - template = self.env.get_template('autosummary/%s.rst' % template_name) - except TemplateNotFound: - # fallback to base.rst - template = self.env.get_template('autosummary/base.rst') - - return template.render(context) - - -# -- Generating output --------------------------------------------------------- - - -class ModuleScanner: - def __init__(self, app: Any, obj: Any) -> None: - self.app = app - self.object = obj - - def get_object_type(self, name: str, value: Any) -> str: - return get_documenter(self.app, value, self.object).objtype - - def is_skipped(self, name: str, value: Any, objtype: str) -> bool: - try: - return self.app.emit_firstresult('autodoc-skip-member', objtype, - name, value, False, {}) - except Exception as exc: - logger.warning(__('autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s'), - name, exc, type='autosummary') - return False - - def scan(self, imported_members: bool) -> List[str]: - members = [] - for name in members_of(self.object, self.app.config): - try: - value = safe_getattr(self.object, name) - except AttributeError: - value = None - - objtype = self.get_object_type(name, value) - if self.is_skipped(name, value, objtype): - continue - - try: - if inspect.ismodule(value): - imported = True - elif safe_getattr(value, '__module__') != self.object.__name__: - imported = True - else: - imported = False - except AttributeError: - imported = False - - respect_module_all = not self.app.config.autosummary_ignore_module_all - if imported_members: - # list all members up - members.append(name) - elif imported is False: - # list not-imported members - members.append(name) - elif '__all__' in dir(self.object) and respect_module_all: - # list members that have __all__ set - members.append(name) - - return members - - -def members_of(obj: Any, conf: Config) -> Sequence[str]: - """Get the members of ``obj``, possibly ignoring the ``__all__`` module attribute - - Follows the ``conf.autosummary_ignore_module_all`` setting.""" - - if conf.autosummary_ignore_module_all: - return dir(obj) - else: - return getall(obj) or dir(obj) - - -def generate_autosummary_content(name: str, obj: Any, parent: Any, - template: AutosummaryRenderer, template_name: str, - imported_members: bool, app: Any, - recursive: bool, context: Dict, - modname: str = None, qualname: str = None) -> str: - doc = get_documenter(app, obj, parent) - - def skip_member(obj: Any, name: str, objtype: str) -> bool: - try: - return app.emit_firstresult('autodoc-skip-member', objtype, name, - obj, False, {}) - except Exception as exc: - logger.warning(__('autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s'), - name, exc, type='autosummary') - return False - - def get_class_members(obj: Any) -> Dict[str, Any]: - members = sphinx.ext.autodoc.get_class_members(obj, [qualname], safe_getattr) - return {name: member.object for name, member in members.items()} - - def get_module_members(obj: Any) -> Dict[str, Any]: - members = {} - for name in members_of(obj, app.config): - try: - members[name] = safe_getattr(obj, name) - except AttributeError: - continue - return members - - def get_all_members(obj: Any) -> Dict[str, Any]: - if doc.objtype == "module": - return get_module_members(obj) - elif doc.objtype == "class": - return get_class_members(obj) - return {} - - def get_members(obj: Any, types: Set[str], include_public: List[str] = [], - imported: bool = True) -> Tuple[List[str], List[str]]: - items: List[str] = [] - public: List[str] = [] - - all_members = get_all_members(obj) - for name, value in all_members.items(): - documenter = get_documenter(app, value, obj) - if documenter.objtype in types: - # skip imported members if expected - if imported or getattr(value, '__module__', None) == obj.__name__: - skipped = skip_member(value, name, documenter.objtype) - if skipped is True: - pass - elif skipped is False: - # show the member forcedly - items.append(name) - public.append(name) - else: - items.append(name) - if name in include_public or not name.startswith('_'): - # considers member as public - public.append(name) - return public, items - - def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: - """Find module attributes with docstrings.""" - attrs, public = [], [] - try: - analyzer = ModuleAnalyzer.for_module(name) - attr_docs = analyzer.find_attr_docs() - for namespace, attr_name in attr_docs: - if namespace == '' and attr_name in members: - attrs.append(attr_name) - if not attr_name.startswith('_'): - public.append(attr_name) - except PycodeError: - pass # give up if ModuleAnalyzer fails to parse code - return public, attrs - - def get_modules(obj: Any) -> Tuple[List[str], List[str]]: - items: List[str] = [] - for _, modname, _ispkg in pkgutil.iter_modules(obj.__path__): - fullname = name + '.' + modname - try: - module = import_module(fullname) - if module and hasattr(module, '__sphinx_mock__'): - continue - except ImportError: - pass - - items.append(fullname) - public = [x for x in items if not x.split('.')[-1].startswith('_')] - return public, items - - ns: Dict[str, Any] = {} - ns.update(context) - - if doc.objtype == 'module': - scanner = ModuleScanner(app, obj) - ns['members'] = scanner.scan(imported_members) - ns['functions'], ns['all_functions'] = \ - get_members(obj, {'function'}, imported=imported_members) - ns['classes'], ns['all_classes'] = \ - get_members(obj, {'class'}, imported=imported_members) - ns['exceptions'], ns['all_exceptions'] = \ - get_members(obj, {'exception'}, imported=imported_members) - ns['attributes'], ns['all_attributes'] = \ - get_module_attrs(ns['members']) - ispackage = hasattr(obj, '__path__') - if ispackage and recursive: - ns['modules'], ns['all_modules'] = get_modules(obj) - elif doc.objtype == 'class': - ns['members'] = dir(obj) - ns['inherited_members'] = \ - set(dir(obj)) - set(obj.__dict__.keys()) - ns['methods'], ns['all_methods'] = \ - get_members(obj, {'method'}, ['__init__']) - ns['attributes'], ns['all_attributes'] = \ - get_members(obj, {'attribute', 'property'}) - - if modname is None or qualname is None: - modname, qualname = split_full_qualified_name(name) - - if doc.objtype in ('method', 'attribute', 'property'): - ns['class'] = qualname.rsplit(".", 1)[0] - - if doc.objtype in ('class',): - shortname = qualname - else: - shortname = qualname.rsplit(".", 1)[-1] - - ns['fullname'] = name - ns['module'] = modname - ns['objname'] = qualname - ns['name'] = shortname - - ns['objtype'] = doc.objtype - ns['underline'] = len(name) * '=' - - if template_name: - return template.render(template_name, ns) - else: - return template.render(doc.objtype, ns) - - -def generate_autosummary_docs(sources: List[str], output_dir: str = None, - suffix: str = '.rst', base_path: str = None, - builder: Builder = None, template_dir: str = None, - imported_members: bool = False, app: Any = None, - overwrite: bool = True, encoding: str = 'utf-8') -> None: - - if builder: - warnings.warn('builder argument for generate_autosummary_docs() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - if template_dir: - warnings.warn('template_dir argument for generate_autosummary_docs() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - showed_sources = list(sorted(sources)) - if len(showed_sources) > 20: - showed_sources = showed_sources[:10] + ['...'] + showed_sources[-10:] - logger.info(__('[autosummary] generating autosummary for: %s') % - ', '.join(showed_sources)) - - if output_dir: - logger.info(__('[autosummary] writing to %s') % output_dir) - - if base_path is not None: - sources = [os.path.join(base_path, filename) for filename in sources] - - template = AutosummaryRenderer(app) - - # read - items = find_autosummary_in_files(sources) - - # keep track of new files - new_files = [] - - if app: - filename_map = app.config.autosummary_filename_map - else: - filename_map = {} - - # write - for entry in sorted(set(items), key=str): - if entry.path is None: - # The corresponding autosummary:: directive did not have - # a :toctree: option - continue - - path = output_dir or os.path.abspath(entry.path) - ensuredir(path) - - try: - name, obj, parent, modname = import_by_name(entry.name, grouped_exception=True) - qualname = name.replace(modname + ".", "") - except ImportExceptionGroup as exc: - try: - # try to import as an instance attribute - name, obj, parent, modname = import_ivar_by_name(entry.name) - qualname = name.replace(modname + ".", "") - except ImportError as exc2: - if exc2.__cause__: - exceptions: List[BaseException] = exc.exceptions + [exc2.__cause__] - else: - exceptions = exc.exceptions + [exc2] - - errors = list(set("* %s: %s" % (type(e).__name__, e) for e in exceptions)) - logger.warning(__('[autosummary] failed to import %s.\nPossible hints:\n%s'), - entry.name, '\n'.join(errors)) - continue - - context: Dict[str, Any] = {} - if app: - context.update(app.config.autosummary_context) - - content = generate_autosummary_content(name, obj, parent, template, entry.template, - imported_members, app, entry.recursive, context, - modname, qualname) - try: - py_source_rel = get_full_modname(modname, qualname).replace('.', '/') + '.py' - except: - logger.warning(name) - py_source_rel = '' - - re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{app.config.docs_branch}/" + \ - f"resource/_static/logo_source_en.svg\n :target: " + app.config.giturl + \ - f"{app.config.copy_repo}/blob/{app.config.branch}/" + app.config.repo_whl + \ - py_source_rel.split(app.config.cst_module_name)[-1] + '\n :alt: View Source On Gitee\n\n' - - if re_view not in content and py_source_rel: - content = re.sub('([=]{5,})\n', r'\1\n' + re_view, content, 1) - filename = os.path.join(path, filename_map.get(name, name) + suffix) - if os.path.isfile(filename): - with open(filename, encoding=encoding) as f: - old_content = f.read() - - if content == old_content: - continue - elif overwrite: # content has changed - with open(filename, 'w', encoding=encoding) as f: - f.write(content) - new_files.append(filename) - else: - with open(filename, 'w', encoding=encoding) as f: - f.write(content) - new_files.append(filename) - - # descend recursively to new files - if new_files: - generate_autosummary_docs(new_files, output_dir=output_dir, - suffix=suffix, base_path=base_path, - builder=builder, template_dir=template_dir, - imported_members=imported_members, app=app, - overwrite=overwrite) - - -# -- Finding documented entries in files --------------------------------------- - -def find_autosummary_in_files(filenames: List[str]) -> List[AutosummaryEntry]: - """Find out what items are documented in source/*.rst. - - See `find_autosummary_in_lines`. - """ - documented: List[AutosummaryEntry] = [] - for filename in filenames: - with open(filename, encoding='utf-8', errors='ignore') as f: - lines = f.read().splitlines() - documented.extend(find_autosummary_in_lines(lines, filename=filename)) - return documented - - -def find_autosummary_in_docstring(name: str, module: str = None, filename: str = None - ) -> List[AutosummaryEntry]: - """Find out what items are documented in the given object's docstring. - - See `find_autosummary_in_lines`. - """ - if module: - warnings.warn('module argument for find_autosummary_in_docstring() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - try: - real_name, obj, parent, modname = import_by_name(name, grouped_exception=True) - lines = pydoc.getdoc(obj).splitlines() - return find_autosummary_in_lines(lines, module=name, filename=filename) - except AttributeError: - pass - except ImportExceptionGroup as exc: - errors = list(set("* %s: %s" % (type(e).__name__, e) for e in exc.exceptions)) - print('Failed to import %s.\nPossible hints:\n%s' % (name, '\n'.join(errors))) - except SystemExit: - print("Failed to import '%s'; the module executes module level " - "statement and it might call sys.exit()." % name) - return [] - - -def find_autosummary_in_lines(lines: List[str], module: str = None, filename: str = None - ) -> List[AutosummaryEntry]: - """Find out what items appear in autosummary:: directives in the - given lines. - - Returns a list of (name, toctree, template) where *name* is a name - of an object and *toctree* the :toctree: path of the corresponding - autosummary directive (relative to the root of the file name), and - *template* the value of the :template: option. *toctree* and - *template* ``None`` if the directive does not have the - corresponding options set. - """ - autosummary_re = re.compile(r'^(\s*)\.\.\s+(ms[a-z]*)?autosummary::\s*') - automodule_re = re.compile( - r'^\s*\.\.\s+automodule::\s*([A-Za-z0-9_.]+)\s*$') - module_re = re.compile( - r'^\s*\.\.\s+(current)?module::\s*([a-zA-Z0-9_.]+)\s*$') - autosummary_item_re = re.compile(r'^\s+(~?[_a-zA-Z][a-zA-Z0-9_.]*)\s*.*?') - recursive_arg_re = re.compile(r'^\s+:recursive:\s*$') - toctree_arg_re = re.compile(r'^\s+:toctree:\s*(.*?)\s*$') - template_arg_re = re.compile(r'^\s+:template:\s*(.*?)\s*$') - - documented: List[AutosummaryEntry] = [] - - recursive = False - toctree: str = None - template = None - current_module = module - in_autosummary = False - base_indent = "" - - for line in lines: - if in_autosummary: - m = recursive_arg_re.match(line) - if m: - recursive = True - continue - - m = toctree_arg_re.match(line) - if m: - toctree = m.group(1) - if filename: - toctree = os.path.join(os.path.dirname(filename), - toctree) - continue - - m = template_arg_re.match(line) - if m: - template = m.group(1).strip() - continue - - if line.strip().startswith(':'): - continue # skip options - - m = autosummary_item_re.match(line) - if m: - name = m.group(1).strip() - if name.startswith('~'): - name = name[1:] - if current_module and \ - not name.startswith(current_module + '.'): - name = "%s.%s" % (current_module, name) - documented.append(AutosummaryEntry(name, toctree, template, recursive)) - continue - - if not line.strip() or line.startswith(base_indent + " "): - continue - - in_autosummary = False - - m = autosummary_re.match(line) - if m: - in_autosummary = True - base_indent = m.group(1) - recursive = False - toctree = None - template = None - continue - - m = automodule_re.search(line) - if m: - current_module = m.group(1).strip() - # recurse into the automodule docstring - documented.extend(find_autosummary_in_docstring( - current_module, filename=filename)) - continue - - m = module_re.match(line) - if m: - current_module = m.group(2) - continue - - return documented - - -def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - usage='%(prog)s [OPTIONS] ...', - epilog=__('For more information, visit .'), - description=__(""" -Generate ReStructuredText using autosummary directives. - -sphinx-autogen is a frontend to sphinx.ext.autosummary.generate. It generates -the reStructuredText files from the autosummary directives contained in the -given input files. - -The format of the autosummary directive is documented in the -``sphinx.ext.autosummary`` Python module and can be read using:: - - pydoc sphinx.ext.autosummary -""")) - - parser.add_argument('--version', action='version', dest='show_version', - version='%%(prog)s %s' % __display_version__) - - parser.add_argument('source_file', nargs='+', - help=__('source files to generate rST files for')) - - parser.add_argument('-o', '--output-dir', action='store', - dest='output_dir', - help=__('directory to place all output in')) - parser.add_argument('-s', '--suffix', action='store', dest='suffix', - default='rst', - help=__('default suffix for files (default: ' - '%(default)s)')) - parser.add_argument('-t', '--templates', action='store', dest='templates', - default=None, - help=__('custom template directory (default: ' - '%(default)s)')) - parser.add_argument('-i', '--imported-members', action='store_true', - dest='imported_members', default=False, - help=__('document imported members (default: ' - '%(default)s)')) - parser.add_argument('-a', '--respect-module-all', action='store_true', - dest='respect_module_all', default=False, - help=__('document exactly the members in module __all__ attribute. ' - '(default: %(default)s)')) - - return parser - - -def main(argv: List[str] = sys.argv[1:]) -> None: - sphinx.locale.setlocale(locale.LC_ALL, '') - sphinx.locale.init_console(os.path.join(package_dir, 'locale'), 'sphinx') - translator, _ = sphinx.locale.init([], None) - - app = DummyApplication(translator) - logging.setup(app, sys.stdout, sys.stderr) # type: ignore - setup_documenters(app) - args = get_parser().parse_args(argv) - - if args.templates: - app.config.templates_path.append(path.abspath(args.templates)) - app.config.autosummary_ignore_module_all = not args.respect_module_all # type: ignore - - generate_autosummary_docs(args.source_file, args.output_dir, - '.' + args.suffix, - imported_members=args.imported_members, - app=app) - - -if __name__ == '__main__': - main() diff --git a/docs/graphlearning/docs/_ext/overwriteobjectiondirective.txt b/docs/graphlearning/docs/_ext/overwriteobjectiondirective.txt deleted file mode 100644 index 8a58bf71191f77ca22097ea9de244c9df5c3d4fb..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/_ext/overwriteobjectiondirective.txt +++ /dev/null @@ -1,368 +0,0 @@ -""" - sphinx.directives - ~~~~~~~~~~~~~~~~~ - - Handlers for additional ReST directives. - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import re -import inspect -import importlib -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Tuple, TypeVar, cast - -from docutils import nodes -from docutils.nodes import Node -from docutils.parsers.rst import directives, roles - -from sphinx import addnodes -from sphinx.addnodes import desc_signature -from sphinx.deprecation import RemovedInSphinx50Warning, deprecated_alias -from sphinx.util import docutils, logging -from sphinx.util.docfields import DocFieldTransformer, Field, TypedField -from sphinx.util.docutils import SphinxDirective -from sphinx.util.typing import OptionSpec - -if TYPE_CHECKING: - from sphinx.application import Sphinx - - -# RE to strip backslash escapes -nl_escape_re = re.compile(r'\\\n') -strip_backslash_re = re.compile(r'\\(.)') - -T = TypeVar('T') -logger = logging.getLogger(__name__) - -def optional_int(argument: str) -> int: - """ - Check for an integer argument or None value; raise ``ValueError`` if not. - """ - if argument is None: - return None - else: - value = int(argument) - if value < 0: - raise ValueError('negative value; must be positive or zero') - return value - -def get_api(fullname): - try: - module_name, api_name= ".".join(fullname.split('.')[:-1]), fullname.split('.')[-1] - module_import = importlib.import_module(module_name) - except ModuleNotFoundError: - module_name, api_name = ".".join(fullname.split('.')[:-2]), ".".join(fullname.split('.')[-2:]) - module_import = importlib.import_module(module_name) - api = eval(f"module_import.{api_name}") - return api - -def get_example(name: str): - try: - api_doc = inspect.getdoc(get_api(name)) - example_str = re.findall(r'Examples:\n([\w\W]*?)(\n\n|$)', api_doc) - if not example_str: - return [] - example_str = re.sub(r'\n\s+', r'\n', example_str[0][0]) - example_str = example_str.strip() - example_list = example_str.split('\n') - return ["", "**样例:**", ""] + example_list + [""] - except: - return [] - -def get_platforms(name: str): - try: - api_doc = inspect.getdoc(get_api(name)) - example_str = re.findall(r'Supported Platforms:\n\s+(.*?)\n\n', api_doc) - if not example_str: - example_str_leak = re.findall(r'Supported Platforms:\n\s+(.*)', api_doc) - if example_str_leak: - example_str = example_str_leak[0].strip() - example_list = example_str.split('\n') - example_list = [' ' + example_list[0]] - return ["", "支持平台:"] + example_list + [""] - return [] - example_str = example_str[0].strip() - example_list = example_str.split('\n') - example_list = [' ' + example_list[0]] - return ["", "支持平台:"] + example_list + [""] - except: - return [] - -class ObjectDescription(SphinxDirective, Generic[T]): - """ - Directive to describe a class, function or similar object. Not used - directly, but subclassed (in domain-specific directives) to add custom - behavior. - """ - - has_content = True - required_arguments = 1 - optional_arguments = 0 - final_argument_whitespace = True - option_spec: OptionSpec = { - 'noindex': directives.flag, - } # type: Dict[str, DirectiveOption] - - # types of doc fields that this directive handles, see sphinx.util.docfields - doc_field_types: List[Field] = [] - domain: str = None - objtype: str = None - indexnode: addnodes.index = None - - # Warning: this might be removed in future version. Don't touch this from extensions. - _doc_field_type_map = {} # type: Dict[str, Tuple[Field, bool]] - - def get_field_type_map(self) -> Dict[str, Tuple[Field, bool]]: - if self._doc_field_type_map == {}: - self._doc_field_type_map = {} - for field in self.doc_field_types: - for name in field.names: - self._doc_field_type_map[name] = (field, False) - - if field.is_typed: - typed_field = cast(TypedField, field) - for name in typed_field.typenames: - self._doc_field_type_map[name] = (field, True) - - return self._doc_field_type_map - - def get_signatures(self) -> List[str]: - """ - Retrieve the signatures to document from the directive arguments. By - default, signatures are given as arguments, one per line. - - Backslash-escaping of newlines is supported. - """ - lines = nl_escape_re.sub('', self.arguments[0]).split('\n') - if self.config.strip_signature_backslash: - # remove backslashes to support (dummy) escapes; helps Vim highlighting - return [strip_backslash_re.sub(r'\1', line.strip()) for line in lines] - else: - return [line.strip() for line in lines] - - def handle_signature(self, sig: str, signode: desc_signature) -> Any: - """ - Parse the signature *sig* into individual nodes and append them to - *signode*. If ValueError is raised, parsing is aborted and the whole - *sig* is put into a single desc_name node. - - The return value should be a value that identifies the object. It is - passed to :meth:`add_target_and_index()` unchanged, and otherwise only - used to skip duplicates. - """ - raise ValueError - - def add_target_and_index(self, name: Any, sig: str, signode: desc_signature) -> None: - """ - Add cross-reference IDs and entries to self.indexnode, if applicable. - - *name* is whatever :meth:`handle_signature()` returned. - """ - return # do nothing by default - - def before_content(self) -> None: - """ - Called before parsing content. Used to set information about the current - directive context on the build environment. - """ - pass - - def transform_content(self, contentnode: addnodes.desc_content) -> None: - """ - Called after creating the content through nested parsing, - but before the ``object-description-transform`` event is emitted, - and before the info-fields are transformed. - Can be used to manipulate the content. - """ - pass - - def after_content(self) -> None: - """ - Called after parsing content. Used to reset information about the - current directive context on the build environment. - """ - pass - - def check_class_end(self, content): - for i in content: - if not i.startswith('.. include::') and i != "\n" and i != "": - return False - return True - - def extend_items(self, rst_file, start_num, num): - ls = [] - for i in range(1, num+1): - ls.append((rst_file, start_num+i)) - return ls - - def run(self) -> List[Node]: - """ - Main directive entry function, called by docutils upon encountering the - directive. - - This directive is meant to be quite easily subclassable, so it delegates - to several additional methods. What it does: - - * find out if called as a domain-specific directive, set self.domain - * create a `desc` node to fit all description inside - * parse standard options, currently `noindex` - * create an index node if needed as self.indexnode - * parse all given signatures (as returned by self.get_signatures()) - using self.handle_signature(), which should either return a name - or raise ValueError - * add index entries using self.add_target_and_index() - * parse the content and handle doc fields in it - """ - if ':' in self.name: - self.domain, self.objtype = self.name.split(':', 1) - else: - self.domain, self.objtype = '', self.name - self.indexnode = addnodes.index(entries=[]) - - node = addnodes.desc() - node.document = self.state.document - node['domain'] = self.domain - # 'desctype' is a backwards compatible attribute - node['objtype'] = node['desctype'] = self.objtype - node['noindex'] = noindex = ('noindex' in self.options) - if self.domain: - node['classes'].append(self.domain) - node['classes'].append(node['objtype']) - - self.names: List[T] = [] - signatures = self.get_signatures() - for sig in signatures: - # add a signature node for each signature in the current unit - # and add a reference target for it - signode = addnodes.desc_signature(sig, '') - self.set_source_info(signode) - node.append(signode) - try: - # name can also be a tuple, e.g. (classname, objname); - # this is strictly domain-specific (i.e. no assumptions may - # be made in this base class) - name = self.handle_signature(sig, signode) - except ValueError: - # signature parsing failed - signode.clear() - signode += addnodes.desc_name(sig, sig) - continue # we don't want an index entry here - if name not in self.names: - self.names.append(name) - if not noindex: - # only add target and index entry if this is the first - # description of the object with this name in this desc block - self.add_target_and_index(name, sig, signode) - - contentnode = addnodes.desc_content() - node.append(contentnode) - if self.names: - # needed for association of version{added,changed} directives - self.env.temp_data['object'] = self.names[0] - self.before_content() - try: - example = get_example(self.names[0][0]) - platforms = get_platforms(self.names[0][0]) - except Exception as e: - example = '' - platforms = '' - logger.warning(f'Error API names in {self.arguments[0]}.') - logger.warning(f'{e}') - extra = platforms + example - if extra: - if self.objtype == "method": - self.content.data.extend(extra) - else: - index_num = 0 - for num, i in enumerate(self.content.data): - if i.startswith('.. py:method::') or self.check_class_end(self.content.data[num:]): - index_num = num - break - if index_num: - count = len(self.content.data) - for i in extra: - self.content.data.insert(index_num-count, i) - else: - self.content.data.extend(extra) - try: - self.content.items.extend(self.extend_items(self.content.items[0][0], self.content.items[-1][1], len(extra))) - except Exception as e: - logger.warning(f'{e}') - self.state.nested_parse(self.content, self.content_offset, contentnode) - self.transform_content(contentnode) - self.env.app.emit('object-description-transform', - self.domain, self.objtype, contentnode) - DocFieldTransformer(self).transform_all(contentnode) - self.env.temp_data['object'] = None - self.after_content() - return [self.indexnode, node] - - -class DefaultRole(SphinxDirective): - """ - Set the default interpreted text role. Overridden from docutils. - """ - - optional_arguments = 1 - final_argument_whitespace = False - - def run(self) -> List[Node]: - if not self.arguments: - docutils.unregister_role('') - return [] - role_name = self.arguments[0] - role, messages = roles.role(role_name, self.state_machine.language, - self.lineno, self.state.reporter) - if role: - docutils.register_role('', role) - self.env.temp_data['default_role'] = role_name - else: - literal_block = nodes.literal_block(self.block_text, self.block_text) - reporter = self.state.reporter - error = reporter.error('Unknown interpreted text role "%s".' % role_name, - literal_block, line=self.lineno) - messages += [error] - - return cast(List[nodes.Node], messages) - - -class DefaultDomain(SphinxDirective): - """ - Directive to (re-)set the default domain for this source file. - """ - - has_content = False - required_arguments = 1 - optional_arguments = 0 - final_argument_whitespace = False - option_spec = {} # type: Dict - - def run(self) -> List[Node]: - domain_name = self.arguments[0].lower() - # if domain_name not in env.domains: - # # try searching by label - # for domain in env.domains.values(): - # if domain.label.lower() == domain_name: - # domain_name = domain.name - # break - self.env.temp_data['default_domain'] = self.env.domains.get(domain_name) - return [] - -def setup(app: "Sphinx") -> Dict[str, Any]: - app.add_config_value("strip_signature_backslash", False, 'env') - directives.register_directive('default-role', DefaultRole) - directives.register_directive('default-domain', DefaultDomain) - directives.register_directive('describe', ObjectDescription) - # new, more consistent, name - directives.register_directive('object', ObjectDescription) - - app.add_event('object-description-transform') - - return { - 'version': 'builtin', - 'parallel_read_safe': True, - 'parallel_write_safe': True, - } - diff --git a/docs/graphlearning/docs/_ext/overwriteviewcode.txt b/docs/graphlearning/docs/_ext/overwriteviewcode.txt deleted file mode 100644 index 172780ec56b3ed90e7b0add617257a618cf38ee0..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/_ext/overwriteviewcode.txt +++ /dev/null @@ -1,378 +0,0 @@ -""" - sphinx.ext.viewcode - ~~~~~~~~~~~~~~~~~~~ - - Add links to module code in Python object descriptions. - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import posixpath -import traceback -import warnings -from os import path -from typing import Any, Dict, Generator, Iterable, Optional, Set, Tuple, cast - -from docutils import nodes -from docutils.nodes import Element, Node - -import sphinx -from sphinx import addnodes -from sphinx.application import Sphinx -from sphinx.builders import Builder -from sphinx.builders.html import StandaloneHTMLBuilder -from sphinx.deprecation import RemovedInSphinx50Warning -from sphinx.environment import BuildEnvironment -from sphinx.locale import _, __ -from sphinx.pycode import ModuleAnalyzer -from sphinx.transforms.post_transforms import SphinxPostTransform -from sphinx.util import get_full_modname, logging, status_iterator -from sphinx.util.nodes import make_refnode - - -logger = logging.getLogger(__name__) - - -OUTPUT_DIRNAME = '_modules' - - -class viewcode_anchor(Element): - """Node for viewcode anchors. - - This node will be processed in the resolving phase. - For viewcode supported builders, they will be all converted to the anchors. - For not supported builders, they will be removed. - """ - - -def _get_full_modname(app: Sphinx, modname: str, attribute: str) -> Optional[str]: - try: - return get_full_modname(modname, attribute) - except AttributeError: - # sphinx.ext.viewcode can't follow class instance attribute - # then AttributeError logging output only verbose mode. - logger.verbose('Didn\'t find %s in %s', attribute, modname) - return None - except Exception as e: - # sphinx.ext.viewcode follow python domain directives. - # because of that, if there are no real modules exists that specified - # by py:function or other directives, viewcode emits a lot of warnings. - # It should be displayed only verbose mode. - logger.verbose(traceback.format_exc().rstrip()) - logger.verbose('viewcode can\'t import %s, failed with error "%s"', modname, e) - return None - - -def is_supported_builder(builder: Builder) -> bool: - if builder.format != 'html': - return False - elif builder.name == 'singlehtml': - return False - elif builder.name.startswith('epub') and not builder.config.viewcode_enable_epub: - return False - else: - return True - - -def doctree_read(app: Sphinx, doctree: Node) -> None: - env = app.builder.env - if not hasattr(env, '_viewcode_modules'): - env._viewcode_modules = {} # type: ignore - - def has_tag(modname: str, fullname: str, docname: str, refname: str) -> bool: - entry = env._viewcode_modules.get(modname, None) # type: ignore - if entry is False: - return False - - code_tags = app.emit_firstresult('viewcode-find-source', modname) - if code_tags is None: - try: - analyzer = ModuleAnalyzer.for_module(modname) - analyzer.find_tags() - except Exception: - env._viewcode_modules[modname] = False # type: ignore - return False - - code = analyzer.code - tags = analyzer.tags - else: - code, tags = code_tags - - if entry is None or entry[0] != code: - entry = code, tags, {}, refname - env._viewcode_modules[modname] = entry # type: ignore - _, tags, used, _ = entry - if fullname in tags: - used[fullname] = docname - return True - - return False - - for objnode in list(doctree.findall(addnodes.desc)): - if objnode.get('domain') != 'py': - continue - names: Set[str] = set() - for signode in objnode: - if not isinstance(signode, addnodes.desc_signature): - continue - modname = signode.get('module') - fullname = signode.get('fullname') - try: - if fullname and modname==None: - if fullname.split('.')[-1].lower() == fullname.split('.')[-1] and fullname.split('.')[-2].lower() != fullname.split('.')[-2]: - modname = '.'.join(fullname.split('.')[:-2]) - fullname = '.'.join(fullname.split('.')[-2:]) - else: - modname = '.'.join(fullname.split('.')[:-1]) - fullname = fullname.split('.')[-1] - fullname_new = fullname - except Exception: - logger.warning(f'error_modename:{modname}') - logger.warning(f'error_fullname:{fullname}') - refname = modname - if env.config.viewcode_follow_imported_members: - new_modname = app.emit_firstresult( - 'viewcode-follow-imported', modname, fullname, - ) - if not new_modname: - new_modname = _get_full_modname(app, modname, fullname) - modname = new_modname - # logger.warning(f'new_modename:{modname}') - if not modname: - continue - # fullname = signode.get('fullname') - # if fullname and modname==None: - fullname = fullname_new - if not has_tag(modname, fullname, env.docname, refname): - continue - if fullname in names: - # only one link per name, please - continue - names.add(fullname) - pagename = posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/')) - signode += viewcode_anchor(reftarget=pagename, refid=fullname, refdoc=env.docname) - - -def env_merge_info(app: Sphinx, env: BuildEnvironment, docnames: Iterable[str], - other: BuildEnvironment) -> None: - if not hasattr(other, '_viewcode_modules'): - return - # create a _viewcode_modules dict on the main environment - if not hasattr(env, '_viewcode_modules'): - env._viewcode_modules = {} # type: ignore - # now merge in the information from the subprocess - for modname, entry in other._viewcode_modules.items(): # type: ignore - if modname not in env._viewcode_modules: # type: ignore - env._viewcode_modules[modname] = entry # type: ignore - else: - if env._viewcode_modules[modname]: # type: ignore - used = env._viewcode_modules[modname][2] # type: ignore - for fullname, docname in entry[2].items(): - if fullname not in used: - used[fullname] = docname - - -def env_purge_doc(app: Sphinx, env: BuildEnvironment, docname: str) -> None: - modules = getattr(env, '_viewcode_modules', {}) - - for modname, entry in list(modules.items()): - if entry is False: - continue - - code, tags, used, refname = entry - for fullname in list(used): - if used[fullname] == docname: - used.pop(fullname) - - if len(used) == 0: - modules.pop(modname) - - -class ViewcodeAnchorTransform(SphinxPostTransform): - """Convert or remove viewcode_anchor nodes depends on builder.""" - default_priority = 100 - - def run(self, **kwargs: Any) -> None: - if is_supported_builder(self.app.builder): - self.convert_viewcode_anchors() - else: - self.remove_viewcode_anchors() - - def convert_viewcode_anchors(self) -> None: - for node in self.document.findall(viewcode_anchor): - anchor = nodes.inline('', _('[源代码]'), classes=['viewcode-link']) - refnode = make_refnode(self.app.builder, node['refdoc'], node['reftarget'], - node['refid'], anchor) - node.replace_self(refnode) - - def remove_viewcode_anchors(self) -> None: - for node in list(self.document.findall(viewcode_anchor)): - node.parent.remove(node) - - -def missing_reference(app: Sphinx, env: BuildEnvironment, node: Element, contnode: Node - ) -> Optional[Node]: - # resolve our "viewcode" reference nodes -- they need special treatment - if node['reftype'] == 'viewcode': - warnings.warn('viewcode extension is no longer use pending_xref node. ' - 'Please update your extension.', RemovedInSphinx50Warning) - return make_refnode(app.builder, node['refdoc'], node['reftarget'], - node['refid'], contnode) - - return None - - -def get_module_filename(app: Sphinx, modname: str) -> Optional[str]: - """Get module filename for *modname*.""" - source_info = app.emit_firstresult('viewcode-find-source', modname) - if source_info: - return None - else: - try: - filename, source = ModuleAnalyzer.get_module_source(modname) - return filename - except Exception: - return None - - -def should_generate_module_page(app: Sphinx, modname: str) -> bool: - """Check generation of module page is needed.""" - module_filename = get_module_filename(app, modname) - if module_filename is None: - # Always (re-)generate module page when module filename is not found. - return True - - builder = cast(StandaloneHTMLBuilder, app.builder) - basename = modname.replace('.', '/') + builder.out_suffix - page_filename = path.join(app.outdir, '_modules/', basename) - - try: - if path.getmtime(module_filename) <= path.getmtime(page_filename): - # generation is not needed if the HTML page is newer than module file. - return False - except IOError: - pass - - return True - - -def collect_pages(app: Sphinx) -> Generator[Tuple[str, Dict[str, Any], str], None, None]: - env = app.builder.env - if not hasattr(env, '_viewcode_modules'): - return - if not is_supported_builder(app.builder): - return - highlighter = app.builder.highlighter # type: ignore - urito = app.builder.get_relative_uri - - modnames = set(env._viewcode_modules) # type: ignore - - for modname, entry in status_iterator( - sorted(env._viewcode_modules.items()), # type: ignore - __('highlighting module code... '), "blue", - len(env._viewcode_modules), # type: ignore - app.verbosity, lambda x: x[0]): - if not entry: - continue - if not should_generate_module_page(app, modname): - continue - - code, tags, used, refname = entry - # construct a page name for the highlighted source - pagename = posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/')) - # highlight the source using the builder's highlighter - if env.config.highlight_language in ('python3', 'default', 'none'): - lexer = env.config.highlight_language - else: - lexer = 'python' - highlighted = highlighter.highlight_block(code, lexer, linenos=False) - # split the code into lines - lines = highlighted.splitlines() - # split off wrap markup from the first line of the actual code - before, after = lines[0].split('
    ')
    -        lines[0:1] = [before + '
    ', after]
    -        # nothing to do for the last line; it always starts with 
    anyway - # now that we have code lines (starting at index 1), insert anchors for - # the collected tags (HACK: this only works if the tag boundaries are - # properly nested!) - maxindex = len(lines) - 1 - for name, docname in used.items(): - type, start, end = tags[name] - backlink = urito(pagename, docname) + '#' + refname + '.' + name - lines[start] = ( - '
    %s' % (name, backlink, _('[文档]')) + - lines[start]) - lines[min(end, maxindex)] += '
    ' - # try to find parents (for submodules) - parents = [] - parent = modname - while '.' in parent: - parent = parent.rsplit('.', 1)[0] - if parent in modnames: - parents.append({ - 'link': urito(pagename, - posixpath.join(OUTPUT_DIRNAME, parent.replace('.', '/'))), - 'title': parent}) - parents.append({'link': urito(pagename, posixpath.join(OUTPUT_DIRNAME, 'index')), - 'title': _('Module code')}) - parents.reverse() - # putting it all together - context = { - 'parents': parents, - 'title': modname, - 'body': (_('

    Source code for %s

    ') % modname + - '\n'.join(lines)), - } - yield (pagename, context, 'page.html') - - if not modnames: - return - - html = ['\n'] - # the stack logic is needed for using nested lists for submodules - stack = [''] - for modname in sorted(modnames): - if modname.startswith(stack[-1]): - stack.append(modname + '.') - html.append('
      ') - else: - stack.pop() - while not modname.startswith(stack[-1]): - stack.pop() - html.append('
    ') - stack.append(modname + '.') - html.append('
  • %s
  • \n' % ( - urito(posixpath.join(OUTPUT_DIRNAME, 'index'), - posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/'))), - modname)) - html.append('' * (len(stack) - 1)) - context = { - 'title': _('Overview: module code'), - 'body': (_('

    All modules for which code is available

    ') + - ''.join(html)), - } - - yield (posixpath.join(OUTPUT_DIRNAME, 'index'), context, 'page.html') - - -def setup(app: Sphinx) -> Dict[str, Any]: - app.add_config_value('viewcode_import', None, False) - app.add_config_value('viewcode_enable_epub', False, False) - app.add_config_value('viewcode_follow_imported_members', True, False) - app.connect('doctree-read', doctree_read) - app.connect('env-merge-info', env_merge_info) - app.connect('env-purge-doc', env_purge_doc) - app.connect('html-collect-pages', collect_pages) - app.connect('missing-reference', missing_reference) - # app.add_config_value('viewcode_include_modules', [], 'env') - # app.add_config_value('viewcode_exclude_modules', [], 'env') - app.add_event('viewcode-find-source') - app.add_event('viewcode-follow-imported') - app.add_post_transform(ViewcodeAnchorTransform) - return { - 'version': sphinx.__display_version__, - 'env_version': 1, - 'parallel_read_safe': True - } diff --git a/docs/graphlearning/docs/requirements.txt b/docs/graphlearning/docs/requirements.txt deleted file mode 100644 index feb4cb9d69dd2d5fbb771cba62f1f68bcb0dd908..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -sphinx == 4.4.0 -docutils == 0.17.1 -myst-parser == 0.18.1 -sphinx_rtd_theme == 1.0.0 -numpy -IPython -jieba \ No newline at end of file diff --git a/docs/graphlearning/docs/source_en/_templates/classtemplate.rst b/docs/graphlearning/docs/source_en/_templates/classtemplate.rst deleted file mode 100644 index 37a8e95499c8343ad3f8e02d5c9095215fd9010a..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/_templates/classtemplate.rst +++ /dev/null @@ -1,27 +0,0 @@ -.. role:: hidden - :class: hidden-section - -.. currentmodule:: {{ module }} - -{% if objname in [] %} -{{ fullname | underline }} - -.. autofunction:: {{ fullname }} - -{% elif objname[0].istitle() %} -{{ fullname | underline }} - -.. autoclass:: {{ name }} - :exclude-members: construct - :members: - -{% else %} -{{ fullname | underline }} - -.. autofunction:: {{ fullname }} - -{% endif %} - -.. - autogenerated from _templates/classtemplate.rst - note it does not have :inherited-members: diff --git a/docs/graphlearning/docs/source_en/batched_graph_training_GIN.md b/docs/graphlearning/docs/source_en/batched_graph_training_GIN.md deleted file mode 100644 index 7d55dba9288aaac445b33446e1f6bf7bde2ecbe5..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/batched_graph_training_GIN.md +++ /dev/null @@ -1,278 +0,0 @@ -# Batched Graph Training Network - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_en/batched_graph_training_GIN.md) -   - -## Overview - -In this example, it will show how to classify the social network with Graph Isomorphism Network. - -GIN is inspired by the close connection between GNNs and the Weisfeiler-Lehman (WL) graph isomorphism test, a powerful test known to distinguish a broad class of graphs. GNN can have as large discriminative power as the WL test if the GNN’s aggregation scheme is highly expressive and can model injective functions. - -IMDB-BINARY is a movie collaboration dataset that consists of the ego-networks of 1,000 actors/actresses who played roles in movies in IMDB. In each graph, nodes represent actors/actress, and there is an edge between them if they appear in the same movie. These graphs are derived from the Action and Romance genres. -Get batched graph data from the IMDB-BINARY dataset. Each graph is a movie composed of actors. The GIN is used to classify the graphs and predict the genres of the movie. - -In the batched graph, multiple graphs can be trained at the same time, and the number of nodes/edges of each graph is different. mindspore_gl integrates the sub graph in the batch into a whole graph, and adds a virtual graph to unify the graph data to reduce memory consumption and speed up calculation. - -> Download the complete sample code here: [GIN](https://gitee.com/mindspore/graphlearning/tree/master/model_zoo/gin). - -## GIN Principles - -Paper: [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf) - -## Defining a Network Model - -GINConv parses graph `g` into `BatchedGraph`, and `BatchedGraph` can support more graph operations than `Graph`. The input data is the whole graph, but when updating the node features of each subgraph, it can still find the corresponding neighbor nodes according to its own nodes, and will not connect to the nodes of other subgraphs. - -mindspore_gl.nn implements GINConv, which can be directly imported for use. The code for implementing a multi-layer GinNet network using GINConv, batch normalization, and pooling is as follows: - -```python -class GinNet(GNNCell): - """GIN net""" - def __init__(self, - num_layers, - num_mlp_layers, - input_dim, - hidden_dim, - output_dim, - final_dropout=0.1, - learn_eps=False, - graph_pooling_type='sum', - neighbor_pooling_type='sum' - ): - super().__init__() - self.final_dropout = final_dropout - self.num_layers = num_layers - self.graph_pooling_type = graph_pooling_type - self.neighbor_pooling_type = neighbor_pooling_type - self.learn_eps = learn_eps - - self.mlps = nn.CellList() - self.convs = nn.CellList() - self.batch_norms = nn.CellList() - - if self.graph_pooling_type not in ('sum', 'avg'): - raise SyntaxError("graph pooling type not supported.") - for layer in range(num_layers - 1): - if layer == 0: - self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)) - else: - self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)) - self.convs.append(GINConv(ApplyNodeFunc(self.mlps[layer]), learn_eps=self.learn_eps, - aggregation_type=self.neighbor_pooling_type)) - self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) - - self.linears_prediction = nn.CellList() - for layer in range(num_layers): - if layer == 0: - self.linears_prediction.append(nn.Dense(input_dim, output_dim)) - else: - self.linears_prediction.append(nn.Dense(hidden_dim, output_dim)) - - def construct(self, x, edge_weight, g: BatchedGraph): - """construct function""" - hidden_rep = [x] - h = x - for layer in range(self.num_layers - 1): - h = self.convs[layer](h, edge_weight, g) - h = self.batch_norms[layer](h) - h = nn.ReLU()(h) - hidden_rep.append(h) - - score_over_layer = 0 - for layer, h in enumerate(hidden_rep): - if self.graph_pooling_type == 'sum': - pooled_h = g.sum_nodes(h) - else: - pooled_h = g.avg_nodes(h) - score_over_layer = score_over_layer + nn.Dropout(p=1.0 - self.final_dropout)( - self.linears_prediction[layer](pooled_h)) - return score_over_layer -``` - -For details about GINConv implementation, see the [API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/nn/conv/ginconv.py) code of mindspore_gl.nn.GINConv. - -## Constructing a Dataset - -From mindspore_gl.dataset calls the dataset of IMDB-BINARY, the method can refer to [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#constructing-a-dataset). Then use mindpoint_gl.dataloader.RandomBatchSampler defines a sampler and returns the sampling index. -MultiHomeGraphDataset obtains data from the dataset according to the sampling index, packages the data into a batch, and generates the dataset generator. -After building a generator, invoke the API of mindspore.dataset.GeneratorDataset to construct a dataloader. - -```python -dataset = IMDBBinary(arguments.data_path) -train_batch_sampler = RandomBatchSampler(dataset.train_graphs, batch_size=arguments.batch_size) -train_multi_graph_dataset = MultiHomoGraphDataset(dataset, arguments.batch_size, len(list(train_batch_sampler))) -test_batch_sampler = RandomBatchSampler(dataset.val_graphs, batch_size=arguments.batch_size) -test_multi_graph_dataset = MultiHomoGraphDataset(dataset, arguments.batch_size, len(list(test_batch_sampler))) - -train_dataloader = ds.GeneratorDataset(train_multi_graph_dataset, ['row', 'col', 'node_count', 'edge_count', - 'node_map_idx', 'edge_map_idx', 'graph_mask', - 'batched_label', 'batched_node_feat', - 'batched_edge_feat'], - sampler=train_batch_sampler) - -test_dataloader = ds.GeneratorDataset(test_multi_graph_dataset, ['row', 'col', 'node_count', 'edge_count', - 'node_map_idx', 'edge_map_idx', 'graph_mask', - 'batched_label', 'batched_node_feat', - 'batched_edge_feat'], - sampler=test_batch_sampler) -``` - -Use mindspore_gl.graph.BatchHomeGraph merges multiple sub-graphs into one whole graph. During model training, all graphs in the batch will be calculated in the form of whole graph. - -To reduce the generation of calculation graphs and speed up calculation, the generator unifies the data of each batch to the same size during returning data. - -Assume number of nodes is `node_size`and number of edges is `edge_size`, which is satisfies that the sum of nodes and edges for all graph data in batch is less than or equal to `node_size * batch` and `edge_size * batch`. -Create a new virtual graph in the batch, so that the sum of nodes and edges in the batch is equal to `node_size * batch` and `edge_size * batch`. -When calculating loss, this graph will not participate in the calculation. - -Call mindspore_gl.graph.PadArray2d to define the operation of node feature filling and edge feature filling, and set the node feature and edge feature on the virtual graph to 0. -Call mindspore_gl.graph.PadHomoGraph to define the operation of filling the nodes and edges on the graph structure, so that the number of nodes in the batch is equal to `node_size * batch`, and the number of edges is equal to `edge_size * batch`. - -```python -class MultiHomoGraphDataset(Dataset): - """MultiHomoGraph Dataset""" - def __init__(self, dataset, batch_size, length, mode=PadMode.CONST, node_size=50, edge_size=350): - self._dataset = dataset - self._batch_size = batch_size - self._length = length - self.batch_fn = BatchHomoGraph() - self.batched_edge_feat = None - node_size *= batch_size - edge_size *= batch_size - if mode == PadMode.CONST: - self.node_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.CONST, direction=PadDirection.COL, - size=(node_size, dataset.node_feat_size), fill_value=0) - self.edge_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.CONST, direction=PadDirection.COL, - size=(edge_size, dataset.edge_feat_size), fill_value=0) - self.graph_pad_op = PadHomoGraph(n_edge=edge_size, n_node=node_size, mode=PadMode.CONST) - else: - self.node_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.AUTO, direction=PadDirection.COL, - fill_value=0) - self.edge_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.AUTO, direction=PadDirection.COL, - fill_value=0) - self.graph_pad_op = PadHomoGraph(mode=PadMode.AUTO) - - # For Padding - self.train_mask = np.array([True] * (self._batch_size + 1)) - self.train_mask[-1] = False - - def __getitem__(self, batch_graph_idx): - graph_list = [] - feature_list = [] - for idx in range(batch_graph_idx.shape[0]): - graph_list.append(self._dataset[batch_graph_idx[idx]]) - feature_list.append(self._dataset.graph_node_feat(batch_graph_idx[idx])) - - # Batch Graph - batch_graph = self.batch_fn(graph_list) - - # Pad Graph - batch_graph = self.graph_pad_op(batch_graph) - - # Batch Node Feat - batched_node_feat = np.concatenate(feature_list) - - # Pad NodeFeat - batched_node_feat = self.node_feat_pad_op(batched_node_feat) - batched_label = self._dataset.graph_label[batch_graph_idx] - - # Pad Label - batched_label = np.append(batched_label, batched_label[-1] * 0) - - # Get Edge Feat - if self.batched_edge_feat is None or self.batched_edge_feat.shape[0] < batch_graph.edge_count: - del self.batched_edge_feat - self.batched_edge_feat = np.ones([batch_graph.edge_count, 1], dtype=np.float32) - - # Trigger Node_Map_Idx/Edge_Map_Idx Computation, Because It Is Lazily Computed - _ = batch_graph.batch_meta.node_map_idx - _ = batch_graph.batch_meta.edge_map_idx - - np_graph_mask = [1] * (self._batch_size + 1) - np_graph_mask[-1] = 0 - constant_graph_mask = ms.Tensor(np_graph_mask, dtype=ms.int32) - batchedgraphfiled = self.get_batched_graph_field(batch_graph, constant_graph_mask) - row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask = batchedgraphfiled.get_batched_graph() - return row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask, batched_label,\ - batched_node_feat, self.batched_edge_feat[:batch_graph.edge_count, :] -``` - -## Defining a Loss Function - -Since this is a classification task, the cross entropy can be used as the loss function, and the implementation method is similar to that of [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#defining-a-loss-function). - -Different from GCN, this tutorial is for graph classification. Therefore, when parsing batch graphs, the mindspore_gl.BatchedGraph interface is invoked. - -The last value in `g.graph_mask` is the mask of the virtual graph, which is 0. Therefore, the last loss value is also 0. - -```python -class LossNet(GNNCell): - """ LossNet definition """ - def __init__(self, net): - super().__init__() - self.net = net - self.loss_fn = ms.nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none') - - def construct(self, node_feat, edge_weight, target, g: BatchedGraph): - predict = self.net(node_feat, edge_weight, g) - target = ops.Squeeze()(target) - loss = self.loss_fn(predict, target) - loss = ops.ReduceSum()(loss * g.graph_mask) - return loss -``` - -## Network Training and Validation - -### Setting Environment Variables - -The method of setting environment variables is similar to that of setting [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#setting-environment-variables). - -### Defining a Training Network - -Instantiation of the model body GinNet and LossNet and optimizer. -Input the LossNet instance and optimizer to mindspore.nn.TrainOneStepCell to construct a single-step training network train_net. -The implementation method is similar to that of the [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#defining-a-training-network). - -### Network Training and Validation - -Because the graph is trained in batch, the API invoked during graph composition is mindspore_gl.BatchedGraphField, which is different from mindspore_gl.GraphField. It added the parameters of `node_map_idx`, `edge_map_idx`, and `graph_mask`. -The `graph_mask` is the mask information of each graph in the batch. The last graph is the virtual graph. Therefore, in the `graph_mask`, the last value is 0 and the rest is 1. - -```python -from mindspore_gl import BatchedGraph, BatchedGraphField - -for data in train_dataloader: - row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask, label, node_feat, edge_feat = data - batch_homo = BatchedGraphField(row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask) - output = net(node_feat, edge_feat, *batch_homo.get_batched_graph()).asnumpy() -``` - -## Executing Jobs and Viewing Results - -### Running Process - -After running the program, translate the code and start training. - -### Execution Results - -Run the [trainval_imdb_binary.py](https://gitee.com/mindspore/graphlearning/blob/master/model_zoo/gin/trainval_imdb_binary.py) script to start training. - -```bash -cd model_zoo/gin -python trainval_imdb_binary.py --data_path={path} -``` - -`{path}` indicates the dataset storage path. - -The training result is as follows: - -```bash -... -Epoch 52, Time 3.547 s, Train loss 0.49981827, Train acc 0.74219, Test acc 0.594 -Epoch 53, Time 3.599 s, Train loss 0.5046462, Train acc 0.74219, Test acc 0.656 -Epoch 54, Time 3.505 s, Train loss 0.49653444, Train acc 0.74777, Test acc 0.766 -Epoch 55, Time 3.468 s, Train loss 0.49411067, Train acc 0.74219, Test acc 0.750 -``` - -The best accuracy verified on IMDBBinary: 0.766 diff --git a/docs/graphlearning/docs/source_en/conf.py b/docs/graphlearning/docs/source_en/conf.py deleted file mode 100644 index e7628fe1714c4f11ff964e8ffc7c385425bdb755..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/conf.py +++ /dev/null @@ -1,217 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import shutil -import IPython -import re -import sys -from sphinx.ext import autodoc as sphinx_autodoc -import sphinx.ext.autosummary.generate as g -sys.path.append(os.path.abspath('../_ext')) - -import mindspore_gl - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_parser', - 'sphinx.ext.mathjax', - 'IPython.sphinxext.ipython_console_highlighting' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -autodoc_inherit_docstrings = False - -autosummary_generate = True - -autosummary_generate_overwrite = False - -html_static_path = ['_static'] - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -import sphinx_rtd_theme -layout_target = os.path.join(os.path.dirname(sphinx_rtd_theme.__file__), 'layout.html') -layout_src = '../../../../resource/_static/layout.html' -if os.path.exists(layout_target): - os.remove(layout_target) -shutil.copy(layout_src, layout_target) - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/', '../../../../resource/python_objects.inv'), - 'numpy': ('https://docs.scipy.org/doc/numpy/', '../../../../resource/numpy_objects.inv'), -} - -# overwriteautosummary_generate add view source for api. -with open('../_ext/overwriteautosummary_generate.txt', 'r', encoding="utf8") as f: - exec(f.read(), g.__dict__) - -# Modify default signatures for autodoc. -autodoc_source_path = os.path.abspath(sphinx_autodoc.__file__) -autodoc_source_re = re.compile(r'stringify_signature\(.*?\)') -get_param_func_str = r"""\ -import re -import inspect as inspect_ - -def get_param_func(func): - try: - source_code = inspect_.getsource(func) - if func.__doc__: - source_code = source_code.replace(func.__doc__, '') - all_params_str = re.findall(r"def [\w_\d\-]+\(([\S\s]*?)(\):|\) ->.*?:)", source_code) - all_params = re.sub("(self|cls)(,|, )?", '', all_params_str[0][0].replace("\n", "").replace("'", "\"")) - return all_params - except: - return '' - -def get_obj(obj): - if isinstance(obj, type): - return obj.__init__ - - return obj -""" - -with open(autodoc_source_path, "r+", encoding="utf8") as f: - code_str = f.read() - code_str = autodoc_source_re.sub('"(" + get_param_func(get_obj(self.object)) + ")"', code_str, count=0) - exec(get_param_func_str, sphinx_autodoc.__dict__) - exec(code_str, sphinx_autodoc.__dict__) - -# Copy source files of chinese python api from mindscience repository. -from sphinx.util import logging -logger = logging.getLogger(__name__) - -gl_dir_msg = os.path.join(os.getenv("GL_PATH"), 'docs/api_python_en') - -present_path = os.path.dirname(__file__) - -for i in os.listdir(gl_dir_msg): - if os.path.isfile(os.path.join(gl_dir_msg,i)): - if os.path.exists('./'+i): - os.remove('./'+i) - shutil.copy(os.path.join(gl_dir_msg,i),'./'+i) - else: - if os.path.exists('./'+i): - shutil.rmtree('./'+i) - shutil.copytree(os.path.join(gl_dir_msg,i),'./'+i) - -# get params for add view source -import json - -if os.path.exists('../../../../tools/generate_html/version.json'): - with open('../../../../tools/generate_html/version.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily_dev.json'): - with open('../../../../tools/generate_html/daily_dev.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily.json'): - with open('../../../../tools/generate_html/daily.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) - -if os.getenv("GL_PATH").split('/')[-1]: - copy_repo = os.getenv("GL_PATH").split('/')[-1] -else: - copy_repo = os.getenv("GL_PATH").split('/')[-2] - -branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == copy_repo][0] -docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == 'tutorials'][0] -cst_module_name = 'mindspore_gl' -repo_whl = 'mindspore_gl' -giturl = 'https://gitee.com/mindspore/' - -sys.path.append(os.path.abspath('../../../../resource/sphinx_ext')) -# import anchor_mod -import nbsphinx_mod - -sys.path.append(os.path.abspath('../../../../resource/search')) -import search_code - -sys.path.append(os.path.abspath('../../../../resource/custom_directives')) -from custom_directives import IncludeCodeDirective -from myautosummary import MsPlatformAutoSummary - -def setup(app): - app.add_directive('msplatformautosummary', MsPlatformAutoSummary) - app.add_directive('includecode', IncludeCodeDirective) - app.add_config_value('docs_branch', '', True) - app.add_config_value('branch', '', True) - app.add_config_value('cst_module_name', '', True) - app.add_config_value('copy_repo', '', True) - app.add_config_value('giturl', '', True) - app.add_config_value('repo_whl', '', True) - -src_release = os.path.join(os.getenv("GL_PATH"), 'RELEASE.md') -des_release = "./RELEASE.md" -with open(src_release, "r", encoding="utf-8") as f: - data = f.read() -if len(re.findall("\n## (.*?)\n",data)) > 1: - content = re.findall("(## [\s\S\n]*?)\n## ", data) -else: - content = re.findall("(## [\s\S\n]*)", data) -#result = content[0].replace('# MindSpore', '#', 1) -with open(des_release, "w", encoding="utf-8") as p: - p.write("# Release Notes"+"\n\n") - p.write(content[0]) \ No newline at end of file diff --git a/docs/graphlearning/docs/source_en/faq.md b/docs/graphlearning/docs/source_en/faq.md deleted file mode 100644 index bf67a69f71a7ef40a95ce7590ed25eefe0b9ac22..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/faq.md +++ /dev/null @@ -1,37 +0,0 @@ -# FAQ - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_en/faq.md) - -**Q: What should I do if the error message `OSError: could not get source code` is displayed during the execution of the GNNCell command?** - -A: MindSpore Graph Learning parses vertex-centeric programming code through source-to-source translation. The inspect function is called to obtain source code of the translation module. Model definition code needs to be placed in a Python file. Otherwise, an error message is displayed, indicating that the source code cannot be found. - -
    - -**Q: What should I do if the error message `AttributeError: None of backend from {mindspore} is identified. Backend must be imported as a global variable.` is displayed during the execution of GNNCell?** - -A: MindSpore Graph Learning parses vertex-centeric programming code through source-to-source translation. In the GNNCell definition file, the network execution backend is obtained based on global variables. You need to import MindSpore in the header of the GNNCell definition file. Otherwise, an error is reported, indicating that the backend cannot be found. - -
    - -**Q: What should I do if the error message `TypeInferenceError: Line 6: Built-in agg func "avg" only takes expr of EDGE or SRC type. Got None.` is displayed when the graph aggregation APIs 'sum, avg, max, and min' is called?** - -A: The aggregate API provided by MindSpore Graph Learning performs operations on graph nodes. During source-to-source translation, the system checks whether the input of the aggregate API is an edge or a node in the graph. If not, an error is reported, indicating that the required input type cannot be found. - -
    - -**Q: What should I do if the error message `RuntimeError: The 'mul' operation does not support the type.` is displayed when I call the graph API 'dot'?** - -A: The dot API provided by MindSpore Graph Learning is used to perform the dot multiplication operation on graph nodes. The backend includes feature multiplication and aggregation. The frontend translation process does not involve build and cannot determine the input data type. The input type must meet the type requirements of the backend mul operator. Otherwise, an error message is displayed, indicating that the type is not supported. - -
    - -**Q: What should I do if the error message `TypeError: For 'tensor getitem', the types only support 'Slice', 'Ellipsis', 'None', 'Tensor', 'int', 'List', 'Tuple', 'bool', but got String.` is displayed when I call the graph API 'topk_nodes,topk_edges'?** - -A: The topk_nodes API provided by MindSpore Graph Learning is used to obtain k nodes or edges based on node or edge feature sorting. The backend includes three steps: obtaining node or edge features, sorting, and slicing k nodes or edges. The frontend translation process does not involve build and cannot determine the input data type. The input type must meet the sorting dimension sortby and value range k of the sort and slice operators. Otherwise, an error is reported, indicating that the type is not supported. - -
    - -**Q: What should I do if the error message `TypeError: For 'Cell', the function construct need 5 positional argument, but got 2.'` is displayed when a non-GraphField instance or equivalent tensors are passed to the input graph of the construct function?** - -A: The GNNCell class provided by MindSpore Graph Learning is a base class for writing vertex-centeric programming GNN models. It must contain the graph class as the last input parameter. The translated input is four tensor parameters, which are src_idx, dst_idx, n_nodes, and n_edges. If only a non-GraphField instance or four equivalent tensors are passed, an error message is displayed, indicating that the input parameter is incorrect. diff --git a/docs/graphlearning/docs/source_en/full_training_of_GCN.md b/docs/graphlearning/docs/source_en/full_training_of_GCN.md deleted file mode 100644 index 4568252c40b43d919e7f224b17831f0ca6a340ca..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/full_training_of_GCN.md +++ /dev/null @@ -1,262 +0,0 @@ -# Entire Graph Training Network - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_en/full_training_of_GCN.md) -   - -## Overview - -In this example, it will show how to do the semi-supervised classification with Graph Convolutional Networks in Cora Dataset. - -Graph Convolutional Networks (GCN) was proposed in 2016 and designed to do semi-supervised learning on graph-structured data. A scalable approach based on an efficient variant of convolutional neural networks which operate directly on graphs was presented. The model scales linearly in the number of graph edges and learns hidden layer representations that encode both local graph structure and features of nodes. - -The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 10556 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. - -The classification of Cora's literature is taken as the label,the word vector of the literature is taken as the node feature of GCN,and the reference of the literature is taken as the edge. The GCN is used to train the cora graph to predict which category the literature belongs to. - -> Download the complete sample code here: [GCN](https://gitee.com/mindspore/graphlearning/tree/master/examples/). - -## GCN Principles - -Paper: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) - -## Defining a Network Model - -mindspore_gl.nn implements GCNConv, which can be directly imported for use. You can also define your own convolutional layer. The code for implementing a two-layer GCN network using GCNConv is as follows: - -```python -import mindspore -from mindspore_gl.nn import GNNCell -from mindspore_gl import Graph -from mindspore_gl.nn import GCNConv - - -class GCNNet(GNNCell): - def __init__(self, - data_feat_size: int, - hidden_dim_size: int, - n_classes: int, - dropout: float, - activation): - super().__init__() - self.layer0 = GCNConv(data_feat_size, hidden_dim_size, activation(), dropout) - self.layer1 = GCNConv(hidden_dim_size, n_classes, None, dropout) - - def construct(self, x, in_deg, out_deg, g: Graph): - x = self.layer0(x, in_deg, out_deg, g) - x = self.layer1(x, in_deg, out_deg, g) - return x -``` - -GCNNet is inherited from GNNCell. The last input of the construct function in GNNCell must be a graph or BatchedGraph, that is, the graph structure class supported by MindSpore Graph Learning. In addition, you must import mindspore at the header of the file to identify the execution backend when the code is translated. - -In GCNConv, data_feat_size indicates the feature dimension of the input node, hidden_dim_size indicates the feature dimension of the hidden layer, n_classes indicates the dimension of the output classification, and in_deg and out_deg indicate the indegree and outdegree of the node in the graph data, respectively. -For details about GCN implementation, refer to the interface code of mindspore_gl.nn.GCNConv: . - -## Defining a Loss Function - -Define LossNet, including a network backbone net and a loss function. In this example, mindspore.nn.SoftmaxCrossEntropyWithLogits is used to implement a cross entropy loss. - -```python -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore_gl.nn import GNNCell - - -class LossNet(GNNCell): - """ LossNet definition """ - - def __init__(self, net): - super().__init__() - self.net = net - self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none') - - def construct(self, x, in_deg, out_deg, train_mask, target, g: Graph): - predict = self.net(x, in_deg, out_deg, g) - target = ops.Squeeze()(target) - loss = self.loss_fn(predict, target) - loss = loss * train_mask - return ms.ops.ReduceSum()(loss) / ms.ops.ReduceSum()(train_mask) -``` - -In the preceding code, net can be transferred to GCNNet by constructing a LossNet instance. predict indicates the predicted value output by the net, and target indicates the actual value. Because the training is based on the entire graph, train_mask is used to obtain a part of the entire graph as the training data. Only this part of nodes are involved in the loss calculation. -LossNet and GCNNet are inherited from GNNCell. - -## Constructing a Dataset - -The mindspore_gl.dataset directory contains some dataset class definitions for reference. You can directly read some common datasets. The following uses the CORA dataset as an example. Enter the data path to construct a data class. - -```python -from mindspore_gl.dataset import CoraV2 - -ds = CoraV2(args.data_path) -``` - -The [Cora](https://data.dgl.ai/dataset/cora_v2.zip) data can be downloaded and decompressed to args.data_path. - -## Network Training and Validation - -### Setting Environment Variables - -The settings of environment variables are the same as those for other MindSpore network training. Especially, if enable_graph_kernel is set to True, the graph kernel build optimization is enabled to accelerate the graph model training. - -```python -import mindspore as ms -import os - -if train_args.fuse: - ms.set_context(device_target="GPU", save_graphs=2, save_graphs_path="./computational_graph/", - mode=ms.GRAPH_MODE, enable_graph_kernel=True) - graph_kernel_flags="--enable_expand_ops=Gather --enable_cluster_ops=TensorScatterAdd," - "UnsortedSegmentSum, GatherNd --enable_recompute_fusion=false " - "--enable_parallel_fusion=true " - os.environ['MS_DEV_GRAPH_KERNEL_FLAGS'] = graph_kernel_flags -else: - ms.set_context(device_target="GPU", mode=ms.PYNATIVE_MODE) -``` - -### Defining a Training Network - -Similar to other supervised learning models, in addition to the instantiation of the model body GCNNet and LossNet, the graph neural network training requires the definition of an optimizer. Here, mindspore.nn.Adam is used. -Input the LossNet instance and optimizer to mindspore.nn.TrainOneStepCell to construct a single-step training network train_net. - -```python -import mindspore.nn as nn - -net = GCNNet(data_feat_size=feature_size, - hidden_dim_size=train_args.num_hidden, - n_classes=ds.n_classes, - dropout=train_args.dropout, - activation=ms.nn.ELU) - optimizer = nn.optim.Adam(net.trainable_params(), learning_rate=train_args.lr, weight_decay=train_args.weight_decay) - loss = LossNet(net) - train_net = nn.TrainOneStepCell(loss, optimizer) -``` - -### Network Training and Validation - -Because the entire graph is trained, one training step covers the entire dataset. Each epoch is one training step. Similarly, the verification node is obtained through test_mask. To calculate the verification accuracy, you only need to compare the verification node in the entire graph with the actual value label. -If the predicted value is consistent with the actual value, the verification is correct. The ratio of the number of correct nodes (count) to the total number of verification nodes is the verification accuracy. - -```python -for e in range(train_args.epochs): - beg = time.time() - train_net.set_train() - train_loss = train_net() - end = time.time() - dur = end - beg - if e >= warm_up: - total = total + dur - - test_mask = ds.test_mask - if test_mask is not None: - net.set_train(False) - out = net(ds.x, ds.in_deg, ds.out_deg, ds.g.src_idx, ds.g.dst_idx, ds.g.n_nodes, ds.g.n_edges).asnumpy() - labels = ds.y.asnumpy() - predict = np.argmax(out[test_mask], axis=1) - label = labels[test_mask] - count = np.equal(predict, label) - print('Epoch time:{} ms Train loss {} Test acc:{}'.format(dur * 1000, train_loss, - np.sum(count) / label.shape[0])) -``` - -## Executing Jobs and Viewing Results - -### Running Process - -After running the program, you can view the comparison diagram of all translated functions. By default, the construct function in GNNCell is translated. The following figure shows the GCNConv translation comparison. The left part is the GCNConv source code, and the right part is the translated code. -You can see the code implementation after the graph API is replaced by mindspore_gl. For example, the code implementation after the called graph aggregate function g.sum is replaced by Gather-Scatter. -It can be seen that the node-centric programming paradigm greatly reduces the amount of code implemented by the graph model. - -```bash --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -| def construct(self, x, in_deg, out_deg, g: Graph): 1 || 1 def construct( | -| || self, | -| || x, | -| || in_deg, | -| || out_deg, | -| || src_idx, | -| || dst_idx, | -| || n_nodes, | -| || n_edges, | -| || ver_subgraph_idx=None, | -| || edge_subgraph_idx=None, | -| || graph_mask=None | -| || ): | -| || 2 SCATTER_ADD = ms.ops.TensorScatterAdd() | -| || 3 SCATTER_MAX = ms.ops.TensorScatterMax() | -| || 4 SCATTER_MIN = ms.ops.TensorScatterMin() | -| || 5 GATHER = ms.ops.Gather() | -| || 6 ZEROS = ms.ops.Zeros() | -| || 7 SHAPE = ms.ops.Shape() | -| || 8 RESHAPE = ms.ops.Reshape() | -| || 9 scatter_src_idx = RESHAPE(src_idx, (SHAPE(src_idx)[0], 1)) | -| || 10 scatter_dst_idx = RESHAPE(dst_idx, (SHAPE(dst_idx)[0], 1)) | -| out_deg = ms.ops.clip_by_value(out_deg, self.min_clip, self.max_clip) 2 || 11 out_deg = ms.ops.clip_by_value(out_deg, self.min_clip, self.max_clip) | -| out_deg = ms.ops.Reshape()( 3 || 12 out_deg = ms.ops.Reshape()( | -| ms.ops.Pow()(out_deg, -0.5), || ms.ops.Pow()(out_deg, -0.5), | -| ms.ops.Shape()(out_deg) + (1,) || ms.ops.Shape()(out_deg) + (1,) | -| ) || ) | -| x = self.drop_out(x) 4 || 13 x = self.drop_out(x) | -| x = ms.ops.Squeeze()(x) 5 || 14 x = ms.ops.Squeeze()(x) | -| x = x * out_deg 6 || 15 x = x * out_deg | -| x = self.fc(x) 7 || 16 x = self.fc(x) | -| g.set_vertex_attr({'x': x}) 8 || 17 VERTEX_SHAPE = SHAPE(x)[0] | -| || 18 x, = [x] | -| for v in g.dst_vertex: 9 || | -| v.x = g.sum([u.x for u in v.innbs]) 10 || 19 SCATTER_INPUT_SNAPSHOT2 = GATHER(x, src_idx, 0) | -| || 20 x = SCATTER_ADD( | -| || ZEROS((VERTEX_SHAPE,) + SHAPE(SCATTER_INPUT_SNAPSHOT2)[1:], ms.float32), | -| || scatter_dst_idx, | -| || SCATTER_INPUT_SNAPSHOT2 | -| || ) | -| in_deg = ms.ops.clip_by_value(in_deg, self.min_clip, self.max_clip) 11 || 21 in_deg = ms.ops.clip_by_value(in_deg, self.min_clip, self.max_clip) | -| in_deg = ms.ops.Reshape()(ms.ops.Pow()(in_deg, -0.5), ms.ops.Shape()(in_deg) + (1,)) 12 || 22 in_deg = ms.ops.Reshape()(ms.ops.Pow()(in_deg, -0.5), ms.ops.Shape()(in_deg) + (1,)) | -| x = [v.x for v in g.dst_vertex] * in_deg 13 || 23 x = x * in_deg | -| x = x + self.bias 14 || 24 x = x + self.bias | -| return x 15 || 25 return x | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - -``` - -### Enabling or Disabling Translation Display - -The translation comparison show is displayed by default setting during code execution. To disable the comparison show is as follows: - -```python -from mindspore_gl.nn import GNNCell -GNNCell.disable_display() -``` - -To change the display width (default: 200), code is as follows: - -```python -from mindspore_gl.nn import GNNCell -GNNCell.enable_display(screen_width=350) -``` - -### Execution Results - -Run the [vc_gcn_datanet.py](https://gitee.com/mindspore/graphlearning/blob/master/examples/vc_gcn_datanet.py) script to start training. - -```bash -cd examples -python vc_gcn_datanet.py --data-path={path} --fuse=True -``` - -`{path}` indicates the dataset storage path. - -The training result (of the last five epochs) is as follows: - -```bash -... -Epoch 196, Train loss 0.30630863, Test acc 0.822 -Epoch 197, Train loss 0.30918056, Test acc 0.819 -Epoch 198, Train loss 0.3299482, Test acc 0.819 -Epoch 199, Train loss 0.2945389, Test acc 0.821 -Epoch 200, Train loss 0.27628058, Test acc 0.819 -``` - -Accuracy verified on CORA: 0.82 (thesis: 0.815) - -The preceding is the usage guide of the entire graph training. For more examples, see [examples directory](https://gitee.com/mindspore/graphlearning/tree/master/examples/). diff --git a/docs/graphlearning/docs/source_en/images/gat_example.PNG b/docs/graphlearning/docs/source_en/images/gat_example.PNG deleted file mode 100644 index 9477fb68c19237a73ce4f21052e4beef2e357970..0000000000000000000000000000000000000000 Binary files a/docs/graphlearning/docs/source_en/images/gat_example.PNG and /dev/null differ diff --git a/docs/graphlearning/docs/source_en/images/graphlearning_en.png b/docs/graphlearning/docs/source_en/images/graphlearning_en.png deleted file mode 100644 index 9473d28a237f1b43c9aa33df14afcd8bbe39b4a8..0000000000000000000000000000000000000000 Binary files a/docs/graphlearning/docs/source_en/images/graphlearning_en.png and /dev/null differ diff --git a/docs/graphlearning/docs/source_en/index.rst b/docs/graphlearning/docs/source_en/index.rst deleted file mode 100644 index 3c5becf3f8959160af06d3b3375bd148017e75c6..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/index.rst +++ /dev/null @@ -1,93 +0,0 @@ -MindSpore Graph Learning Documents -=================================== - -MindSpore Graph Learning is a graph learning suite, which supports point-centered programming for graph neural networks and efficient training inference. - -Thanks to the MindSpore graph kernel fusion, MindSpore Graph Learning can optimize the build of execution patterns specific to graph models, helping developers shorten the training time. MindSpore Graph Learning also proposes an innovative vertex-centric programming paradigm, which provides native graph neural network expressions and built-in models covering most application scenarios, enabling developers to easily build graph neural networks. - -.. image:: ./images/graphlearning_en.png - :width: 700px - -Code repository address: - -Design Features ----------------- - -1. Vertex-centric programming paradigm - - A graph neural network model transfers and aggregates information on a given graph structure, which cannot be intuitively expressed through entire graph computing. MindSpore Graph Learning provides a vertex-centric programming paradigm that better complies with the graph learning algorithm logic and Python language style. It can directly translate formulas into code, reducing the gap between algorithm design and implementation. - -2. Accelerated graph models - - MindSpore Graph Learning combines the features of MindSpore graph kernel fusion and auto kernel generator (AKG) to automatically identify the specific execution pattern of graph neural network tasks for fusion and kernel-level optimization, covering the fusion of existing operators and new combined operators in the existing framework. The performance is improved by 3 to 4 times compared with that of the existing popular frameworks. - -Training Process ------------------- - -MindSpore Graph Learning provides abundant dataset read, graph operation, and network module APIs. To train graph neural networks, perform the following steps: - -1. Define a network model. You can directly call the API provided by mindspore_gl.nn or define your own graph learning model by referring to the implementation of mindspore_gl.nn. - -2. Define a loss function. - -3. Construct a dataset. mindspore_gl.dataset provides the function of reading and constructing some public datasets for research. - -4. Train and validate the network. - -Feature Introduction ---------------------- - -MindSpore Graph Learning provides a node-centric GNN programming paradigm. Its built-in code parsing functions translate node-centric computing expressions into graph data computing operations. To facilitate debugging, a translation comparison between the user input code and the calculation code is printed during the parsing process. - -The following figure shows the implementation of the classic GAT network based on the node-centric programming model. A user defines a function that uses node `v` as the input parameter. In the function, the user obtains the neighboring node list through `v.innbs()`. Traverse each neighboring node `u` to obtain node features, and calculate feature interaction between the neighboring node and the central node to obtain a weight list of neighboring edges. Then, the weights of neighboring edges and neighboring nodes are weighted averaged, and the updated central node features are returned. - -.. image:: ./images/gat_example.PNG - :width: 700px - -Future Roadmap ---------------- - -The initial version of MindSpore Graph Learning includes the point-centric programming paradigm, provides implementation of typical graph models, and provides cases and performance evaluation for single-node training on small datasets. The initial version does not support performance evaluation and distributed training on large datasets, and does not support interconnection with graph databases. These features will be included in later versions of MindSpore Graph Learning. - -Typical MindSpore Graph Learning Application Scenarios -------------------------------------------------------- - -.. toctree:: - :maxdepth: 1 - :caption: Deployment - - mindspore_graphlearning_install - -.. toctree:: - :maxdepth: 1 - :caption: Guide - - full_training_of_GCN - batched_graph_training_GIN - spatio_temporal_graph_training_STGCN - single_host_distributed_Graphsage - -.. toctree:: - :maxdepth: 1 - :caption: API References - - mindspore_gl - mindspore_gl.dataloader - mindspore_gl.dataset - mindspore_gl.graph - mindspore_gl.nn - mindspore_gl.sampling - mindspore_gl.utils - -.. toctree:: - :maxdepth: 1 - :caption: References - - faq - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE diff --git a/docs/graphlearning/docs/source_en/mindspore_graphlearning_install.md b/docs/graphlearning/docs/source_en/mindspore_graphlearning_install.md deleted file mode 100644 index 2c91d6775d47db26f7299167e61214dda61c3b16..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/mindspore_graphlearning_install.md +++ /dev/null @@ -1,57 +0,0 @@ -# Installing Graph Learning - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_en/mindspore_graphlearning_install.md)   - -## Installation - -### System Environment Information Confirmation - -- Ensure that the hardware platform is Ascend or GPU under the Linux system. -- Refer to [MindSpore Installation Guide](https://www.mindspore.cn/install/en) to complete the installation of MindSpore, which requires at least version 2.0.0. -- For other dependencies, please refer to [requirements.txt](https://gitee.com/mindspore/graphlearning/blob/master/requirements.txt). - -### Installation Methods - -You can install MindSpore Graph Learning either by pip or by source code. - -#### Installation by pip - -- Ascend/CPU - - ```bash - pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.0.0rc1/GraphLearning/cpu/{system_structure}/mindspore_gl-0.2-cp37-cp37m-linux_{system_structure}.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple - ``` - -- GPU - - ```bash - pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.0.0rc1/GraphLearning/gpu/x86_64/cuda-{cuda_verison}/mindspore_gl-0.2-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple - ``` - -> - When the network is connected, dependency items are automatically downloaded during .whl package installation. For details about other dependency items, see [requirements.txt](https://gitee.com/mindspore/graphlearning/blob/master/requirements.txt). In other cases, you need to manually install dependency items. -> - `{system_structure}` denotes the Linux system architecture, and the option is `x86_64` and `arrch64`. -> - `{cuda_verison}` denotes the CUDA version, and the option is `10.1`, `11.1` and `11.6`. - -#### Installation by Source Code - -1. Download source code from Gitee. - - ```bash - git clone https://gitee.com/mindspore/graphlearning.git - ``` - -2. Compile and install in MindSpore Graph Learning directory. - - ```bash - cd graphlearning - bash build.sh - pip install ./output/mindspore_gl-*.whl - ``` - -### Installation Verification - -Successfully installed, if there is no error message such as `No module named 'mindspore_gl'` when execute the following command: - -```bash -python -c 'import mindspore_gl' -``` diff --git a/docs/graphlearning/docs/source_en/single_host_distributed_Graphsage.md b/docs/graphlearning/docs/source_en/single_host_distributed_Graphsage.md deleted file mode 100644 index 8aa9d4dc9fe91697004a45cd86bf1b92081eba6d..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/single_host_distributed_Graphsage.md +++ /dev/null @@ -1,270 +0,0 @@ -# Single-host Distributed Training - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_en/single_host_distributed_Graphsage.md) -   - -## Overview - -In this example, it will show how to do the single-host distributed training of GraphSAGE on large size graphs. - -GraphSAGE is a general inductive framework that leverages node feature information (e.g., text attributes) to efficiently generate node embeddings for previously unseen data. Instead of training individual embeddings for each node, GraphSAGE learns a function that generates embeddings by sampling and aggregating features from a node's local neighborhood. - -In the Reddit dataset, the authors sampled 50 large communities and constructed a post-to-post graph, linking posts if the same user commented on both posts. Each post is labeled as the community to which it belongs. The dataset contains a total of 232965 posts with an average degree of 492. - -Since the Reddit dataset size is large, to reduce the GraphSAGE training time, in this example, distributed model training is performed on single-host to accelerate the model training. - -> Download the complete sample code here: [GraphSAGE](https://gitee.com/mindspore/graphlearning/tree/master/model_zoo/graphsage). - -## GraphSAGE Principles - -Paper: [Inductive representation learning on large graphs](https://proceedings.neurips.cc/paper/2017/file/5dd9db5e033da9c6fb5ba83c7a7ebea9-Paper.pdf) - -## Setting Running Script - -The invoking method of distributed training depending on the device. - -On the GPU hardware platform, communication in MindSpore distributed parallel training uses NVIDIA’s collective communication library NVIDIA Collective Communication Library (NCCL for short). - -```bash -# GPU -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -export CUDA_NUM=8 -rm -rf device -mkdir device -cp -r src ./device -cp distributed_trainval_reddit.py ./device -cd ./device -echo "start training" -mpirun --allow-run-as-root -n ${CUDA_NUM} python3 ./distributed_trainval_reddit.py --data-path ${DATA_PATH} --epochs 5 > train.log 2>&1 & -``` - -The Huawei Collective Communication Library (HCCL) is used for the communication of MindSpore parallel distributed training and can be found in the Atlas 200/300/500 inference product software package. In addition, mindspore.communication.management encapsulates the collective communication API provided by the HCCL to help users configure distributed information. - -```bash -# Ascend -RANK_TABLE_FILE=$3 - export RANK_TABLE_FILE=${RANK_TABLE_FILE} - for((i=0;i<8;i++)); - do - export RANK_ID=$[i+RANK_START] - export DEVICE_ID=$[i+RANK_START] - echo ${DEVICE_ID} - rm -rf ${execute_path}/device_$RANK_ID - mkdir ${execute_path}/device_$RANK_ID - cd ${execute_path}/device_$RANK_ID || exit - echo "start training" - python3 ${self_path}/distributed_trainval_reddit.py --data-path ${DATA_PATH} --epochs 2 > train$RANK_ID.log 2>&1 & - done -``` - -## Defining a Network Model - -mindspore_gl.nn implements SAGEConv, which can be directly imported for use. You can also define your own convolutional layer. The code for implementing a two-layer GraphSAGE network using SAGEConv is as follows: - -```python -class SAGENet(Cell): - """graphsage net""" - def __init__(self, in_feat_size, hidden_feat_size, appr_feat_size, out_feat_size): - super().__init__() - self.num_layers = 2 - self.layer1 = SAGEConv(in_feat_size, hidden_feat_size, aggregator_type='mean') - self.layer2 = SAGEConv(hidden_feat_size, appr_feat_size, aggregator_type='mean') - self.dense_out = ms.nn.Dense(appr_feat_size, out_feat_size, has_bias=False, - weight_init=XavierUniform(math.sqrt(2))) - self.activation = ms.nn.ReLU() - self.dropout = ms.nn.Dropout(p=0.5) - - def construct(self, node_feat, edges, n_nodes, n_edges): - """graphsage net forward""" - node_feat = self.layer1(node_feat, None, edges[0], edges[1], n_nodes, n_edges) - node_feat = self.activation(node_feat) - node_feat = self.dropout(node_feat) - ret = self.layer2(node_feat, None, edges[0], edges[1], n_nodes, n_edges) - ret = self.dense_out(ret) - return ret -``` - -For details about SAGENet implementation, see the [API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/nn/conv/sageconv.py) code of mindspore_gl.nn.SAGEConv. - -## Defining a Loss Function - -Because this task is a classification task, the cross entropy can be used as the loss function, and the implementation method is similar to that of [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#defining-a-loss-function). - -## Constructing a Dataset - -The following uses the [Reddit](https://data.dgl.ai/dataset/reddit.zip) dataset as an example. Enter the data path to construct a data class. -The get_group_size is used to obtain the total number of processes for distributed training, and the get_rank is used to obtain the ID of the current process. The construction method of dataloader can refer to [GIN](https://www.mindspore.cn/graphlearning/docs/en/master/batched_graph_training_GIN.html#constructing-a-dataset). - -Different from GIN, in this example, the sampler is mindpoint_gl.dataloader.DistributeRandomBatchSampler. In DistributeRandomBatchSampler, datasets can be split based on process ID to ensure that each process obtains different part of dataset batches. - -```python -from mindspore_gl.dataset import Reddit -from mindspore.communication import get_rank, get_group_size - -rank_id = get_rank() -world_size = get_group_size() -graph_dataset = Reddit(args.data_path) -train_sampler = DistributeRandomBatchSampler(rank_id, world_size, data_source=graph_dataset.train_nodes, - batch_size=args.batch_size) -test_sampler = RandomBatchSampler(data_source=graph_dataset.test_nodes, batch_size=args.batch_size) -train_dataset = GraphSAGEDataset(graph_dataset, [25, 10], args.batch_size, len(list(train_sampler)), single_size) -test_dataset = GraphSAGEDataset(graph_dataset, [25, 10], args.batch_size, len(list(test_sampler)), single_size) -train_dataloader = ds.GeneratorDataset(train_dataset, ['seeds_idx', 'label', 'nid_feat', 'edges'], - sampler=train_sampler, python_multiprocessing=True) -test_dataloader = ds.GeneratorDataset(test_dataset, ['seeds_idx', 'label', 'nid_feat', 'edges'], - sampler=test_sampler, python_multiprocessing=True) -``` - -mindspore_gl.sampling.sage_sampler_on_homo provides a k-hop sampling method. In the list of `self.neighbor_nums`, the number of sampling nodes from the central node to the outside when sampling. -Since the degree of each point is different, the size of the array after k-hop sampling is also different. Discretize the sampling results into 5 fixed values through the API of mindspore_gl.graph.PadArray2d. - -```python -from mindspore_gl.dataloader.dataset import Dataset -from mindspore_gl.sampling.neighbor import sage_sampler_on_homo - -class GraphSAGEDataset(Dataset): - """Do sampling from neighbour nodes""" - def __init__(self, graph_dataset, neighbor_nums, batch_size, length, single_size=False): - self.graph_dataset = graph_dataset - self.graph = graph_dataset[0] - self.neighbor_nums = neighbor_nums - self.x = graph_dataset.node_feat - self.y = graph_dataset.node_label - self.batch_size = batch_size - self.max_sampled_nodes_num = neighbor_nums[0] * neighbor_nums[1] * batch_size - self.single_size = single_size - self.length = length - - def __getitem__(self, batch_nodes): - batch_nodes = np.array(batch_nodes, np.int32) - res = sage_sampler_on_homo(self.graph, batch_nodes, self.neighbor_nums) - label = array_kernel.int_1d_array_slicing(self.y, batch_nodes) - layered_edges_0 = res['layered_edges_0'] - layered_edges_1 = res['layered_edges_1'] - sample_edges = np.concatenate((layered_edges_0, layered_edges_1), axis=1) - sample_edges = sample_edges[[1, 0], :] - num_sample_edges = sample_edges.shape[1] - num_sample_nodes = len(res['all_nodes']) - max_sampled_nodes_num = self.max_sampled_nodes_num - if self.single_size is False: - if num_sample_nodes < floor(0.2*max_sampled_nodes_num): - pad_node_num = floor(0.2*max_sampled_nodes_num) - elif num_sample_nodes < floor(0.4*max_sampled_nodes_num): - pad_node_num = floor(0.4 * max_sampled_nodes_num) - elif num_sample_nodes < floor(0.6*max_sampled_nodes_num): - pad_node_num = floor(0.6 * max_sampled_nodes_num) - elif num_sample_nodes < floor(0.8*max_sampled_nodes_num): - pad_node_num = floor(0.8 * max_sampled_nodes_num) - else: - pad_node_num = max_sampled_nodes_num - - if num_sample_edges < floor(0.2*max_sampled_nodes_num): - pad_edge_num = floor(0.2*max_sampled_nodes_num) - elif num_sample_edges < floor(0.4*max_sampled_nodes_num): - pad_edge_num = floor(0.4 * max_sampled_nodes_num) - elif num_sample_edges < floor(0.6*max_sampled_nodes_num): - pad_edge_num = floor(0.6 * max_sampled_nodes_num) - elif num_sample_edges < floor(0.8*max_sampled_nodes_num): - pad_edge_num = floor(0.8 * max_sampled_nodes_num) - else: - pad_edge_num = max_sampled_nodes_num - - else: - pad_node_num = max_sampled_nodes_num - pad_edge_num = max_sampled_nodes_num - - layered_edges_pad_op = PadArray2d(mode=PadMode.CONST, size=[2, pad_edge_num], - dtype=np.int32, direction=PadDirection.ROW, - fill_value=pad_node_num - 1, - ) - nid_feat_pad_op = PadArray2d(mode=PadMode.CONST, - size=[pad_node_num, self.graph_dataset.node_feat_size], - dtype=self.graph_dataset.node_feat.dtype, - direction=PadDirection.COL, - fill_value=0, - reset_with_fill_value=False, - use_shared_numpy=True - ) - sample_edges = sample_edges[:, :pad_edge_num] - pad_sample_edges = layered_edges_pad_op(sample_edges) - feat = nid_feat_pad_op.lazy([num_sample_nodes, self.graph_dataset.node_feat_size]) - array_kernel.float_2d_gather_with_dst(feat, self.graph_dataset.node_feat, res['all_nodes']) - return res['seeds_idx'], label, feat, pad_sample_edges -``` - -## Network Training and Validation - -### Setting Environment Variables - -During distributed training, data is imported in data parallel mode. At the end of each training step, each process unifies the model parameters. On Ascend it must be ensured that the data shape is the same in each process. - -```python -device_target = str(os.getenv('DEVICE_TARGET')) -if device_target == 'Ascend': - device_id = int(os.getenv('DEVICE_ID')) - ms.set_context(device_id=device_id) - single_size = True - init() -else: - init("nccl") - single_size = False -``` - -Graph Operator compilation optimization settings is similar to that of [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#setting-environment-variables). - -### Defining a Training Network - -Instantiation of the model body SAGENet and LossNet and optimizer. -The implementation method is similar to that of the [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#defining-a-training-network). - -### Network Training and Validation - -For the training and validation methods, refer to [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#network-training-and-validation-1). - -## Executing Jobs and Viewing Results - -### Running Process - -After running the program, translate the code and start training. - -### Execution Results - -Run the [distributed_run.sh](https://gitee.com/mindspore/graphlearning/blob/master/model_zoo/graphsage/distributed_run.sh) script to start training. - -- GPU - - ```bash - cd model_zoo/graphsage - bash distributed_run.sh GPU DATA_PATH - ``` - - `{DATA_PATH}` indicates the dataset storage path. - -- Ascend - - ```bash - cd model_zoo/graphsage - bash bash distributed_run.sh Ascend DATA_PATH RANK_START RANK_SIZE RANK_TABLE_FILE - ``` - - `{DATA_PATH}` indicates the dataset storage path. {ANK_START} is the first Ascend device id be used. `{RANK_SIZE}` is numbers of Ascend device be used. `{RANK_TABLE_FILE}` is root path of 'rank_table_*pcs.json' file. - -The training result is as follows: - -```bash -... -Iteration/Epoch: 30:4 train loss: 0.41629112 -Iteration/Epoch: 30:4 train loss: 0.5337528 -Iteration/Epoch: 30:4 train loss: 0.42849028 -Iteration/Epoch: 30:4 train loss: 0.5358513 -rank_id:3 Epoch/Time: 4:76.17579555511475 -rank_id:1 Epoch/Time: 4:37.79207944869995 -rank_id:2 Epoch/Time: 4:76.04292225837708 -rank_id:0 Epoch/Time: 4:75.64319372177124 -rank_id:2 test accuracy : 0.9276439525462963 -rank_id:0 test accuracy : 0.9305013020833334 -rank_id:3 test accuracy : 0.9290907118055556 -rank_id:1 test accuracy : 0.9279513888888888 -``` - -Accuracy verified on Reddit: 0.92. diff --git a/docs/graphlearning/docs/source_en/spatio_temporal_graph_training_STGCN.md b/docs/graphlearning/docs/source_en/spatio_temporal_graph_training_STGCN.md deleted file mode 100644 index 0385420640229208b37fc2fa6567cf0f762dff28..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_en/spatio_temporal_graph_training_STGCN.md +++ /dev/null @@ -1,165 +0,0 @@ -# Spatio-Temporal Graph Training Network - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_en/spatio_temporal_graph_training_STGCN.md) -   - -## Overview - -In this example, it will show how to forecast the traffic by Spatio-temporal Graph Convolutional Networks. - -Spatio-Temporal Graph Convolutional Networks (STGCN) can tackle the time series prediction problem in traffic domain. Experiments show that STGCN effectively captures comprehensive spatio-temporal correlations through modeling multi-scale traffic networks. - -METR-LA is a large-scale data set collected from 1,500 traffic loop detectors in the Los Angeles rural road network. This data set includes speed, road capacity, and occupancy data and covers approximately 3,420 miles. The road network is constructed into a graph and input to the STGCN network. The road network information in the next time phase is predicted based on the historical data. - -The node feature shape of a general graph is `(nodes number, feature dimension)`, but the feature shape of a spatio-temporal graph is usually at least 3-dimensional `(nodes number, feature dimension, time step)`, and the feature fusion processing of neighbor nodes will be more complicated. And due to the convolution in the time dimension, the `time step` will also change. When calculating the loss, it is necessary to calculate the output time length in advance. - -> Download the complete sample code here: [STGCN](https://gitee.com/mindspore/graphlearning/tree/master/model_zoo/stgcn). - -## STGCN Principles - -Paper: [A deep learning framework for traffic forecasting](https://arxiv.org/pdf/1709.04875.pdf) - -## Graph Laplacian Normalization - -The self-loop of the graph is deleted, and the graph is normalized to obtain the new edge index and edge weight. -mindspore_gl.graph implements norm, which can be used for laplacian normalization. The code for normalization of edge index and edge weight is as follows: - -```python -mask = edge_index[0] != edge_index[1] -edge_index = edge_index[:, mask] -edge_attr = edge_attr[mask] - -edge_index = ms.Tensor(edge_index, ms.int32) -edge_attr = ms.Tensor(edge_attr, ms.float32) -edge_index, edge_weight = norm(edge_index, node_num, edge_attr, args.normalization) -``` - -For details about laplacian normalization, see the [API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/graph/norm.py) code of mindspore_gl.graph.norm. - -## Defining a Network Model - -mindspore_gl.nn implements STConv, which can be directly imported for use. Different from the general graph convolution layer, the input features of STConv are 4-dimensional, that is, `(batch graphs number, time step, nodes number, feature dimension)`. -The `time step` of the output feature needs to be calculated according to the size of the 1D convolution kernel and the times of convolutions. - -The code for implementing a two-layer STGCN network using STConv is as follows: - -```python -class STGcnNet(GNNCell): - """ STGCN Net """ - def __init__(self, - num_nodes: int, - in_channels: int, - hidden_channels_1st: int, - out_channels_1st: int, - hidden_channels_2nd: int, - out_channels_2nd: int, - out_channels: int, - kernel_size: int, - k: int, - bias: bool = True): - super().__init__() - self.layer0 = STConv(num_nodes, in_channels, - hidden_channels_1st, - out_channels_1st, - kernel_size, - k, bias) - self.layer1 = STConv(num_nodes, out_channels_1st, - hidden_channels_2nd, - out_channels_2nd, - kernel_size, - k, bias) - self.relu = ms.nn.ReLU() - self.fc = ms.nn.Dense(out_channels_2nd, out_channels) - - def construct(self, x, edge_weight, g: Graph): - x = self.layer0(x, edge_weight, g) - x = self.layer1(x, edge_weight, g) - x = self.relu(x) - x = self.fc(x) - return x -``` - -For details about STConv implementation, see the [API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/nn/temporal/stconv.py) code of mindspore_gl.nn.temporal.STConv. - -## Defining a Loss Function - -Since this task is a regression task, the minimum mean square error can be used as the loss function. In this example, mindspore.nn.MSELoss is used to implement a mean square error loss. - -```python -class LossNet(GNNCell): - """ LossNet definition """ - def __init__(self, net): - super().__init__() - self.net = net - self.loss_fn = nn.loss.MSELoss() - - def construct(self, feat, edges, target, g: Graph): - """STGCN Net with loss function""" - predict = self.net(feat, edges, g) - predict = ops.Squeeze()(predict) - loss = self.loss_fn(predict, target) - return ms.ops.ReduceMean()(loss) -``` - -## Constructing a Dataset - -Input feature is `(batch graphs number, time step, nodes number, feature dimension)`. The length of the time series changed after time convolution. Therefore, the input and output timestamps must be specified when features and tags are obtained from datasets. Otherwise, the shape of the predicted value is inconsistent with that of the label value. - -For details about the restriction specifications, see the code comments. - -```python -from mindspore_gl.dataset import MetrLa -metr = MetrLa(args.data_path) -# out_timestep setting -# out_timestep = in_timestep - ((kernel_size - 1) * 2 * layer_nums) -# such as: layer_nums = 2, kernel_size = 3, in_timestep = 12, -# out_timestep = 4 -features, labels = metr.get_data(args.in_timestep, args.out_timestep) -``` - -The [MetrLa](https://graphmining.ai/temporal_datasets/METR-LA.zip) data can be downloaded and decompressed to args.data_path. - -## Network Training and Validation - -### Setting Environment Variables - -The method of setting environment variables is similar to that of setting [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#setting-environment-variables). - -### Defining a Training Network - -Instantiation of the model body STGcnNet and LossNet and optimizer. -The implementation method is similar to that of the [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#defining-a-training-network). - -### Network Training and Validation - -The implementation method is similar to that of the [GCN](https://www.mindspore.cn/graphlearning/docs/en/master/full_training_of_GCN.html#network-training-and-validation-1). - -## Executing Jobs and Viewing Results - -### Running Process - -After running the program, translate the code and start training. - -### Execution Results - -Run the [trainval_metr.py](https://gitee.com/mindspore/graphlearning/blob/master/model_zoo/stgcn/trainval_metr.py) script to start training. - -```bash -cd model_zoo/stgcn -python trainval_metr.py --data-path={path} --fuse=True -``` - -`{path}` indicates the dataset storage path. - -The training result is as follows: - -```bash -... -Iteration/Epoch: 600:199 loss: 0.21488506 -Iteration/Epoch: 700:199 loss: 0.21441595 -Iteration/Epoch: 800:199 loss: 0.21243602 -Time 13.162885904312134 Epoch loss 0.21053028 -eval MSE: 0.2060675 -``` - -MSE on MetrLa: 0.206 diff --git a/docs/graphlearning/docs/source_zh_cn/_templates/classtemplate.rst b/docs/graphlearning/docs/source_zh_cn/_templates/classtemplate.rst deleted file mode 100644 index 37a8e95499c8343ad3f8e02d5c9095215fd9010a..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/_templates/classtemplate.rst +++ /dev/null @@ -1,27 +0,0 @@ -.. role:: hidden - :class: hidden-section - -.. currentmodule:: {{ module }} - -{% if objname in [] %} -{{ fullname | underline }} - -.. autofunction:: {{ fullname }} - -{% elif objname[0].istitle() %} -{{ fullname | underline }} - -.. autoclass:: {{ name }} - :exclude-members: construct - :members: - -{% else %} -{{ fullname | underline }} - -.. autofunction:: {{ fullname }} - -{% endif %} - -.. - autogenerated from _templates/classtemplate.rst - note it does not have :inherited-members: diff --git a/docs/graphlearning/docs/source_zh_cn/batched_graph_training_GIN.md b/docs/graphlearning/docs/source_zh_cn/batched_graph_training_GIN.md deleted file mode 100644 index ef83cbfdc13d63388aa1d6176fa821cd961664c4..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/batched_graph_training_GIN.md +++ /dev/null @@ -1,276 +0,0 @@ -# 批次图训练网络 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_zh_cn/batched_graph_training_GIN.md) -   - -## 概述 - -在本例中将展示如何基于图同构网络的进行社会关系网络分类。 - -GIN的灵感来自GNN和Weisfeiler-Lehman (WL)图同构测试。WL测试是一个强大的测试,可以区分广泛的图类。如果GNN的聚合方案具有高度的表达能力,并且可以建模内射函数,GNN可以具有与WL测试一样大的鉴别力。 - -IMDB-BINARY是一个电影协作数据集,由1000名在IMDB中扮演电影角色的演员的角色网络组成。在每张图中,节点代表演员,如果他们出演过同一部电影,在节点直接建立一条边。这些图都来源于动作或浪漫电影。 -分批次从IMDB-BINARY数据集中取出图数据,每张图都是由演员构成的电影,利用GIN对图进行分类,预测电影属于什么风格。 - -批次图模式中每次能够对多张图同时进行训练,并且每张图的节点数/边数都完全不同。mindspore_gl提供了构建虚拟图的方法将对批次内图整合成一张整图,并对整图数据进行统一,以降低内存消耗及加速计算。 - -> 下载完整的样例[GIN](https://gitee.com/mindspore/graphlearning/tree/master/model_zoo/gin)代码。 - -## GIN原理 - -论文链接:[How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf) - -## 定义网络结构 - -GINConv将图`g`解析为`BatchedGraph`,与`Graph`相比`BatchedGraph`能够支持更多图操作。输入的数据为整图,但是每张子图进行节点特征更新时,还是能根据自身的节点找到对应的邻居节点,而不会连接到其他子图的节点。 - -mindspore_gl.nn提供了GINConv的API可以直接调用。使用GINConv,再配合批次归一化、池化等操作实现一个多层的GinNet网络代码如下: - -```python -class GinNet(GNNCell): - """GIN net""" - def __init__(self, - num_layers, - num_mlp_layers, - input_dim, - hidden_dim, - output_dim, - final_dropout=0.1, - learn_eps=False, - graph_pooling_type='sum', - neighbor_pooling_type='sum' - ): - super().__init__() - self.final_dropout = final_dropout - self.num_layers = num_layers - self.graph_pooling_type = graph_pooling_type - self.neighbor_pooling_type = neighbor_pooling_type - self.learn_eps = learn_eps - - self.mlps = nn.CellList() - self.convs = nn.CellList() - self.batch_norms = nn.CellList() - - if self.graph_pooling_type not in ('sum', 'avg'): - raise SyntaxError("graph pooling type not supported.") - for layer in range(num_layers - 1): - if layer == 0: - self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)) - else: - self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)) - self.convs.append(GINConv(ApplyNodeFunc(self.mlps[layer]), learn_eps=self.learn_eps, - aggregation_type=self.neighbor_pooling_type)) - self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) - - self.linears_prediction = nn.CellList() - for layer in range(num_layers): - if layer == 0: - self.linears_prediction.append(nn.Dense(input_dim, output_dim)) - else: - self.linears_prediction.append(nn.Dense(hidden_dim, output_dim)) - - def construct(self, x, edge_weight, g: BatchedGraph): - """construct function""" - hidden_rep = [x] - h = x - for layer in range(self.num_layers - 1): - h = self.convs[layer](h, edge_weight, g) - h = self.batch_norms[layer](h) - h = nn.ReLU()(h) - hidden_rep.append(h) - - score_over_layer = 0 - for layer, h in enumerate(hidden_rep): - if self.graph_pooling_type == 'sum': - pooled_h = g.sum_nodes(h) - else: - pooled_h = g.avg_nodes(h) - score_over_layer = score_over_layer + nn.Dropout(p=1.0 - self.final_dropout)( - self.linears_prediction[layer](pooled_h)) - return score_over_layer -``` - -GINConv执行的更多细节可以看mindspore_gl.nn.GINConv的[API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/nn/conv/ginconv.py)代码。 - -## 构造数据集 - -从mindspore_gl.dataset调用了IMDB-BINARY的数据集,调用方法可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E6%9E%84%E9%80%A0%E6%95%B0%E6%8D%AE%E9%9B%86)。然后利用mindspore_gl.dataloader.RandomBatchSampler定义了一个采样器,来生成采样索引。 -MultiHomoGraphDataset根据采样索引从数据集里获取数据,将返回数据打包成batch,做出数据集的生成器。构建生成器后,调用mindspore.dataset.GeneratorDataset的API,完成数据加载器构建。 - -```python -dataset = IMDBBinary(arguments.data_path) -train_batch_sampler = RandomBatchSampler(dataset.train_graphs, batch_size=arguments.batch_size) -train_multi_graph_dataset = MultiHomoGraphDataset(dataset, arguments.batch_size, len(list(train_batch_sampler))) -test_batch_sampler = RandomBatchSampler(dataset.val_graphs, batch_size=arguments.batch_size) -test_multi_graph_dataset = MultiHomoGraphDataset(dataset, arguments.batch_size, len(list(test_batch_sampler))) - -train_dataloader = ds.GeneratorDataset(train_multi_graph_dataset, ['row', 'col', 'node_count', 'edge_count', - 'node_map_idx', 'edge_map_idx', 'graph_mask', - 'batched_label', 'batched_node_feat', - 'batched_edge_feat'], - sampler=train_batch_sampler) - -test_dataloader = ds.GeneratorDataset(test_multi_graph_dataset, ['row', 'col', 'node_count', 'edge_count', - 'node_map_idx', 'edge_map_idx', 'graph_mask', - 'batched_label', 'batched_node_feat', - 'batched_edge_feat'], - sampler=test_batch_sampler) -``` - -利用mindspore_gl.graph.BatchHomoGraph将多张子图合并成一张整图。在模型训练时,batch内所有图将以一张整图的形式进行计算。 - -为了减少计算图的生成,加快计算速度,生成器在返回数据时,将每个batch中的数据统一到相同的尺寸。 - -假设节点数`node_size`与边数`edge_size`,并满足batch内所有图数据的节点数之和与边数之和都要都小于等于`node_size * batch`和`edge_size * batch`。 -在batch内新建张虚拟图,使得batch内图节点数和、边数和等于`node_size * batch`和`edge_size * batch`。在计算loss时,这张图将不参与计算。 - -调用mindspore_gl.graph.PadArray2d定义节点和边特征填充的操作,将虚拟图上的节点特征和边特征都设置为0。 -调用mindspore_gl.graph.PadHomoGraph定义对图结构上的节点和边进行填充的操作,使得batch内节点数等于`node_size * batch`,边数等于`edge_size * batch`。 - -```python -class MultiHomoGraphDataset(Dataset): - """MultiHomoGraph Dataset""" - def __init__(self, dataset, batch_size, length, mode=PadMode.CONST, node_size=50, edge_size=350): - self._dataset = dataset - self._batch_size = batch_size - self._length = length - self.batch_fn = BatchHomoGraph() - self.batched_edge_feat = None - node_size *= batch_size - edge_size *= batch_size - if mode == PadMode.CONST: - self.node_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.CONST, direction=PadDirection.COL, - size=(node_size, dataset.node_feat_size), fill_value=0) - self.edge_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.CONST, direction=PadDirection.COL, - size=(edge_size, dataset.edge_feat_size), fill_value=0) - self.graph_pad_op = PadHomoGraph(n_edge=edge_size, n_node=node_size, mode=PadMode.CONST) - else: - self.node_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.AUTO, direction=PadDirection.COL, - fill_value=0) - self.edge_feat_pad_op = PadArray2d(dtype=np.float32, mode=PadMode.AUTO, direction=PadDirection.COL, - fill_value=0) - self.graph_pad_op = PadHomoGraph(mode=PadMode.AUTO) - - # For Padding - self.train_mask = np.array([True] * (self._batch_size + 1)) - self.train_mask[-1] = False - - def __getitem__(self, batch_graph_idx): - graph_list = [] - feature_list = [] - for idx in range(batch_graph_idx.shape[0]): - graph_list.append(self._dataset[batch_graph_idx[idx]]) - feature_list.append(self._dataset.graph_node_feat(batch_graph_idx[idx])) - - # Batch Graph - batch_graph = self.batch_fn(graph_list) - - # Pad Graph - batch_graph = self.graph_pad_op(batch_graph) - - # Batch Node Feat - batched_node_feat = np.concatenate(feature_list) - - # Pad NodeFeat - batched_node_feat = self.node_feat_pad_op(batched_node_feat) - batched_label = self._dataset.graph_label[batch_graph_idx] - - # Pad Label - batched_label = np.append(batched_label, batched_label[-1] * 0) - - # Get Edge Feat - if self.batched_edge_feat is None or self.batched_edge_feat.shape[0] < batch_graph.edge_count: - del self.batched_edge_feat - self.batched_edge_feat = np.ones([batch_graph.edge_count, 1], dtype=np.float32) - - # Trigger Node_Map_Idx/Edge_Map_Idx Computation, Because It Is Lazily Computed - _ = batch_graph.batch_meta.node_map_idx - _ = batch_graph.batch_meta.edge_map_idx - - np_graph_mask = [1] * (self._batch_size + 1) - np_graph_mask[-1] = 0 - constant_graph_mask = ms.Tensor(np_graph_mask, dtype=ms.int32) - batchedgraphfiled = self.get_batched_graph_field(batch_graph, constant_graph_mask) - row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask = batchedgraphfiled.get_batched_graph() - return row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask, batched_label,\ - batched_node_feat, self.batched_edge_feat[:batch_graph.edge_count, :] -``` - -## 定义loss函数 - -由于本次任务为分类任务,可以采用交叉熵来作为损失函数,实现方法与[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E5%AE%9A%E4%B9%89loss%E5%87%BD%E6%95%B0)类似。 - -与GCN不同的是,本次教程为图分类,因此在解析批次图时,调用的为mindspore_gl.BatchedGraph接口。 - -在`g.graph_mask`中最后一位为虚拟图的mask,等于0,因此在计算loss时,最后1个值也为0。 - -```python -class LossNet(GNNCell): - """ LossNet definition """ - def __init__(self, net): - super().__init__() - self.net = net - self.loss_fn = ms.nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none') - - def construct(self, node_feat, edge_weight, target, g: BatchedGraph): - predict = self.net(node_feat, edge_weight, g) - target = ops.Squeeze()(target) - loss = self.loss_fn(predict, target) - loss = ops.ReduceSum()(loss * g.graph_mask) - return loss -``` - -## 网络训练和验证 - -### 设置环境变量 - -环境变量设置方法可以[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E8%AE%BE%E7%BD%AE%E7%8E%AF%E5%A2%83%E5%8F%98%E9%87%8F)。 - -### 定义训练网络 - -实例化模型主体GinNet以及LossNet和优化器。 -将LossNet实例和optimizer传入mindspore.nn.TrainOneStepCell构建一个单步训练网络train_net。 -实现方法与GCN类似,可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E5%AE%9A%E4%B9%89%E8%AE%AD%E7%BB%83%E7%BD%91%E7%BB%9C)。 - -### 网络训练及验证 - -由于是批次图训练,构图时调用的API为mindspore_gl.BatchedGraphField,与mindspore_gl.GraphField不同的是,增加了`node_map_idx`、`edge_map_idx`、`graph_mask`三个参数。 -其中在`graph_mask`为batch中每个图的掩码信息,由于最后1张图为虚构图,因此在`graph_mask`数组中,最后1位为0,其余为1。 - -```python -from mindspore_gl import BatchedGraph, BatchedGraphField - -for data in train_dataloader: - row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask, label, node_feat, edge_feat = data - batch_homo = BatchedGraphField(row, col, node_count, edge_count, node_map_idx, edge_map_idx, graph_mask) - output = net(node_feat, edge_feat, *batch_homo.get_batched_graph()).asnumpy() -``` - -## 执行并查看结果 - -### 运行过程 - -运行程序后,进行代码翻译并开始训练。 - -### 执行结果 - -执行脚本[trainval_imdb_binary.py](https://gitee.com/mindspore/graphlearning/blob/master/model_zoo/gin/trainval_imdb_binary.py)启动训练。 - -```bash -cd model_zoo/gin -python trainval_imdb_binary.py --data_path={path} -``` - -其中`{path}`为数据集存放路径。 - -可以看到训练的结果如下: - -```bash -... -Epoch 52, Time 3.547 s, Train loss 0.49981827, Train acc 0.74219, Test acc 0.594 -Epoch 53, Time 3.599 s, Train loss 0.5046462, Train acc 0.74219, Test acc 0.656 -Epoch 54, Time 3.505 s, Train loss 0.49653444, Train acc 0.74777, Test acc 0.766 -Epoch 55, Time 3.468 s, Train loss 0.49411067, Train acc 0.74219, Test acc 0.750 -``` - -在IMDBBinary最好的验证精度为:0.766 diff --git a/docs/graphlearning/docs/source_zh_cn/conf.py b/docs/graphlearning/docs/source_zh_cn/conf.py deleted file mode 100644 index 20c383ca9c80ed7b66c20678507b91e4b997fa70..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/conf.py +++ /dev/null @@ -1,268 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import glob -import os -import shutil -import IPython -import re -import sys -from sphinx.ext import autodoc as sphinx_autodoc -import sphinx.ext.autosummary.generate as g -sys.path.append(os.path.abspath('../_ext')) - -from sphinx import directives -with open('../_ext/overwriteobjectiondirective.txt', 'r', encoding="utf8") as f: - exec(f.read(), directives.__dict__) - -from sphinx.ext import viewcode -with open('../_ext/overwriteviewcode.txt', 'r', encoding="utf8") as f: - exec(f.read(), viewcode.__dict__) - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_parser', - 'sphinx.ext.mathjax', - 'IPython.sphinxext.ipython_console_highlighting' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -autodoc_inherit_docstrings = False - -autosummary_generate = True - -autosummary_generate_overwrite = False - -html_static_path = ['_static'] - -# -- Options for HTML output ------------------------------------------------- - -# Reconstruction of sphinx auto generated document translation. - -language = 'zh_CN' -locale_dirs = ['../../../../resource/locale/'] -gettext_compact = False - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -html_search_language = 'zh' - -html_search_options = {'dict': '../../../resource/jieba.txt'} - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/', '../../../../resource/python_objects.inv'), - 'numpy': ('https://docs.scipy.org/doc/numpy/', '../../../../resource/numpy_objects.inv'), -} - -# Modify regex for sphinx.ext.autosummary.generate.find_autosummary_in_lines. -gfile_abs_path = os.path.abspath(g.__file__) -autosummary_re_line_old = r"autosummary_re = re.compile(r'^(\s*)\.\.\s+autosummary::\s*')" -autosummary_re_line_new = r"autosummary_re = re.compile(r'^(\s*)\.\.\s+(ms[a-z]*)?autosummary::\s*')" -with open(gfile_abs_path, "r+", encoding="utf8") as f: - data = f.read() - data = data.replace(autosummary_re_line_old, autosummary_re_line_new) - exec(data, g.__dict__) - -# Modify default signatures for autodoc. -autodoc_source_path = os.path.abspath(sphinx_autodoc.__file__) -autodoc_source_re = re.compile(r'stringify_signature\(.*?\)') -get_param_func_str = r"""\ -import re -import inspect as inspect_ - -def get_param_func(func): - try: - source_code = inspect_.getsource(func) - if func.__doc__: - source_code = source_code.replace(func.__doc__, '') - all_params_str = re.findall(r"def [\w_\d\-]+\(([\S\s]*?)(\):|\) ->.*?:)", source_code) - all_params = re.sub("(self|cls)(,|, )?", '', all_params_str[0][0].replace("\n", "").replace("'", "\"")) - return all_params - except: - return '' - -def get_obj(obj): - if isinstance(obj, type): - return obj.__init__ - - return obj -""" - -with open(autodoc_source_path, "r+", encoding="utf8") as f: - code_str = f.read() - code_str = autodoc_source_re.sub('"(" + get_param_func(get_obj(self.object)) + ")"', code_str, count=0) - exec(get_param_func_str, sphinx_autodoc.__dict__) - exec(code_str, sphinx_autodoc.__dict__) - -with open("../_ext/customdocumenter.txt", "r", encoding="utf8") as f: - code_str = f.read() - exec(code_str, sphinx_autodoc.__dict__) - -# Copy source files of chinese python api from mindscience repository. -from sphinx.util import logging -logger = logging.getLogger(__name__) - -copy_path = 'docs/api_python' -src_dir = os.path.join(os.getenv("GL_PATH"), copy_path) - -copy_list = [] - -present_path = os.path.dirname(__file__) - -for i in os.listdir(src_dir): - if os.path.isfile(os.path.join(src_dir,i)): - if os.path.exists('./'+i): - os.remove('./'+i) - shutil.copy(os.path.join(src_dir,i),'./'+i) - copy_list.append(os.path.join(present_path,i)) - else: - if os.path.exists('./'+i): - shutil.rmtree('./'+i) - shutil.copytree(os.path.join(src_dir,i),'./'+i) - copy_list.append(os.path.join(present_path,i)) - -# add view -import json - -if os.path.exists('../../../../tools/generate_html/version.json'): - with open('../../../../tools/generate_html/version.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily_dev.json'): - with open('../../../../tools/generate_html/daily_dev.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily.json'): - with open('../../../../tools/generate_html/daily.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) - -if os.getenv("GL_PATH").split('/')[-1]: - copy_repo = os.getenv("GL_PATH").split('/')[-1] -else: - copy_repo = os.getenv("GL_PATH").split('/')[-2] - -branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == copy_repo][0] -docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == 'tutorials'][0] - -re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{docs_branch}/" + \ - f"resource/_static/logo_source.svg\n :target: https://gitee.com/mindspore/{copy_repo}/blob/{branch}/" - -for cur, _, files in os.walk(present_path): - for i in files: - flag_copy = 0 - if i.endswith('.rst'): - for j in copy_list: - if j in cur: - flag_copy = 1 - break - if os.path.join(cur, i) in copy_list or flag_copy: - try: - with open(os.path.join(cur, i), 'r+', encoding='utf-8') as f: - content = f.read() - new_content = content - if '.. include::' in content and '.. automodule::' in content: - continue - if 'autosummary::' not in content and "\n=====" in content: - re_view_ = re_view + copy_path + cur.split(present_path)[-1] + '/' + i + \ - '\n :alt: 查看源文件\n\n' - new_content = re.sub('([=]{5,})\n', r'\1\n' + re_view_, content, 1) - if new_content != content: - f.seek(0) - f.truncate() - f.write(new_content) - except Exception: - print(f'打开{i}文件失败') - -import mindspore_gl - -sys.path.append(os.path.abspath('../../../../resource/sphinx_ext')) -# import anchor_mod -import nbsphinx_mod - -sys.path.append(os.path.abspath('../../../../resource/search')) -import search_code - -sys.path.append(os.path.abspath('../../../../resource/custom_directives')) -from custom_directives import IncludeCodeDirective -from myautosummary import MsPlatformAutoSummary, MsCnPlatformAutoSummary, MsCnAutoSummary - -rst_files = set([i.replace('.rst', '') for i in glob.glob('api_python/**/*.rst', recursive=True)]) - -def setup(app): - app.add_directive('msplatformautosummary', MsPlatformAutoSummary) - app.add_directive('mscnplatformautosummary', MsCnPlatformAutoSummary) - app.add_directive('mscnautosummary', MsCnAutoSummary) - app.add_directive('includecode', IncludeCodeDirective) - app.add_config_value('rst_files', set(), False) - -src_release = os.path.join(os.getenv("GL_PATH"), 'RELEASE_CN.md') -des_release = "./RELEASE.md" -with open(src_release, "r", encoding="utf-8") as f: - data = f.read() -if len(re.findall("\n## (.*?)\n",data)) > 1: - content = re.findall("(## [\s\S\n]*?)\n## ", data) -else: - content = re.findall("(## [\s\S\n]*)", data) -#result = content[0].replace('# MindSpore', '#', 1) -with open(des_release, "w", encoding="utf-8") as p: - p.write("# Release Notes"+"\n\n") - p.write(content[0]) \ No newline at end of file diff --git a/docs/graphlearning/docs/source_zh_cn/faq.md b/docs/graphlearning/docs/source_zh_cn/faq.md deleted file mode 100644 index 5d26754931e7b6880d1f2578f4c18500ca33c54c..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/faq.md +++ /dev/null @@ -1,37 +0,0 @@ -# FAQ - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_zh_cn/faq.md) - -**Q:命令行执行GNNCell报错`OSError: could not get source code`怎么办?** - -A:MindSpore Graph Learning源到源翻译解析以点为中心的编程代码,中间调用了inspect来获取翻译module的源码,需要将模型定义代码放到Python文件中,否则会报错找不到源码。 - -
    - -**Q:执行GNNCell报错`AttributeError: None of backend from {mindspore} is identified. Backend must be imported as a global variable.`怎么办?** - -A:MindSpore Graph Learning通过源到源翻译解析以点为中心的编程代码,在GNNCell定义的文件根据全局变量获取网络执行后端,需要在GNNCell定义文件头部import mindspore,否则会报错找不到后端。 - -
    - -**Q:调用图聚合接口'sum、avg、max、min'等时`TypeInferenceError: Line 6: Built-in agg func "avg" only takes expr of EDGE or SRC type. Got None.`怎么办?** - -A:MindSpore Graph Learning前端提供的聚合接口均为针对图节点的操作,在源到源翻译过程会判断聚合接口的输入是否为图中的边或节点,否则报错找不到合乎规则的输入类型。 - -
    - -**Q:调用图接口'dot'时`RuntimeError: The 'mul' operation does not support the type.`怎么办?** - -A:MindSpore Graph Learning前端提供的dot接口为针对图节点的点乘操作,后端包含特征乘和聚合加两步,前端翻译过程不涉及编译无法判断输入数据类型,输入类型必须符合后端mul算子的类型要求,否则会报错类型不支持。 - -
    - -**Q:调用图接口'topk_nodes,topk_edges'时`TypeError: For 'tensor getitem', the types only support 'Slice', 'Ellipsis', 'None', 'Tensor', 'int', 'List', 'Tuple', 'bool', but got String.`怎么办?** - -A:MindSpore Graph Learning前端提供的topk_nodes接口为针对图节点/边特征排序取k个节点/边的操作,后端包含获取节点/边特征、排序sort和slice取k个三步,前端翻译过程不涉及编译无法判断输入数据类型,输入类型必须符合sort和slice算子的排序维度sortby和取值范围k的类型要求,否则会报错类型不支持。 - -
    - -**Q:construct的输入graph传入非GraphField实例或等价tensor时`TypeError: For 'Cell', the function construct need 5 positional argument, but got 2.'`怎么办?** - -A:MindSpore Graph Learning前端提供的GNNCell为写以点为中心编程GNN模型的基类,必须包含Graph类为最后一个输入参数,翻译后对应的输入为4个Tensor参数,分别为src_idx, dst_idx, n_nodes, n_edges, 如果仅传入非GraphField实例或等价的4个tensor,就会报参数输入不对的错误。 \ No newline at end of file diff --git a/docs/graphlearning/docs/source_zh_cn/full_training_of_GCN.md b/docs/graphlearning/docs/source_zh_cn/full_training_of_GCN.md deleted file mode 100644 index 89a01626aeaed3f398c40ee235db1bbeb288b294..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/full_training_of_GCN.md +++ /dev/null @@ -1,263 +0,0 @@ -# 整图训练网络 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_zh_cn/full_training_of_GCN.md) -   - -## 概述 - -在本例中将展示如何在Cora数据集上进行图卷积网络的半监督分类。 - -图卷积网络(GCN)于2016年提出,旨在对图结构数据进行半监督学习。提出了一种基于卷积神经网络的有效变体的可扩展方法,该方法直接在图上操作。该模型在图边的数量上线性缩放,并学习编码本地图结构和节点特征的隐藏层表示。 - -Cora数据集包括2708份科学出版物,分为七类之一。引文网络由10556个链接组成。数据集中的每个发布都由0/1值单词向量描述,指示词典中相应单词的不存在/存在。该词典由1433个独特的单词组成。 - -将Cora的文献的分类作为标签,文献的单词向量作为GCN的节点特征,文献的引用作为边,构图后利用GCN进行训练,判断文献应该属于哪个类。 - -> 下载完整的样例[GCN](https://gitee.com/mindspore/graphlearning/tree/master/examples/)代码。 - -## GCN原理 - -论文链接:[Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) - -## 定义网络结构 - -mindspore_gl.nn实现了GCNConv,可以直接导入使用,用户也可以自己定义卷积层。使用GCNConv实现一个两层的GCN网络代码如下: - -```python -import mindspore -from mindspore_gl.nn import GNNCell -from mindspore_gl import Graph -from mindspore_gl.nn import GCNConv - -class GCNNet(GNNCell): - def __init__(self, - data_feat_size: int, - hidden_dim_size: int, - n_classes: int, - dropout: float, - activation): - super().__init__() - self.layer0 = GCNConv(data_feat_size, hidden_dim_size, activation(), dropout) - self.layer1 = GCNConv(hidden_dim_size, n_classes, None, dropout) - - def construct(self, x, in_deg, out_deg, g: Graph): - x = self.layer0(x, in_deg, out_deg, g) - x = self.layer1(x, in_deg, out_deg, g) - return x -``` - -其中定义的GCNNet继承于GNNCell。GNNCell中construct函数的最后一项输入必须为Graph或者BatchedGraph,也就是MindSpore Graph Learning内置支持的图结构类。此外必须在文件的头部导入 mindspore便于代码翻译时识别执行后端。 - -GCNConv的参数data_feat_size为输入节点特征维度,hidden_dim_size为隐层特征维度,n_classes为输出分类的维度,in_deg和out_deg分别为图数据中节点的入度和出度。 - -具体GCN的实现可以参考mindspore_gl.nn.GCNConv的接口代码:。 - -## 定义loss函数 - -接下来定义LossNet,包含了网络主干net和loss function两部分,这里利用mindspore.nn.SoftmaxCrossEntropyWithLogits实现交叉熵loss。 - -```python -import mindspore as ms -import mindspore.nn as nn -import mindspore.ops as ops -from mindspore_gl.nn import GNNCell - - -class LossNet(GNNCell): - """ LossNet definition """ - - def __init__(self, net): - super().__init__() - self.net = net - self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none') - - def construct(self, x, in_deg, out_deg, train_mask, target, g: Graph): - predict = self.net(x, in_deg, out_deg, g) - target = ops.Squeeze()(target) - loss = self.loss_fn(predict, target) - loss = loss * train_mask - return ms.ops.ReduceSum()(loss) / ms.ops.ReduceSum()(train_mask) -``` - -其中net可以通过构建一个LossNet的实例传入GCNNet。predict为net输出的预测值,target为预测真实值,由于是整图训练,通过train_mask从整图中获取一部分作为训练数据,仅这部分节点参与loss计算。 - -LossNet和GCNNet一样继承自GNNCell。 - -## 构造数据集 - -在mindspore_gl.dataset目录下提供了一些dataset类定义的参考。可以直接读入一些研究常用数据集,这里用cora数据集为例,输入数据路径data_path即可构建数据类。 - -```python -from mindspore_gl.dataset import CoraV2 - -ds = CoraV2(args.data_path) -``` - -其中[Cora](https://data.dgl.ai/dataset/cora_v2.zip)数据下载后,解压路径即为args.data_path。 - -## 网络训练和验证 - -### 设置环境变量 - -环境变量的设置同MindSpore其他网络训练,特别的是设置enable_graph_kernel=True可以启动图算编译优化,加速图模型的训练。 - -```python -import mindspore as ms -import os - -if train_args.fuse: - ms.set_context(device_target="GPU", save_graphs=2, save_graphs_path="./computational_graph/", - mode=ms.GRAPH_MODE, enable_graph_kernel=True) - graph_kernel_flags="--enable_expand_ops=Gather --enable_cluster_ops=TensorScatterAdd," - "UnsortedSegmentSum, GatherNd --enable_recompute_fusion=false " - "--enable_parallel_fusion=true " - os.environ['MS_DEV_GRAPH_KERNEL_FLAGS'] = graph_kernel_flags -else: - ms.set_context(device_target="GPU", mode=ms.PYNATIVE_MODE) -``` - -### 定义训练网络 - -图神经网络的训练如同其他监督学习模型,除了实例化模型主体GCNNet以及LossNet,还需定义优化器,这里用的mindspore.nn.Adam。 - -将LossNet实例和optimizer传入mindspore.nn.TrainOneStepCell构建一个单步训练网络train_net。 - -```python -import mindspore.nn as nn - -net = GCNNet(data_feat_size=feature_size, - hidden_dim_size=train_args.num_hidden, - n_classes=ds.n_classes, - dropout=train_args.dropout, - activation=ms.nn.ELU) - optimizer = nn.optim.Adam(net.trainable_params(), learning_rate=train_args.lr, weight_decay=train_args.weight_decay) - loss = LossNet(net) - train_net = nn.TrainOneStepCell(loss, optimizer) -``` - -### 网络训练及验证 - -由于是整图训练,一步训练就覆盖了整个数据集,每个epoch即为一步训练,同样验证节点通过test_mask获取,验证准确率的计算只需取出整图中的验证节点与真实值label进行比较计算:预测值与真实值一致即为正确,正确节点数count与验证节点总数的比值即为验证准确率。 - -```python -for e in range(train_args.epochs): - beg = time.time() - train_net.set_train() - train_loss = train_net() - end = time.time() - dur = end - beg - if e >= warm_up: - total = total + dur - - test_mask = ds.test_mask - if test_mask is not None: - net.set_train(False) - out = net(ds.x, ds.in_deg, ds.out_deg, ds.g.src_idx, ds.g.dst_idx, ds.g.n_nodes, ds.g.n_edges).asnumpy() - labels = ds.y.asnumpy() - predict = np.argmax(out[test_mask], axis=1) - label = labels[test_mask] - count = np.equal(predict, label) - print('Epoch time:{} ms Train loss {} Test acc:{}'.format(dur * 1000, train_loss, - np.sum(count) / label.shape[0])) -``` - -## 执行并查看结果 - -### 运行过程 - -运行程序后,首先可以看到所有被翻译后的函数的对比图(默认GNNCell中的construct函数会被翻译)。此处展示出GCNConv的翻译对比图,左边为GCNConv的源代码;右边为翻译后的代码。 - -可以看到graph的API被mindspore_gl替换后的代码实现。比如调用的graph aggregate函数g.sum将被替换为Gather-Scatter的实现。可以看出以节点为中心的编程范式大大降低了图模型实现的代码量。 - -```bash --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -| def construct(self, x, in_deg, out_deg, g: Graph): 1 || 1 def construct( | -| || self, | -| || x, | -| || in_deg, | -| || out_deg, | -| || src_idx, | -| || dst_idx, | -| || n_nodes, | -| || n_edges, | -| || ver_subgraph_idx=None, | -| || edge_subgraph_idx=None, | -| || graph_mask=None | -| || ): | -| || 2 SCATTER_ADD = ms.ops.TensorScatterAdd() | -| || 3 SCATTER_MAX = ms.ops.TensorScatterMax() | -| || 4 SCATTER_MIN = ms.ops.TensorScatterMin() | -| || 5 GATHER = ms.ops.Gather() | -| || 6 ZEROS = ms.ops.Zeros() | -| || 7 SHAPE = ms.ops.Shape() | -| || 8 RESHAPE = ms.ops.Reshape() | -| || 9 scatter_src_idx = RESHAPE(src_idx, (SHAPE(src_idx)[0], 1)) | -| || 10 scatter_dst_idx = RESHAPE(dst_idx, (SHAPE(dst_idx)[0], 1)) | -| out_deg = ms.ops.clip_by_value(out_deg, self.min_clip, self.max_clip) 2 || 11 out_deg = ms.ops.clip_by_value(out_deg, self.min_clip, self.max_clip) | -| out_deg = ms.ops.Reshape()( 3 || 12 out_deg = ms.ops.Reshape()( | -| ms.ops.Pow()(out_deg, -0.5), || ms.ops.Pow()(out_deg, -0.5), | -| ms.ops.Shape()(out_deg) + (1,) || ms.ops.Shape()(out_deg) + (1,) | -| ) || ) | -| x = self.drop_out(x) 4 || 13 x = self.drop_out(x) | -| x = ms.ops.Squeeze()(x) 5 || 14 x = ms.ops.Squeeze()(x) | -| x = x * out_deg 6 || 15 x = x * out_deg | -| x = self.fc(x) 7 || 16 x = self.fc(x) | -| g.set_vertex_attr({'x': x}) 8 || 17 VERTEX_SHAPE = SHAPE(x)[0] | -| || 18 x, = [x] | -| for v in g.dst_vertex: 9 || | -| v.x = g.sum([u.x for u in v.innbs]) 10 || 19 SCATTER_INPUT_SNAPSHOT2 = GATHER(x, src_idx, 0) | -| || 20 x = SCATTER_ADD( | -| || ZEROS((VERTEX_SHAPE,) + SHAPE(SCATTER_INPUT_SNAPSHOT2)[1:], ms.float32), | -| || scatter_dst_idx, | -| || SCATTER_INPUT_SNAPSHOT2 | -| || ) | -| in_deg = ms.ops.clip_by_value(in_deg, self.min_clip, self.max_clip) 11 || 21 in_deg = ms.ops.clip_by_value(in_deg, self.min_clip, self.max_clip) | -| in_deg = ms.ops.Reshape()(ms.ops.Pow()(in_deg, -0.5), ms.ops.Shape()(in_deg) + (1,)) 12 || 22 in_deg = ms.ops.Reshape()(ms.ops.Pow()(in_deg, -0.5), ms.ops.Shape()(in_deg) + (1,)) | -| x = [v.x for v in g.dst_vertex] * in_deg 13 || 23 x = x * in_deg | -| x = x + self.bias 14 || 24 x = x + self.bias | -| return x 15 || 25 return x | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - -``` - -### 开启/关闭翻译界面 - -代码执行时默认显示翻译对比图,如果需要关闭对比视图,可以进行如下操作: - -```python -from mindspore_gl.nn import GNNCell -GNNCell.disable_display() -``` - -如果需要修改对比视图展示宽度时(默认为200),可以进行如下操作: - -```python -from mindspore_gl.nn import GNNCell -GNNCell.enable_display(screen_width=350) -``` - -### 执行结果 - -执行脚本[vc_gcn_datanet.py](https://gitee.com/mindspore/graphlearning/blob/master/examples/vc_gcn_datanet.py)启动训练。 - -```bash -cd examples -python vc_gcn_datanet.py --data-path={path} --fuse=True -``` - -其中`{path}`为数据集存放路径。 - -可以看到训练的结果(截取最后五个epoch)如下: - -```bash -... -Epoch 196, Train loss 0.30630863, Test acc 0.822 -Epoch 197, Train loss 0.30918056, Test acc 0.819 -Epoch 198, Train loss 0.3299482, Test acc 0.819 -Epoch 199, Train loss 0.2945389, Test acc 0.821 -Epoch 200, Train loss 0.27628058, Test acc 0.819 -``` - -在cora上验证精度:0.82 (论文:0.815) - -以上就是整图训练的使用指南。更多样例可参考[examples directory](https://gitee.com/mindspore/graphlearning/tree/master/examples/)。 diff --git a/docs/graphlearning/docs/source_zh_cn/images/gat_example.PNG b/docs/graphlearning/docs/source_zh_cn/images/gat_example.PNG deleted file mode 100644 index 9477fb68c19237a73ce4f21052e4beef2e357970..0000000000000000000000000000000000000000 Binary files a/docs/graphlearning/docs/source_zh_cn/images/gat_example.PNG and /dev/null differ diff --git a/docs/graphlearning/docs/source_zh_cn/images/graphlearning_cn.png b/docs/graphlearning/docs/source_zh_cn/images/graphlearning_cn.png deleted file mode 100644 index 2bdec9d8671dd81ea14e1004499ade2ad9f69f52..0000000000000000000000000000000000000000 Binary files a/docs/graphlearning/docs/source_zh_cn/images/graphlearning_cn.png and /dev/null differ diff --git a/docs/graphlearning/docs/source_zh_cn/index.rst b/docs/graphlearning/docs/source_zh_cn/index.rst deleted file mode 100644 index e71103b0a0d1c0945aca4cebc2061851ed61988d..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/index.rst +++ /dev/null @@ -1,93 +0,0 @@ -MindSpore Graph Learning文档 -============================== - -MindSpore Graph Learning是一款图学习套件,支持以点为中心编程实现图神经网络和高效的训练推理。 - -得益于MindSpore的图算融合能力,MindSpore Graph Learning能够针对图模型特有的执行模式进行编译优化,帮助开发者缩短训练时间。MindSpore Graph Learning还创新性地提出以点为中心编程范式,提供更原生的图神经网络表达方式,并内置覆盖大部分应用场景的模型,使开发者能够轻松搭建图神经网络。 - -.. image:: ./images/graphlearning_cn.png - :width: 700px - -代码仓地址: - -设计特点 ---------- - -1. 以点为中心的编程范式 - - 图神经网络模型需要在给定的图结构上做信息的传递和聚合,整图计算无法直观表达这些操作。MindSpore Graph Learning提供以点为中心的编程范式,更符合图学习算法逻辑和Python语言风格,可以直接进行公式到代码的翻译,减少算法设计和实现间的差距。 - -2. 高效加速图模型 - - MindSpore Graph Learning结合MindSpore的图算融合和自动算子编译技术(AKG)特性,自动识别图神经网络任务中特有的执行pattern进行融合和kernel level优化,能够覆盖现有框架中已有的算子和新组合算子的融合优化,获得相比现有流行框架平均3到4倍的性能提升。 - -训练流程 ---------- - -MindSpore Graph Learning为用户提供了丰富的数据读入、图操作和网络结构模块接口,用户使用MindSpore Graph Learning实现训练图神经网络只需要以下几步: - -1. 定义网络结构,用户可以直接调用mindspore_gl.nn提供的接口,也可以参考这里的实现自定义图学习模块。 - -2. 定义loss函数。 - -3. 构造数据集,mindspore_gl.dataset提供了一些研究用的公开数据集的读入和构造。 - -4. 网络训练和验证。 - -特性介绍 ---------- - -MindSpore Graph Learning提供了以点为中心的GNN网络编程范式,内置将以点为中心的计算表达翻译为图数据的计算操作的代码解析函数,为了方便用户调试解析过程将打印出用户输入代码与计算代码的翻译对比图。 - -如下图基于以点为中心编程模型实现经典GAT网络。用户定义一个函数以节点`v`作为入参,在函数内用户通过`v.innbs()`获取邻居节点列表,遍历每个邻居节点`u`,获取节点特征,计算邻居节点与中心节点的特征交互得到邻边上的权重列表,然后将邻边权重与邻居节点进行加权平均,返回更新的中心节点特征。 - -.. image:: ./images/gat_example.PNG - :width: 700px - -未来路标 ---------- - -MindSpore Graph Learning初始版本包含以点为中心的编程范式,并内置提供了典型图模型的实现,以及在小数据集上单机训练的案例和性能评测。初始版本暂时不包含大数据集上的性能评测和分布式训练,也不支持对接图数据库。MindSpore Graph Learning后续版本将包含这些内容,敬请期待。 - -使用MindSpore Graph Learning的典型场景 ---------------------------------------- - -.. toctree:: - :maxdepth: 1 - :caption: 安装部署 - - mindspore_graphlearning_install - -.. toctree:: - :maxdepth: 1 - :caption: 使用指南 - - full_training_of_GCN - batched_graph_training_GIN - spatio_temporal_graph_training_STGCN - single_host_distributed_Graphsage - -.. toctree:: - :maxdepth: 1 - :caption: API参考 - - mindspore_gl - mindspore_gl.dataloader - mindspore_gl.dataset - mindspore_gl.graph - mindspore_gl.nn - mindspore_gl.sampling - mindspore_gl.utils - -.. toctree:: - :maxdepth: 1 - :caption: 参考文档 - - faq - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: RELEASE NOTES - - RELEASE diff --git a/docs/graphlearning/docs/source_zh_cn/mindspore_graphlearning_install.md b/docs/graphlearning/docs/source_zh_cn/mindspore_graphlearning_install.md deleted file mode 100644 index aeebc542753bdd7f3fdc7aebe33cb193011b4afd..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/mindspore_graphlearning_install.md +++ /dev/null @@ -1,57 +0,0 @@ -# 安装 Graph Learning - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_zh_cn/mindspore_graphlearning_install.md)   - -## 安装指南 - -### 确认系统环境信息 - -- 硬件平台确认为Linux系统下的Ascend或GPU。 -- 参考[MindSpore安装指南](https://www.mindspore.cn/install),完成MindSpore的安装,要求至少2.0.0版本。 -- 其余依赖请参见[requirements.txt](https://gitee.com/mindspore/graphlearning/blob/master/requirements.txt)。 - -### 安装方式 - -可以采用pip安装或者源码编译安装两种方式。 - -#### pip安装 - -- Ascend/CPU - - ```bash - pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.0.0rc1/GraphLearning/cpu/{system_structure}/mindspore_gl-0.2-cp37-cp37m-linux_{system_structure}.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple - ``` - -- GPU - - ```bash - pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.0.0rc1/GraphLearning/gpu/x86_64/cuda-{cuda_verison}/mindspore_gl-0.2-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple - ``` - -> - 在联网状态下,安装whl包时会自动下载MindSpore Graph Learning安装包的依赖项(依赖项详情参见[requirements.txt](https://gitee.com/mindspore/graphlearning/blob/master/requirements.txt)),其余情况需自行安装。 -> - `{system_structure}`表示为Linux系统架构,可选项为`x86_64`和`arrch64`。 -> - `{cuda_verison}`表示为CUDA版本,可选项为`10.1`、`11.1`和`11.6`。 - -#### 源码安装 - -1. 从代码仓下载源码 - - ```bash - git clone https://gitee.com/mindspore/graphlearning.git - ``` - -2. 编译安装MindSpore Graph Learning - - ```bash - cd graphlearning - bash build.sh - pip install ./output/mindspore_gl-*.whl - ``` - -### 验证是否成功安装 - -执行如下命令,如果没有报错`No module named 'mindspore_gl'`,则说明安装成功。 - -```bash -python -c 'import mindspore_gl' -``` diff --git a/docs/graphlearning/docs/source_zh_cn/single_host_distributed_Graphsage.md b/docs/graphlearning/docs/source_zh_cn/single_host_distributed_Graphsage.md deleted file mode 100644 index f6bf7068cc3778b66d261b4c8cc6a0f80e132512..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/single_host_distributed_Graphsage.md +++ /dev/null @@ -1,271 +0,0 @@ -# 单机多卡分布式训练 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_zh_cn/single_host_distributed_Graphsage.md) -   - -## 概述 - -在本例中将展示如何利用Graphsage在大尺寸图上进行单机多卡训练。 - -GraphSAGE是一个通用的归纳框架,它利用节点特征信息(例如,文本属性)为以前看不见的数据有效地生成节点嵌入。GraphSAGE不是为每个节点训练单个嵌入,而是学习一个函数,该函数通过从节点的本地邻居中采样和聚合特征来生成嵌入。 - -在Reddit数据集中,作者对50个大型社区进行了抽样调查,并构建了一个帖子到帖子的图,如果同一用户对这两个帖子都发表了评论,则连接帖子。每个帖子的标签为所属的社区。该数据集总共包含232965个帖子,平均度为492。 - -由于Reddit数据集较大,为了减少GraphSAGE训练时间,本例中在单机上执行分布式模型训练,以加快模型训练。 - -> 下载完整的样例[GraphSAGE](https://gitee.com/mindspore/graphlearning/tree/master/model_zoo/graphsage)代码。 - -## GraphSAGE原理 - -论文链接:[Inductive representation learning on large graphs](https://proceedings.neurips.cc/paper/2017/file/5dd9db5e033da9c6fb5ba83c7a7ebea9-Paper.pdf) - -## 设置运行脚本 - -在不同设备上,分布式训练的方式也不相同。 - -在GPU硬件平台上,MindSpore分布式并行训练中的通信使用的是英伟达集合通信库NVIDIA Collective Communication Library(简称为NCCL)。 - -```bash -# GPU -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -export CUDA_NUM=8 -rm -rf device -mkdir device -cp -r src ./device -cp distributed_trainval_reddit.py ./device -cd ./device -echo "start training" -mpirun --allow-run-as-root -n ${CUDA_NUM} python3 ./distributed_trainval_reddit.py --data-path ${DATA_PATH} --epochs 5 > train.log 2>&1 & -``` - -MindSpore分布式并行训练的通信使用了华为集合通信库Huawei Collective Communication Library(以下简称HCCL),可以在Atlas 200/300/500推理产品配套的软件包中找到。同时mindspore.communication.management中封装了HCCL提供的集合通信接口,方便用户配置分布式信息。 - -```bash -# Ascend -RANK_TABLE_FILE=$3 - export RANK_TABLE_FILE=${RANK_TABLE_FILE} - for((i=0;i<8;i++)); - do - export RANK_ID=$[i+RANK_START] - export DEVICE_ID=$[i+RANK_START] - echo ${DEVICE_ID} - rm -rf ${execute_path}/device_$RANK_ID - mkdir ${execute_path}/device_$RANK_ID - cd ${execute_path}/device_$RANK_ID || exit - echo "start training" - python3 ${self_path}/distributed_trainval_reddit.py --data-path ${DATA_PATH} --epochs 2 > train$RANK_ID.log 2>&1 & - done -``` - -## 定义网络结构 - -mindspore_gl.nn提供了SAGEConv的API可以直接调用。使用SAGEConv实现一个两层的GraphSAGE网络代码如下: - -```python -class SAGENet(Cell): - """graphsage net""" - def __init__(self, in_feat_size, hidden_feat_size, appr_feat_size, out_feat_size): - super().__init__() - self.num_layers = 2 - self.layer1 = SAGEConv(in_feat_size, hidden_feat_size, aggregator_type='mean') - self.layer2 = SAGEConv(hidden_feat_size, appr_feat_size, aggregator_type='mean') - self.dense_out = ms.nn.Dense(appr_feat_size, out_feat_size, has_bias=False, - weight_init=XavierUniform(math.sqrt(2))) - self.activation = ms.nn.ReLU() - self.dropout = ms.nn.Dropout(p=0.5) - - def construct(self, node_feat, edges, n_nodes, n_edges): - """graphsage net forward""" - node_feat = self.layer1(node_feat, None, edges[0], edges[1], n_nodes, n_edges) - node_feat = self.activation(node_feat) - node_feat = self.dropout(node_feat) - ret = self.layer2(node_feat, None, edges[0], edges[1], n_nodes, n_edges) - ret = self.dense_out(ret) - return ret -``` - -SAGEConv执行的更多细节可以看mindspore_gl.nn.SAGEConv的[API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/nn/conv/sageconv.py)代码。 - -## 定义loss函数 - -由于本次任务为分类任务,可以采用交叉熵来作为损失函数,实现方法与[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E5%AE%9A%E4%B9%89loss%E5%87%BD%E6%95%B0)类似。 - -## 构造数据集 - -下面以[Reddit](https://data.dgl.ai/dataset/reddit.zip)数据集为例。输入数据路径,构造数据类。 - -get_group_size用于获取分布式训练的进程总数,get_rank用于获取当前进程的ID。数据加载器的构建方法可以参考[GIN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/batched_graph_training_GIN.html#%E6%9E%84%E9%80%A0%E6%95%B0%E6%8D%AE%E9%9B%86)。 - -与GIN不同的时,在本例中采样器调用的是mindspore_gl.dataloader.DistributeRandomBatchSampler。DistributeRandomBatchSampler可以根据进程ID拆分数据集索引,确保每个进程获取的数据集批次的不同部分。 - -```python -from mindspore_gl.dataset import Reddit -from mindspore.communication import get_rank, get_group_size - -rank_id = get_rank() -world_size = get_group_size() -graph_dataset = Reddit(args.data_path) -train_sampler = DistributeRandomBatchSampler(rank_id, world_size, data_source=graph_dataset.train_nodes, - batch_size=args.batch_size) -test_sampler = RandomBatchSampler(data_source=graph_dataset.test_nodes, batch_size=args.batch_size) -train_dataset = GraphSAGEDataset(graph_dataset, [25, 10], args.batch_size, len(list(train_sampler)), single_size) -test_dataset = GraphSAGEDataset(graph_dataset, [25, 10], args.batch_size, len(list(test_sampler)), single_size) -train_dataloader = ds.GeneratorDataset(train_dataset, ['seeds_idx', 'label', 'nid_feat', 'edges'], - sampler=train_sampler, python_multiprocessing=True) -test_dataloader = ds.GeneratorDataset(test_dataset, ['seeds_idx', 'label', 'nid_feat', 'edges'], - sampler=test_sampler, python_multiprocessing=True) -``` - -mindspore_gl.sampling.sage_sampler_on_homo提供了k-hop的采样方法。在`self.neighbor_nums`为list的形式,设定了每次从中心节点往外的采样点个数。 -由于每个点的度数不一样,经过k-hop采样后的数组的尺寸也不一样。通过接口mindspore_gl.graph.PadArray2d将采样得到的结果离散化成5个固定的值。 - -```python -from mindspore_gl.dataloader.dataset import Dataset -from mindspore_gl.sampling.neighbor import sage_sampler_on_homo - -class GraphSAGEDataset(Dataset): - """Do sampling from neighbour nodes""" - def __init__(self, graph_dataset, neighbor_nums, batch_size, length, single_size=False): - self.graph_dataset = graph_dataset - self.graph = graph_dataset[0] - self.neighbor_nums = neighbor_nums - self.x = graph_dataset.node_feat - self.y = graph_dataset.node_label - self.batch_size = batch_size - self.max_sampled_nodes_num = neighbor_nums[0] * neighbor_nums[1] * batch_size - self.single_size = single_size - self.length = length - - def __getitem__(self, batch_nodes): - batch_nodes = np.array(batch_nodes, np.int32) - res = sage_sampler_on_homo(self.graph, batch_nodes, self.neighbor_nums) - label = array_kernel.int_1d_array_slicing(self.y, batch_nodes) - layered_edges_0 = res['layered_edges_0'] - layered_edges_1 = res['layered_edges_1'] - sample_edges = np.concatenate((layered_edges_0, layered_edges_1), axis=1) - sample_edges = sample_edges[[1, 0], :] - num_sample_edges = sample_edges.shape[1] - num_sample_nodes = len(res['all_nodes']) - max_sampled_nodes_num = self.max_sampled_nodes_num - if self.single_size is False: - if num_sample_nodes < floor(0.2*max_sampled_nodes_num): - pad_node_num = floor(0.2*max_sampled_nodes_num) - elif num_sample_nodes < floor(0.4*max_sampled_nodes_num): - pad_node_num = floor(0.4 * max_sampled_nodes_num) - elif num_sample_nodes < floor(0.6*max_sampled_nodes_num): - pad_node_num = floor(0.6 * max_sampled_nodes_num) - elif num_sample_nodes < floor(0.8*max_sampled_nodes_num): - pad_node_num = floor(0.8 * max_sampled_nodes_num) - else: - pad_node_num = max_sampled_nodes_num - - if num_sample_edges < floor(0.2*max_sampled_nodes_num): - pad_edge_num = floor(0.2*max_sampled_nodes_num) - elif num_sample_edges < floor(0.4*max_sampled_nodes_num): - pad_edge_num = floor(0.4 * max_sampled_nodes_num) - elif num_sample_edges < floor(0.6*max_sampled_nodes_num): - pad_edge_num = floor(0.6 * max_sampled_nodes_num) - elif num_sample_edges < floor(0.8*max_sampled_nodes_num): - pad_edge_num = floor(0.8 * max_sampled_nodes_num) - else: - pad_edge_num = max_sampled_nodes_num - - else: - pad_node_num = max_sampled_nodes_num - pad_edge_num = max_sampled_nodes_num - - layered_edges_pad_op = PadArray2d(mode=PadMode.CONST, size=[2, pad_edge_num], - dtype=np.int32, direction=PadDirection.ROW, - fill_value=pad_node_num - 1, - ) - nid_feat_pad_op = PadArray2d(mode=PadMode.CONST, - size=[pad_node_num, self.graph_dataset.node_feat_size], - dtype=self.graph_dataset.node_feat.dtype, - direction=PadDirection.COL, - fill_value=0, - reset_with_fill_value=False, - use_shared_numpy=True - ) - sample_edges = sample_edges[:, :pad_edge_num] - pad_sample_edges = layered_edges_pad_op(sample_edges) - feat = nid_feat_pad_op.lazy([num_sample_nodes, self.graph_dataset.node_feat_size]) - array_kernel.float_2d_gather_with_dst(feat, self.graph_dataset.node_feat, res['all_nodes']) - return res['seeds_idx'], label, feat, pad_sample_edges -``` - -## 网络训练和验证 - -### 设置环境变量 - -分布式训练时,采用数据并行方式导入数据。在每个训练步骤结束时,各个进程会统一模型参数,在Ascend上,必须确保每个进程中的数据shape相同。 - -```python -device_target = str(os.getenv('DEVICE_TARGET')) -if device_target == 'Ascend': - device_id = int(os.getenv('DEVICE_ID')) - ms.set_context(device_id=device_id) - single_size = True - init() -else: - init("nccl") - single_size = False -``` - -图算编译优化设置可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E8%AE%BE%E7%BD%AE%E7%8E%AF%E5%A2%83%E5%8F%98%E9%87%8F)。 - -### 定义训练网络 - -实例化模型主体以及LossNet和优化器。 -实现方法与GCN类似,可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E5%AE%9A%E4%B9%89%E8%AE%AD%E7%BB%83%E7%BD%91%E7%BB%9C)。 - -### 网络训练及验证 - -训练与验证方法可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E7%BD%91%E7%BB%9C%E8%AE%AD%E7%BB%83%E5%8F%8A%E9%AA%8C%E8%AF%81)。 - -## 执行并查看结果 - -### 运行过程 - -运行程序后,翻译代码并开始训练。 - -### 执行结果 - -执行脚本[distributed_run.sh](https://gitee.com/mindspore/graphlearning/blob/master/model_zoo/graphsage/distributed_run.sh)启动训练。 - -- GPU - - ```bash - cd model_zoo/graphsage - bash distributed_run.sh GPU DATA_PATH - ``` - - `{DATA_PATH}`为数据集存放路径。 - -- Ascend - - ```bash - cd model_zoo/graphsage - bash bash distributed_run.sh Ascend DATA_PATH RANK_START RANK_SIZE RANK_TABLE_FILE - ``` - - `{DATA_PATH}`为数据集存放路径。`{ANK_START}`为使用的Ascend卡的第一个ID。`{RANK_SIZE}`为使用的卡的张数。`{RANK_TABLE_FILE}`为'rank_table_*pcs.json'文件的根路径. - -可以看到训练的结果如下: - -```bash -... -Iteration/Epoch: 30:4 train loss: 0.41629112 -Iteration/Epoch: 30:4 train loss: 0.5337528 -Iteration/Epoch: 30:4 train loss: 0.42849028 -Iteration/Epoch: 30:4 train loss: 0.5358513 -rank_id:3 Epoch/Time: 4:76.17579555511475 -rank_id:1 Epoch/Time: 4:37.79207944869995 -rank_id:2 Epoch/Time: 4:76.04292225837708 -rank_id:0 Epoch/Time: 4:75.64319372177124 -rank_id:2 test accuracy : 0.9276439525462963 -rank_id:0 test accuracy : 0.9305013020833334 -rank_id:3 test accuracy : 0.9290907118055556 -rank_id:1 test accuracy : 0.9279513888888888 -``` - -在Reddit数据上的验证精度为0.92。 diff --git a/docs/graphlearning/docs/source_zh_cn/spatio_temporal_graph_training_STGCN.md b/docs/graphlearning/docs/source_zh_cn/spatio_temporal_graph_training_STGCN.md deleted file mode 100644 index d0d819ab5f1a527b67be6f0f83a55590e5db8b2a..0000000000000000000000000000000000000000 --- a/docs/graphlearning/docs/source_zh_cn/spatio_temporal_graph_training_STGCN.md +++ /dev/null @@ -1,164 +0,0 @@ -# 时空图训练网络 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/graphlearning/docs/source_zh_cn/spatio_temporal_graph_training_STGCN.md) -   - -## 概述 - -在本例中将展示如何利用时空图卷积网络进行交通预测。 - -时空图卷积网络(STGCN)可以解决交通域的时间序列预测问题。实验表明,STGCN通过建模多尺度交通网络,有效地捕获了综合的时空相关性。 - -METR-LA是一个大规模数据集,从洛杉矶乡村公路网的1500个交通环路探测器收集。此数据集包括速度、道路容量和占用数据,覆盖约3,420英里。将路网构建成图,输入到STGCN网络中,根据历史数据来预测下个时间段的路网信息。 - -一般图的节点特征形状为`(节点数量, 特征维度)`,时空图中输入的特征形状通常至少为三维`(节点数量, 特征维度, 时间步)`,邻居节点的特征融合处理会更加复杂。并且由于时间维度上进行卷积,`时间步`也会发生变化,计算loss时,需要提前计算好输出时间长度。 - -> 下载完整的样例[STGCN](https://gitee.com/mindspore/graphlearning/tree/master/model_zoo/stgcn)代码。 - -## STGCN原理 - -论文链接: [A deep learning framework for traffic forecasting](https://arxiv.org/pdf/1709.04875.pdf) - -## 图拉普拉斯归一化 - -将图的自环删除,对图进行归一化,得到新的边索引与边权重。 -mindspore_gl.graph提供norm的API可以被用于拉普拉斯归一化。边缘索引和边缘权重归一化的代码如下所示: - -```python -mask = edge_index[0] != edge_index[1] -edge_index = edge_index[:, mask] -edge_attr = edge_attr[mask] - -edge_index = ms.Tensor(edge_index, ms.int32) -edge_attr = ms.Tensor(edge_attr, ms.float32) -edge_index, edge_weight = norm(edge_index, node_num, edge_attr, args.normalization) -``` - -关于拉普拉斯归一化的更多细节,可以看mindspore_gl.graph.norm的[API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/graph/norm.py). - -## 定义网络结构 - -mindspore_gl.nn提供了STConv的API可以直接调用。与一般的图卷积层不同,STConv的输入特征为四维,即`(批次内图数量, 时间步, 节点数量, 特征维度)`。输出特征的`时间步`需要根据1D卷积核尺寸、卷积次数来计算。 - -使用STConv实现一个两层的STGCN网络代码如下: - -```python -class STGcnNet(GNNCell): - """ STGCN Net """ - def __init__(self, - num_nodes: int, - in_channels: int, - hidden_channels_1st: int, - out_channels_1st: int, - hidden_channels_2nd: int, - out_channels_2nd: int, - out_channels: int, - kernel_size: int, - k: int, - bias: bool = True): - super().__init__() - self.layer0 = STConv(num_nodes, in_channels, - hidden_channels_1st, - out_channels_1st, - kernel_size, - k, bias) - self.layer1 = STConv(num_nodes, out_channels_1st, - hidden_channels_2nd, - out_channels_2nd, - kernel_size, - k, bias) - self.relu = ms.nn.ReLU() - self.fc = ms.nn.Dense(out_channels_2nd, out_channels) - - def construct(self, x, edge_weight, g: Graph): - x = self.layer0(x, edge_weight, g) - x = self.layer1(x, edge_weight, g) - x = self.relu(x) - x = self.fc(x) - return x -``` - -STConv执行的更多细节可以看mindspore_gl.nn.temporal.STConv的[API](https://gitee.com/mindspore/graphlearning/blob/master/mindspore_gl/nn/temporal/stconv.py)代码。 - -## 定义loss函数 - -由于本次任务为回归任务,可以采用最小均方差来作为损失函数。这里调用mindspore.nn.MSELoss实现最小均方差loss。 - -```python -class LossNet(GNNCell): - """ LossNet definition """ - def __init__(self, net): - super().__init__() - self.net = net - self.loss_fn = nn.loss.MSELoss() - - def construct(self, feat, edges, target, g: Graph): - """STGCN Net with loss function""" - predict = self.net(feat, edges, g) - predict = ops.Squeeze()(predict) - loss = self.loss_fn(predict, target) - return ms.ops.ReduceMean()(loss) -``` - -## 构造数据集 - -输入特征为`(批次内图数量, 时间步, 节点数量, 特征维度)`。在时序卷积上时间序列的长度将会发生变化。因此,从数据集获取特征和标签时,输入和输出时间步有相应规范,否则会出现预测值与标签值形状不一致。 - -限制规范可以参考代码注释。 - -```python -from mindspore_gl.dataset import MetrLa -metr = MetrLa(args.data_path) -# out_timestep setting -# out_timestep = in_timestep - ((kernel_size - 1) * 2 * layer_nums) -# such as: layer_nums = 2, kernel_size = 3, in_timestep = 12, -# out_timestep = 4 -features, labels = metr.get_data(args.in_timestep, args.out_timestep) -``` - -其中[MetrLa](https://graphmining.ai/temporal_datasets/METR-LA.zip)数据下载后,解压路径即为args.data_path。 - -## 网络训练和验证 - -### 设置环境变量 - -环境变量设置方法可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E8%AE%BE%E7%BD%AE%E7%8E%AF%E5%A2%83%E5%8F%98%E9%87%8F)。 - -### 定义训练网络 - -实例化模型主体以及LossNet和优化器。 -实现方法可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E5%AE%9A%E4%B9%89%E8%AE%AD%E7%BB%83%E7%BD%91%E7%BB%9C)。 - -### 网络训练及验证 - -实现方法可以参考[GCN](https://www.mindspore.cn/graphlearning/docs/zh-CN/master/full_training_of_GCN.html#%E7%BD%91%E7%BB%9C%E8%AE%AD%E7%BB%83%E5%8F%8A%E9%AA%8C%E8%AF%81)。 - -## 执行并查看结果 - -### 运行过程 - -运行程序后,翻译代码并开始训练。 - -### 执行结果 - -执行脚本[trainval_metr.py](https://gitee.com/mindspore/graphlearning/blob/master/model_zoo/stgcn/trainval_metr.py)启动训练。 - -```bash -cd model_zoo/stgcn -python trainval_metr.py --data-path={path} --fuse=True -``` - -其中`{path}`为数据集存放路径。 - -可以看到训练的结果如下: - -```bash -... -Iteration/Epoch: 600:199 loss: 0.21488506 -Iteration/Epoch: 700:199 loss: 0.21441595 -Iteration/Epoch: 800:199 loss: 0.21243602 -Time 13.162885904312134 Epoch loss 0.21053028 -eval MSE: 0.2060675 -``` - -MetrLa的MSE: 0.206 diff --git a/docs/hub/docs/Makefile b/docs/hub/docs/Makefile deleted file mode 100644 index 1eff8952707bdfa503c8d60c1e9a903053170ba2..0000000000000000000000000000000000000000 --- a/docs/hub/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source_zh_cn -BUILDDIR = build_zh_cn - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/hub/docs/_ext/overwriteautosummary_generate.txt b/docs/hub/docs/_ext/overwriteautosummary_generate.txt deleted file mode 100644 index 4b0a1b1dd2b410ecab971b13da9993c90d65ef0d..0000000000000000000000000000000000000000 --- a/docs/hub/docs/_ext/overwriteautosummary_generate.txt +++ /dev/null @@ -1,707 +0,0 @@ -""" - sphinx.ext.autosummary.generate - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - Usable as a library or script to generate automatic RST source files for - items referred to in autosummary:: directives. - - Each generated RST file contains a single auto*:: directive which - extracts the docstring of the referred item. - - Example Makefile rule:: - - generate: - sphinx-autogen -o source/generated source/*.rst - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import argparse -import importlib -import inspect -import locale -import os -import pkgutil -import pydoc -import re -import sys -import warnings -from gettext import NullTranslations -from os import path -from typing import Any, Dict, List, NamedTuple, Sequence, Set, Tuple, Type, Union - -from jinja2 import TemplateNotFound -from jinja2.sandbox import SandboxedEnvironment - -import sphinx.locale -from sphinx import __display_version__, package_dir -from sphinx.application import Sphinx -from sphinx.builders import Builder -from sphinx.config import Config -from sphinx.deprecation import RemovedInSphinx50Warning -from sphinx.ext.autodoc import Documenter -from sphinx.ext.autodoc.importer import import_module -from sphinx.ext.autosummary import (ImportExceptionGroup, get_documenter, import_by_name, - import_ivar_by_name) -from sphinx.locale import __ -from sphinx.pycode import ModuleAnalyzer, PycodeError -from sphinx.registry import SphinxComponentRegistry -from sphinx.util import logging, rst, split_full_qualified_name, get_full_modname -from sphinx.util.inspect import getall, safe_getattr -from sphinx.util.osutil import ensuredir -from sphinx.util.template import SphinxTemplateLoader - -logger = logging.getLogger(__name__) - - -class DummyApplication: - """Dummy Application class for sphinx-autogen command.""" - - def __init__(self, translator: NullTranslations) -> None: - self.config = Config() - self.registry = SphinxComponentRegistry() - self.messagelog: List[str] = [] - self.srcdir = "/" - self.translator = translator - self.verbosity = 0 - self._warncount = 0 - self.warningiserror = False - - self.config.add('autosummary_context', {}, True, None) - self.config.add('autosummary_filename_map', {}, True, None) - self.config.add('autosummary_ignore_module_all', True, 'env', bool) - self.config.add('docs_branch', '', True, None) - self.config.add('branch', '', True, None) - self.config.add('cst_module_name', '', True, None) - self.config.add('copy_repo', '', True, None) - self.config.add('giturl', '', True, None) - self.config.add('repo_whl', '', True, None) - self.config.init_values() - - def emit_firstresult(self, *args: Any) -> None: - pass - - -class AutosummaryEntry(NamedTuple): - name: str - path: str - template: str - recursive: bool - - -def setup_documenters(app: Any) -> None: - from sphinx.ext.autodoc import (AttributeDocumenter, ClassDocumenter, DataDocumenter, - DecoratorDocumenter, ExceptionDocumenter, - FunctionDocumenter, MethodDocumenter, ModuleDocumenter, - NewTypeAttributeDocumenter, NewTypeDataDocumenter, - PropertyDocumenter) - documenters: List[Type[Documenter]] = [ - ModuleDocumenter, ClassDocumenter, ExceptionDocumenter, DataDocumenter, - FunctionDocumenter, MethodDocumenter, NewTypeAttributeDocumenter, - NewTypeDataDocumenter, AttributeDocumenter, DecoratorDocumenter, PropertyDocumenter, - ] - for documenter in documenters: - app.registry.add_documenter(documenter.objtype, documenter) - - -def _simple_info(msg: str) -> None: - warnings.warn('_simple_info() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - print(msg) - - -def _simple_warn(msg: str) -> None: - warnings.warn('_simple_warn() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - print('WARNING: ' + msg, file=sys.stderr) - - -def _underline(title: str, line: str = '=') -> str: - if '\n' in title: - raise ValueError('Can only underline single lines') - return title + '\n' + line * len(title) - - -class AutosummaryRenderer: - """A helper class for rendering.""" - - def __init__(self, app: Union[Builder, Sphinx], template_dir: str = None) -> None: - if isinstance(app, Builder): - warnings.warn('The first argument for AutosummaryRenderer has been ' - 'changed to Sphinx object', - RemovedInSphinx50Warning, stacklevel=2) - if template_dir: - warnings.warn('template_dir argument for AutosummaryRenderer is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - system_templates_path = [os.path.join(package_dir, 'ext', 'autosummary', 'templates')] - loader = SphinxTemplateLoader(app.srcdir, app.config.templates_path, - system_templates_path) - - self.env = SandboxedEnvironment(loader=loader) - self.env.filters['escape'] = rst.escape - self.env.filters['e'] = rst.escape - self.env.filters['underline'] = _underline - - if isinstance(app, (Sphinx, DummyApplication)): - if app.translator: - self.env.add_extension("jinja2.ext.i18n") - self.env.install_gettext_translations(app.translator) - elif isinstance(app, Builder): - if app.app.translator: - self.env.add_extension("jinja2.ext.i18n") - self.env.install_gettext_translations(app.app.translator) - - def exists(self, template_name: str) -> bool: - """Check if template file exists.""" - warnings.warn('AutosummaryRenderer.exists() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - try: - self.env.get_template(template_name) - return True - except TemplateNotFound: - return False - - def render(self, template_name: str, context: Dict) -> str: - """Render a template file.""" - try: - template = self.env.get_template(template_name) - except TemplateNotFound: - try: - # objtype is given as template_name - template = self.env.get_template('autosummary/%s.rst' % template_name) - except TemplateNotFound: - # fallback to base.rst - template = self.env.get_template('autosummary/base.rst') - - return template.render(context) - - -# -- Generating output --------------------------------------------------------- - - -class ModuleScanner: - def __init__(self, app: Any, obj: Any) -> None: - self.app = app - self.object = obj - - def get_object_type(self, name: str, value: Any) -> str: - return get_documenter(self.app, value, self.object).objtype - - def is_skipped(self, name: str, value: Any, objtype: str) -> bool: - try: - return self.app.emit_firstresult('autodoc-skip-member', objtype, - name, value, False, {}) - except Exception as exc: - logger.warning(__('autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s'), - name, exc, type='autosummary') - return False - - def scan(self, imported_members: bool) -> List[str]: - members = [] - for name in members_of(self.object, self.app.config): - try: - value = safe_getattr(self.object, name) - except AttributeError: - value = None - - objtype = self.get_object_type(name, value) - if self.is_skipped(name, value, objtype): - continue - - try: - if inspect.ismodule(value): - imported = True - elif safe_getattr(value, '__module__') != self.object.__name__: - imported = True - else: - imported = False - except AttributeError: - imported = False - - respect_module_all = not self.app.config.autosummary_ignore_module_all - if imported_members: - # list all members up - members.append(name) - elif imported is False: - # list not-imported members - members.append(name) - elif '__all__' in dir(self.object) and respect_module_all: - # list members that have __all__ set - members.append(name) - - return members - - -def members_of(obj: Any, conf: Config) -> Sequence[str]: - """Get the members of ``obj``, possibly ignoring the ``__all__`` module attribute - - Follows the ``conf.autosummary_ignore_module_all`` setting.""" - - if conf.autosummary_ignore_module_all: - return dir(obj) - else: - return getall(obj) or dir(obj) - - -def generate_autosummary_content(name: str, obj: Any, parent: Any, - template: AutosummaryRenderer, template_name: str, - imported_members: bool, app: Any, - recursive: bool, context: Dict, - modname: str = None, qualname: str = None) -> str: - doc = get_documenter(app, obj, parent) - - def skip_member(obj: Any, name: str, objtype: str) -> bool: - try: - return app.emit_firstresult('autodoc-skip-member', objtype, name, - obj, False, {}) - except Exception as exc: - logger.warning(__('autosummary: failed to determine %r to be documented, ' - 'the following exception was raised:\n%s'), - name, exc, type='autosummary') - return False - - def get_class_members(obj: Any) -> Dict[str, Any]: - members = sphinx.ext.autodoc.get_class_members(obj, [qualname], safe_getattr) - return {name: member.object for name, member in members.items()} - - def get_module_members(obj: Any) -> Dict[str, Any]: - members = {} - for name in members_of(obj, app.config): - try: - members[name] = safe_getattr(obj, name) - except AttributeError: - continue - return members - - def get_all_members(obj: Any) -> Dict[str, Any]: - if doc.objtype == "module": - return get_module_members(obj) - elif doc.objtype == "class": - return get_class_members(obj) - return {} - - def get_members(obj: Any, types: Set[str], include_public: List[str] = [], - imported: bool = True) -> Tuple[List[str], List[str]]: - items: List[str] = [] - public: List[str] = [] - - all_members = get_all_members(obj) - for name, value in all_members.items(): - documenter = get_documenter(app, value, obj) - if documenter.objtype in types: - # skip imported members if expected - if imported or getattr(value, '__module__', None) == obj.__name__: - skipped = skip_member(value, name, documenter.objtype) - if skipped is True: - pass - elif skipped is False: - # show the member forcedly - items.append(name) - public.append(name) - else: - items.append(name) - if name in include_public or not name.startswith('_'): - # considers member as public - public.append(name) - return public, items - - def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: - """Find module attributes with docstrings.""" - attrs, public = [], [] - try: - analyzer = ModuleAnalyzer.for_module(name) - attr_docs = analyzer.find_attr_docs() - for namespace, attr_name in attr_docs: - if namespace == '' and attr_name in members: - attrs.append(attr_name) - if not attr_name.startswith('_'): - public.append(attr_name) - except PycodeError: - pass # give up if ModuleAnalyzer fails to parse code - return public, attrs - - def get_modules(obj: Any) -> Tuple[List[str], List[str]]: - items: List[str] = [] - for _, modname, _ispkg in pkgutil.iter_modules(obj.__path__): - fullname = name + '.' + modname - try: - module = import_module(fullname) - if module and hasattr(module, '__sphinx_mock__'): - continue - except ImportError: - pass - - items.append(fullname) - public = [x for x in items if not x.split('.')[-1].startswith('_')] - return public, items - - ns: Dict[str, Any] = {} - ns.update(context) - - if doc.objtype == 'module': - scanner = ModuleScanner(app, obj) - ns['members'] = scanner.scan(imported_members) - ns['functions'], ns['all_functions'] = \ - get_members(obj, {'function'}, imported=imported_members) - ns['classes'], ns['all_classes'] = \ - get_members(obj, {'class'}, imported=imported_members) - ns['exceptions'], ns['all_exceptions'] = \ - get_members(obj, {'exception'}, imported=imported_members) - ns['attributes'], ns['all_attributes'] = \ - get_module_attrs(ns['members']) - ispackage = hasattr(obj, '__path__') - if ispackage and recursive: - ns['modules'], ns['all_modules'] = get_modules(obj) - elif doc.objtype == 'class': - ns['members'] = dir(obj) - ns['inherited_members'] = \ - set(dir(obj)) - set(obj.__dict__.keys()) - ns['methods'], ns['all_methods'] = \ - get_members(obj, {'method'}, ['__init__']) - ns['attributes'], ns['all_attributes'] = \ - get_members(obj, {'attribute', 'property'}) - - if modname is None or qualname is None: - modname, qualname = split_full_qualified_name(name) - - if doc.objtype in ('method', 'attribute', 'property'): - ns['class'] = qualname.rsplit(".", 1)[0] - - if doc.objtype in ('class',): - shortname = qualname - else: - shortname = qualname.rsplit(".", 1)[-1] - - ns['fullname'] = name - ns['module'] = modname - ns['objname'] = qualname - ns['name'] = shortname - - ns['objtype'] = doc.objtype - ns['underline'] = len(name) * '=' - - if template_name: - return template.render(template_name, ns) - else: - return template.render(doc.objtype, ns) - - -def generate_autosummary_docs(sources: List[str], output_dir: str = None, - suffix: str = '.rst', base_path: str = None, - builder: Builder = None, template_dir: str = None, - imported_members: bool = False, app: Any = None, - overwrite: bool = True, encoding: str = 'utf-8') -> None: - - if builder: - warnings.warn('builder argument for generate_autosummary_docs() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - if template_dir: - warnings.warn('template_dir argument for generate_autosummary_docs() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - showed_sources = list(sorted(sources)) - if len(showed_sources) > 20: - showed_sources = showed_sources[:10] + ['...'] + showed_sources[-10:] - logger.info(__('[autosummary] generating autosummary for: %s') % - ', '.join(showed_sources)) - - if output_dir: - logger.info(__('[autosummary] writing to %s') % output_dir) - - if base_path is not None: - sources = [os.path.join(base_path, filename) for filename in sources] - - template = AutosummaryRenderer(app) - - # read - items = find_autosummary_in_files(sources) - - # keep track of new files - new_files = [] - - if app: - filename_map = app.config.autosummary_filename_map - else: - filename_map = {} - - # write - for entry in sorted(set(items), key=str): - if entry.path is None: - # The corresponding autosummary:: directive did not have - # a :toctree: option - continue - - path = output_dir or os.path.abspath(entry.path) - ensuredir(path) - - try: - name, obj, parent, modname = import_by_name(entry.name, grouped_exception=True) - qualname = name.replace(modname + ".", "") - except ImportExceptionGroup as exc: - try: - # try to import as an instance attribute - name, obj, parent, modname = import_ivar_by_name(entry.name) - qualname = name.replace(modname + ".", "") - except ImportError as exc2: - if exc2.__cause__: - exceptions: List[BaseException] = exc.exceptions + [exc2.__cause__] - else: - exceptions = exc.exceptions + [exc2] - - errors = list(set("* %s: %s" % (type(e).__name__, e) for e in exceptions)) - logger.warning(__('[autosummary] failed to import %s.\nPossible hints:\n%s'), - entry.name, '\n'.join(errors)) - continue - - context: Dict[str, Any] = {} - if app: - context.update(app.config.autosummary_context) - - content = generate_autosummary_content(name, obj, parent, template, entry.template, - imported_members, app, entry.recursive, context, - modname, qualname) - try: - py_source_rel = get_full_modname(modname, qualname).replace('.', '/') + '.py' - except: - logger.warning(name) - py_source_rel = '' - - re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{app.config.docs_branch}/" + \ - f"resource/_static/logo_source_en.svg\n :target: " + app.config.giturl + \ - f"{app.config.copy_repo}/blob/{app.config.branch}/" + app.config.repo_whl + \ - py_source_rel.split(app.config.cst_module_name)[-1] + '\n :alt: View Source On Gitee\n\n' - - if re_view not in content and py_source_rel: - content = re.sub('([=]{5,})\n', r'\1\n' + re_view, content, 1) - filename = os.path.join(path, filename_map.get(name, name) + suffix) - if os.path.isfile(filename): - with open(filename, encoding=encoding) as f: - old_content = f.read() - - if content == old_content: - continue - elif overwrite: # content has changed - with open(filename, 'w', encoding=encoding) as f: - f.write(content) - new_files.append(filename) - else: - with open(filename, 'w', encoding=encoding) as f: - f.write(content) - new_files.append(filename) - - # descend recursively to new files - if new_files: - generate_autosummary_docs(new_files, output_dir=output_dir, - suffix=suffix, base_path=base_path, - builder=builder, template_dir=template_dir, - imported_members=imported_members, app=app, - overwrite=overwrite) - - -# -- Finding documented entries in files --------------------------------------- - -def find_autosummary_in_files(filenames: List[str]) -> List[AutosummaryEntry]: - """Find out what items are documented in source/*.rst. - - See `find_autosummary_in_lines`. - """ - documented: List[AutosummaryEntry] = [] - for filename in filenames: - with open(filename, encoding='utf-8', errors='ignore') as f: - lines = f.read().splitlines() - documented.extend(find_autosummary_in_lines(lines, filename=filename)) - return documented - - -def find_autosummary_in_docstring(name: str, module: str = None, filename: str = None - ) -> List[AutosummaryEntry]: - """Find out what items are documented in the given object's docstring. - - See `find_autosummary_in_lines`. - """ - if module: - warnings.warn('module argument for find_autosummary_in_docstring() is deprecated.', - RemovedInSphinx50Warning, stacklevel=2) - - try: - real_name, obj, parent, modname = import_by_name(name, grouped_exception=True) - lines = pydoc.getdoc(obj).splitlines() - return find_autosummary_in_lines(lines, module=name, filename=filename) - except AttributeError: - pass - except ImportExceptionGroup as exc: - errors = list(set("* %s: %s" % (type(e).__name__, e) for e in exc.exceptions)) - print('Failed to import %s.\nPossible hints:\n%s' % (name, '\n'.join(errors))) - except SystemExit: - print("Failed to import '%s'; the module executes module level " - "statement and it might call sys.exit()." % name) - return [] - - -def find_autosummary_in_lines(lines: List[str], module: str = None, filename: str = None - ) -> List[AutosummaryEntry]: - """Find out what items appear in autosummary:: directives in the - given lines. - - Returns a list of (name, toctree, template) where *name* is a name - of an object and *toctree* the :toctree: path of the corresponding - autosummary directive (relative to the root of the file name), and - *template* the value of the :template: option. *toctree* and - *template* ``None`` if the directive does not have the - corresponding options set. - """ - autosummary_re = re.compile(r'^(\s*)\.\.\s+(ms[a-z]*)?autosummary::\s*') - automodule_re = re.compile( - r'^\s*\.\.\s+automodule::\s*([A-Za-z0-9_.]+)\s*$') - module_re = re.compile( - r'^\s*\.\.\s+(current)?module::\s*([a-zA-Z0-9_.]+)\s*$') - autosummary_item_re = re.compile(r'^\s+(~?[_a-zA-Z][a-zA-Z0-9_.]*)\s*.*?') - recursive_arg_re = re.compile(r'^\s+:recursive:\s*$') - toctree_arg_re = re.compile(r'^\s+:toctree:\s*(.*?)\s*$') - template_arg_re = re.compile(r'^\s+:template:\s*(.*?)\s*$') - - documented: List[AutosummaryEntry] = [] - - recursive = False - toctree: str = None - template = None - current_module = module - in_autosummary = False - base_indent = "" - - for line in lines: - if in_autosummary: - m = recursive_arg_re.match(line) - if m: - recursive = True - continue - - m = toctree_arg_re.match(line) - if m: - toctree = m.group(1) - if filename: - toctree = os.path.join(os.path.dirname(filename), - toctree) - continue - - m = template_arg_re.match(line) - if m: - template = m.group(1).strip() - continue - - if line.strip().startswith(':'): - continue # skip options - - m = autosummary_item_re.match(line) - if m: - name = m.group(1).strip() - if name.startswith('~'): - name = name[1:] - if current_module and \ - not name.startswith(current_module + '.'): - name = "%s.%s" % (current_module, name) - documented.append(AutosummaryEntry(name, toctree, template, recursive)) - continue - - if not line.strip() or line.startswith(base_indent + " "): - continue - - in_autosummary = False - - m = autosummary_re.match(line) - if m: - in_autosummary = True - base_indent = m.group(1) - recursive = False - toctree = None - template = None - continue - - m = automodule_re.search(line) - if m: - current_module = m.group(1).strip() - # recurse into the automodule docstring - documented.extend(find_autosummary_in_docstring( - current_module, filename=filename)) - continue - - m = module_re.match(line) - if m: - current_module = m.group(2) - continue - - return documented - - -def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - usage='%(prog)s [OPTIONS] ...', - epilog=__('For more information, visit .'), - description=__(""" -Generate ReStructuredText using autosummary directives. - -sphinx-autogen is a frontend to sphinx.ext.autosummary.generate. It generates -the reStructuredText files from the autosummary directives contained in the -given input files. - -The format of the autosummary directive is documented in the -``sphinx.ext.autosummary`` Python module and can be read using:: - - pydoc sphinx.ext.autosummary -""")) - - parser.add_argument('--version', action='version', dest='show_version', - version='%%(prog)s %s' % __display_version__) - - parser.add_argument('source_file', nargs='+', - help=__('source files to generate rST files for')) - - parser.add_argument('-o', '--output-dir', action='store', - dest='output_dir', - help=__('directory to place all output in')) - parser.add_argument('-s', '--suffix', action='store', dest='suffix', - default='rst', - help=__('default suffix for files (default: ' - '%(default)s)')) - parser.add_argument('-t', '--templates', action='store', dest='templates', - default=None, - help=__('custom template directory (default: ' - '%(default)s)')) - parser.add_argument('-i', '--imported-members', action='store_true', - dest='imported_members', default=False, - help=__('document imported members (default: ' - '%(default)s)')) - parser.add_argument('-a', '--respect-module-all', action='store_true', - dest='respect_module_all', default=False, - help=__('document exactly the members in module __all__ attribute. ' - '(default: %(default)s)')) - - return parser - - -def main(argv: List[str] = sys.argv[1:]) -> None: - sphinx.locale.setlocale(locale.LC_ALL, '') - sphinx.locale.init_console(os.path.join(package_dir, 'locale'), 'sphinx') - translator, _ = sphinx.locale.init([], None) - - app = DummyApplication(translator) - logging.setup(app, sys.stdout, sys.stderr) # type: ignore - setup_documenters(app) - args = get_parser().parse_args(argv) - - if args.templates: - app.config.templates_path.append(path.abspath(args.templates)) - app.config.autosummary_ignore_module_all = not args.respect_module_all # type: ignore - - generate_autosummary_docs(args.source_file, args.output_dir, - '.' + args.suffix, - imported_members=args.imported_members, - app=app) - - -if __name__ == '__main__': - main() diff --git a/docs/hub/docs/_ext/overwriteobjectiondirective.txt b/docs/hub/docs/_ext/overwriteobjectiondirective.txt deleted file mode 100644 index 8a58bf71191f77ca22097ea9de244c9df5c3d4fb..0000000000000000000000000000000000000000 --- a/docs/hub/docs/_ext/overwriteobjectiondirective.txt +++ /dev/null @@ -1,368 +0,0 @@ -""" - sphinx.directives - ~~~~~~~~~~~~~~~~~ - - Handlers for additional ReST directives. - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import re -import inspect -import importlib -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Tuple, TypeVar, cast - -from docutils import nodes -from docutils.nodes import Node -from docutils.parsers.rst import directives, roles - -from sphinx import addnodes -from sphinx.addnodes import desc_signature -from sphinx.deprecation import RemovedInSphinx50Warning, deprecated_alias -from sphinx.util import docutils, logging -from sphinx.util.docfields import DocFieldTransformer, Field, TypedField -from sphinx.util.docutils import SphinxDirective -from sphinx.util.typing import OptionSpec - -if TYPE_CHECKING: - from sphinx.application import Sphinx - - -# RE to strip backslash escapes -nl_escape_re = re.compile(r'\\\n') -strip_backslash_re = re.compile(r'\\(.)') - -T = TypeVar('T') -logger = logging.getLogger(__name__) - -def optional_int(argument: str) -> int: - """ - Check for an integer argument or None value; raise ``ValueError`` if not. - """ - if argument is None: - return None - else: - value = int(argument) - if value < 0: - raise ValueError('negative value; must be positive or zero') - return value - -def get_api(fullname): - try: - module_name, api_name= ".".join(fullname.split('.')[:-1]), fullname.split('.')[-1] - module_import = importlib.import_module(module_name) - except ModuleNotFoundError: - module_name, api_name = ".".join(fullname.split('.')[:-2]), ".".join(fullname.split('.')[-2:]) - module_import = importlib.import_module(module_name) - api = eval(f"module_import.{api_name}") - return api - -def get_example(name: str): - try: - api_doc = inspect.getdoc(get_api(name)) - example_str = re.findall(r'Examples:\n([\w\W]*?)(\n\n|$)', api_doc) - if not example_str: - return [] - example_str = re.sub(r'\n\s+', r'\n', example_str[0][0]) - example_str = example_str.strip() - example_list = example_str.split('\n') - return ["", "**样例:**", ""] + example_list + [""] - except: - return [] - -def get_platforms(name: str): - try: - api_doc = inspect.getdoc(get_api(name)) - example_str = re.findall(r'Supported Platforms:\n\s+(.*?)\n\n', api_doc) - if not example_str: - example_str_leak = re.findall(r'Supported Platforms:\n\s+(.*)', api_doc) - if example_str_leak: - example_str = example_str_leak[0].strip() - example_list = example_str.split('\n') - example_list = [' ' + example_list[0]] - return ["", "支持平台:"] + example_list + [""] - return [] - example_str = example_str[0].strip() - example_list = example_str.split('\n') - example_list = [' ' + example_list[0]] - return ["", "支持平台:"] + example_list + [""] - except: - return [] - -class ObjectDescription(SphinxDirective, Generic[T]): - """ - Directive to describe a class, function or similar object. Not used - directly, but subclassed (in domain-specific directives) to add custom - behavior. - """ - - has_content = True - required_arguments = 1 - optional_arguments = 0 - final_argument_whitespace = True - option_spec: OptionSpec = { - 'noindex': directives.flag, - } # type: Dict[str, DirectiveOption] - - # types of doc fields that this directive handles, see sphinx.util.docfields - doc_field_types: List[Field] = [] - domain: str = None - objtype: str = None - indexnode: addnodes.index = None - - # Warning: this might be removed in future version. Don't touch this from extensions. - _doc_field_type_map = {} # type: Dict[str, Tuple[Field, bool]] - - def get_field_type_map(self) -> Dict[str, Tuple[Field, bool]]: - if self._doc_field_type_map == {}: - self._doc_field_type_map = {} - for field in self.doc_field_types: - for name in field.names: - self._doc_field_type_map[name] = (field, False) - - if field.is_typed: - typed_field = cast(TypedField, field) - for name in typed_field.typenames: - self._doc_field_type_map[name] = (field, True) - - return self._doc_field_type_map - - def get_signatures(self) -> List[str]: - """ - Retrieve the signatures to document from the directive arguments. By - default, signatures are given as arguments, one per line. - - Backslash-escaping of newlines is supported. - """ - lines = nl_escape_re.sub('', self.arguments[0]).split('\n') - if self.config.strip_signature_backslash: - # remove backslashes to support (dummy) escapes; helps Vim highlighting - return [strip_backslash_re.sub(r'\1', line.strip()) for line in lines] - else: - return [line.strip() for line in lines] - - def handle_signature(self, sig: str, signode: desc_signature) -> Any: - """ - Parse the signature *sig* into individual nodes and append them to - *signode*. If ValueError is raised, parsing is aborted and the whole - *sig* is put into a single desc_name node. - - The return value should be a value that identifies the object. It is - passed to :meth:`add_target_and_index()` unchanged, and otherwise only - used to skip duplicates. - """ - raise ValueError - - def add_target_and_index(self, name: Any, sig: str, signode: desc_signature) -> None: - """ - Add cross-reference IDs and entries to self.indexnode, if applicable. - - *name* is whatever :meth:`handle_signature()` returned. - """ - return # do nothing by default - - def before_content(self) -> None: - """ - Called before parsing content. Used to set information about the current - directive context on the build environment. - """ - pass - - def transform_content(self, contentnode: addnodes.desc_content) -> None: - """ - Called after creating the content through nested parsing, - but before the ``object-description-transform`` event is emitted, - and before the info-fields are transformed. - Can be used to manipulate the content. - """ - pass - - def after_content(self) -> None: - """ - Called after parsing content. Used to reset information about the - current directive context on the build environment. - """ - pass - - def check_class_end(self, content): - for i in content: - if not i.startswith('.. include::') and i != "\n" and i != "": - return False - return True - - def extend_items(self, rst_file, start_num, num): - ls = [] - for i in range(1, num+1): - ls.append((rst_file, start_num+i)) - return ls - - def run(self) -> List[Node]: - """ - Main directive entry function, called by docutils upon encountering the - directive. - - This directive is meant to be quite easily subclassable, so it delegates - to several additional methods. What it does: - - * find out if called as a domain-specific directive, set self.domain - * create a `desc` node to fit all description inside - * parse standard options, currently `noindex` - * create an index node if needed as self.indexnode - * parse all given signatures (as returned by self.get_signatures()) - using self.handle_signature(), which should either return a name - or raise ValueError - * add index entries using self.add_target_and_index() - * parse the content and handle doc fields in it - """ - if ':' in self.name: - self.domain, self.objtype = self.name.split(':', 1) - else: - self.domain, self.objtype = '', self.name - self.indexnode = addnodes.index(entries=[]) - - node = addnodes.desc() - node.document = self.state.document - node['domain'] = self.domain - # 'desctype' is a backwards compatible attribute - node['objtype'] = node['desctype'] = self.objtype - node['noindex'] = noindex = ('noindex' in self.options) - if self.domain: - node['classes'].append(self.domain) - node['classes'].append(node['objtype']) - - self.names: List[T] = [] - signatures = self.get_signatures() - for sig in signatures: - # add a signature node for each signature in the current unit - # and add a reference target for it - signode = addnodes.desc_signature(sig, '') - self.set_source_info(signode) - node.append(signode) - try: - # name can also be a tuple, e.g. (classname, objname); - # this is strictly domain-specific (i.e. no assumptions may - # be made in this base class) - name = self.handle_signature(sig, signode) - except ValueError: - # signature parsing failed - signode.clear() - signode += addnodes.desc_name(sig, sig) - continue # we don't want an index entry here - if name not in self.names: - self.names.append(name) - if not noindex: - # only add target and index entry if this is the first - # description of the object with this name in this desc block - self.add_target_and_index(name, sig, signode) - - contentnode = addnodes.desc_content() - node.append(contentnode) - if self.names: - # needed for association of version{added,changed} directives - self.env.temp_data['object'] = self.names[0] - self.before_content() - try: - example = get_example(self.names[0][0]) - platforms = get_platforms(self.names[0][0]) - except Exception as e: - example = '' - platforms = '' - logger.warning(f'Error API names in {self.arguments[0]}.') - logger.warning(f'{e}') - extra = platforms + example - if extra: - if self.objtype == "method": - self.content.data.extend(extra) - else: - index_num = 0 - for num, i in enumerate(self.content.data): - if i.startswith('.. py:method::') or self.check_class_end(self.content.data[num:]): - index_num = num - break - if index_num: - count = len(self.content.data) - for i in extra: - self.content.data.insert(index_num-count, i) - else: - self.content.data.extend(extra) - try: - self.content.items.extend(self.extend_items(self.content.items[0][0], self.content.items[-1][1], len(extra))) - except Exception as e: - logger.warning(f'{e}') - self.state.nested_parse(self.content, self.content_offset, contentnode) - self.transform_content(contentnode) - self.env.app.emit('object-description-transform', - self.domain, self.objtype, contentnode) - DocFieldTransformer(self).transform_all(contentnode) - self.env.temp_data['object'] = None - self.after_content() - return [self.indexnode, node] - - -class DefaultRole(SphinxDirective): - """ - Set the default interpreted text role. Overridden from docutils. - """ - - optional_arguments = 1 - final_argument_whitespace = False - - def run(self) -> List[Node]: - if not self.arguments: - docutils.unregister_role('') - return [] - role_name = self.arguments[0] - role, messages = roles.role(role_name, self.state_machine.language, - self.lineno, self.state.reporter) - if role: - docutils.register_role('', role) - self.env.temp_data['default_role'] = role_name - else: - literal_block = nodes.literal_block(self.block_text, self.block_text) - reporter = self.state.reporter - error = reporter.error('Unknown interpreted text role "%s".' % role_name, - literal_block, line=self.lineno) - messages += [error] - - return cast(List[nodes.Node], messages) - - -class DefaultDomain(SphinxDirective): - """ - Directive to (re-)set the default domain for this source file. - """ - - has_content = False - required_arguments = 1 - optional_arguments = 0 - final_argument_whitespace = False - option_spec = {} # type: Dict - - def run(self) -> List[Node]: - domain_name = self.arguments[0].lower() - # if domain_name not in env.domains: - # # try searching by label - # for domain in env.domains.values(): - # if domain.label.lower() == domain_name: - # domain_name = domain.name - # break - self.env.temp_data['default_domain'] = self.env.domains.get(domain_name) - return [] - -def setup(app: "Sphinx") -> Dict[str, Any]: - app.add_config_value("strip_signature_backslash", False, 'env') - directives.register_directive('default-role', DefaultRole) - directives.register_directive('default-domain', DefaultDomain) - directives.register_directive('describe', ObjectDescription) - # new, more consistent, name - directives.register_directive('object', ObjectDescription) - - app.add_event('object-description-transform') - - return { - 'version': 'builtin', - 'parallel_read_safe': True, - 'parallel_write_safe': True, - } - diff --git a/docs/hub/docs/_ext/overwriteviewcode.txt b/docs/hub/docs/_ext/overwriteviewcode.txt deleted file mode 100644 index 172780ec56b3ed90e7b0add617257a618cf38ee0..0000000000000000000000000000000000000000 --- a/docs/hub/docs/_ext/overwriteviewcode.txt +++ /dev/null @@ -1,378 +0,0 @@ -""" - sphinx.ext.viewcode - ~~~~~~~~~~~~~~~~~~~ - - Add links to module code in Python object descriptions. - - :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" - -import posixpath -import traceback -import warnings -from os import path -from typing import Any, Dict, Generator, Iterable, Optional, Set, Tuple, cast - -from docutils import nodes -from docutils.nodes import Element, Node - -import sphinx -from sphinx import addnodes -from sphinx.application import Sphinx -from sphinx.builders import Builder -from sphinx.builders.html import StandaloneHTMLBuilder -from sphinx.deprecation import RemovedInSphinx50Warning -from sphinx.environment import BuildEnvironment -from sphinx.locale import _, __ -from sphinx.pycode import ModuleAnalyzer -from sphinx.transforms.post_transforms import SphinxPostTransform -from sphinx.util import get_full_modname, logging, status_iterator -from sphinx.util.nodes import make_refnode - - -logger = logging.getLogger(__name__) - - -OUTPUT_DIRNAME = '_modules' - - -class viewcode_anchor(Element): - """Node for viewcode anchors. - - This node will be processed in the resolving phase. - For viewcode supported builders, they will be all converted to the anchors. - For not supported builders, they will be removed. - """ - - -def _get_full_modname(app: Sphinx, modname: str, attribute: str) -> Optional[str]: - try: - return get_full_modname(modname, attribute) - except AttributeError: - # sphinx.ext.viewcode can't follow class instance attribute - # then AttributeError logging output only verbose mode. - logger.verbose('Didn\'t find %s in %s', attribute, modname) - return None - except Exception as e: - # sphinx.ext.viewcode follow python domain directives. - # because of that, if there are no real modules exists that specified - # by py:function or other directives, viewcode emits a lot of warnings. - # It should be displayed only verbose mode. - logger.verbose(traceback.format_exc().rstrip()) - logger.verbose('viewcode can\'t import %s, failed with error "%s"', modname, e) - return None - - -def is_supported_builder(builder: Builder) -> bool: - if builder.format != 'html': - return False - elif builder.name == 'singlehtml': - return False - elif builder.name.startswith('epub') and not builder.config.viewcode_enable_epub: - return False - else: - return True - - -def doctree_read(app: Sphinx, doctree: Node) -> None: - env = app.builder.env - if not hasattr(env, '_viewcode_modules'): - env._viewcode_modules = {} # type: ignore - - def has_tag(modname: str, fullname: str, docname: str, refname: str) -> bool: - entry = env._viewcode_modules.get(modname, None) # type: ignore - if entry is False: - return False - - code_tags = app.emit_firstresult('viewcode-find-source', modname) - if code_tags is None: - try: - analyzer = ModuleAnalyzer.for_module(modname) - analyzer.find_tags() - except Exception: - env._viewcode_modules[modname] = False # type: ignore - return False - - code = analyzer.code - tags = analyzer.tags - else: - code, tags = code_tags - - if entry is None or entry[0] != code: - entry = code, tags, {}, refname - env._viewcode_modules[modname] = entry # type: ignore - _, tags, used, _ = entry - if fullname in tags: - used[fullname] = docname - return True - - return False - - for objnode in list(doctree.findall(addnodes.desc)): - if objnode.get('domain') != 'py': - continue - names: Set[str] = set() - for signode in objnode: - if not isinstance(signode, addnodes.desc_signature): - continue - modname = signode.get('module') - fullname = signode.get('fullname') - try: - if fullname and modname==None: - if fullname.split('.')[-1].lower() == fullname.split('.')[-1] and fullname.split('.')[-2].lower() != fullname.split('.')[-2]: - modname = '.'.join(fullname.split('.')[:-2]) - fullname = '.'.join(fullname.split('.')[-2:]) - else: - modname = '.'.join(fullname.split('.')[:-1]) - fullname = fullname.split('.')[-1] - fullname_new = fullname - except Exception: - logger.warning(f'error_modename:{modname}') - logger.warning(f'error_fullname:{fullname}') - refname = modname - if env.config.viewcode_follow_imported_members: - new_modname = app.emit_firstresult( - 'viewcode-follow-imported', modname, fullname, - ) - if not new_modname: - new_modname = _get_full_modname(app, modname, fullname) - modname = new_modname - # logger.warning(f'new_modename:{modname}') - if not modname: - continue - # fullname = signode.get('fullname') - # if fullname and modname==None: - fullname = fullname_new - if not has_tag(modname, fullname, env.docname, refname): - continue - if fullname in names: - # only one link per name, please - continue - names.add(fullname) - pagename = posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/')) - signode += viewcode_anchor(reftarget=pagename, refid=fullname, refdoc=env.docname) - - -def env_merge_info(app: Sphinx, env: BuildEnvironment, docnames: Iterable[str], - other: BuildEnvironment) -> None: - if not hasattr(other, '_viewcode_modules'): - return - # create a _viewcode_modules dict on the main environment - if not hasattr(env, '_viewcode_modules'): - env._viewcode_modules = {} # type: ignore - # now merge in the information from the subprocess - for modname, entry in other._viewcode_modules.items(): # type: ignore - if modname not in env._viewcode_modules: # type: ignore - env._viewcode_modules[modname] = entry # type: ignore - else: - if env._viewcode_modules[modname]: # type: ignore - used = env._viewcode_modules[modname][2] # type: ignore - for fullname, docname in entry[2].items(): - if fullname not in used: - used[fullname] = docname - - -def env_purge_doc(app: Sphinx, env: BuildEnvironment, docname: str) -> None: - modules = getattr(env, '_viewcode_modules', {}) - - for modname, entry in list(modules.items()): - if entry is False: - continue - - code, tags, used, refname = entry - for fullname in list(used): - if used[fullname] == docname: - used.pop(fullname) - - if len(used) == 0: - modules.pop(modname) - - -class ViewcodeAnchorTransform(SphinxPostTransform): - """Convert or remove viewcode_anchor nodes depends on builder.""" - default_priority = 100 - - def run(self, **kwargs: Any) -> None: - if is_supported_builder(self.app.builder): - self.convert_viewcode_anchors() - else: - self.remove_viewcode_anchors() - - def convert_viewcode_anchors(self) -> None: - for node in self.document.findall(viewcode_anchor): - anchor = nodes.inline('', _('[源代码]'), classes=['viewcode-link']) - refnode = make_refnode(self.app.builder, node['refdoc'], node['reftarget'], - node['refid'], anchor) - node.replace_self(refnode) - - def remove_viewcode_anchors(self) -> None: - for node in list(self.document.findall(viewcode_anchor)): - node.parent.remove(node) - - -def missing_reference(app: Sphinx, env: BuildEnvironment, node: Element, contnode: Node - ) -> Optional[Node]: - # resolve our "viewcode" reference nodes -- they need special treatment - if node['reftype'] == 'viewcode': - warnings.warn('viewcode extension is no longer use pending_xref node. ' - 'Please update your extension.', RemovedInSphinx50Warning) - return make_refnode(app.builder, node['refdoc'], node['reftarget'], - node['refid'], contnode) - - return None - - -def get_module_filename(app: Sphinx, modname: str) -> Optional[str]: - """Get module filename for *modname*.""" - source_info = app.emit_firstresult('viewcode-find-source', modname) - if source_info: - return None - else: - try: - filename, source = ModuleAnalyzer.get_module_source(modname) - return filename - except Exception: - return None - - -def should_generate_module_page(app: Sphinx, modname: str) -> bool: - """Check generation of module page is needed.""" - module_filename = get_module_filename(app, modname) - if module_filename is None: - # Always (re-)generate module page when module filename is not found. - return True - - builder = cast(StandaloneHTMLBuilder, app.builder) - basename = modname.replace('.', '/') + builder.out_suffix - page_filename = path.join(app.outdir, '_modules/', basename) - - try: - if path.getmtime(module_filename) <= path.getmtime(page_filename): - # generation is not needed if the HTML page is newer than module file. - return False - except IOError: - pass - - return True - - -def collect_pages(app: Sphinx) -> Generator[Tuple[str, Dict[str, Any], str], None, None]: - env = app.builder.env - if not hasattr(env, '_viewcode_modules'): - return - if not is_supported_builder(app.builder): - return - highlighter = app.builder.highlighter # type: ignore - urito = app.builder.get_relative_uri - - modnames = set(env._viewcode_modules) # type: ignore - - for modname, entry in status_iterator( - sorted(env._viewcode_modules.items()), # type: ignore - __('highlighting module code... '), "blue", - len(env._viewcode_modules), # type: ignore - app.verbosity, lambda x: x[0]): - if not entry: - continue - if not should_generate_module_page(app, modname): - continue - - code, tags, used, refname = entry - # construct a page name for the highlighted source - pagename = posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/')) - # highlight the source using the builder's highlighter - if env.config.highlight_language in ('python3', 'default', 'none'): - lexer = env.config.highlight_language - else: - lexer = 'python' - highlighted = highlighter.highlight_block(code, lexer, linenos=False) - # split the code into lines - lines = highlighted.splitlines() - # split off wrap markup from the first line of the actual code - before, after = lines[0].split('
    ')
    -        lines[0:1] = [before + '
    ', after]
    -        # nothing to do for the last line; it always starts with 
    anyway - # now that we have code lines (starting at index 1), insert anchors for - # the collected tags (HACK: this only works if the tag boundaries are - # properly nested!) - maxindex = len(lines) - 1 - for name, docname in used.items(): - type, start, end = tags[name] - backlink = urito(pagename, docname) + '#' + refname + '.' + name - lines[start] = ( - '
    %s' % (name, backlink, _('[文档]')) + - lines[start]) - lines[min(end, maxindex)] += '
    ' - # try to find parents (for submodules) - parents = [] - parent = modname - while '.' in parent: - parent = parent.rsplit('.', 1)[0] - if parent in modnames: - parents.append({ - 'link': urito(pagename, - posixpath.join(OUTPUT_DIRNAME, parent.replace('.', '/'))), - 'title': parent}) - parents.append({'link': urito(pagename, posixpath.join(OUTPUT_DIRNAME, 'index')), - 'title': _('Module code')}) - parents.reverse() - # putting it all together - context = { - 'parents': parents, - 'title': modname, - 'body': (_('

    Source code for %s

    ') % modname + - '\n'.join(lines)), - } - yield (pagename, context, 'page.html') - - if not modnames: - return - - html = ['\n'] - # the stack logic is needed for using nested lists for submodules - stack = [''] - for modname in sorted(modnames): - if modname.startswith(stack[-1]): - stack.append(modname + '.') - html.append('
      ') - else: - stack.pop() - while not modname.startswith(stack[-1]): - stack.pop() - html.append('
    ') - stack.append(modname + '.') - html.append('
  • %s
  • \n' % ( - urito(posixpath.join(OUTPUT_DIRNAME, 'index'), - posixpath.join(OUTPUT_DIRNAME, modname.replace('.', '/'))), - modname)) - html.append('' * (len(stack) - 1)) - context = { - 'title': _('Overview: module code'), - 'body': (_('

    All modules for which code is available

    ') + - ''.join(html)), - } - - yield (posixpath.join(OUTPUT_DIRNAME, 'index'), context, 'page.html') - - -def setup(app: Sphinx) -> Dict[str, Any]: - app.add_config_value('viewcode_import', None, False) - app.add_config_value('viewcode_enable_epub', False, False) - app.add_config_value('viewcode_follow_imported_members', True, False) - app.connect('doctree-read', doctree_read) - app.connect('env-merge-info', env_merge_info) - app.connect('env-purge-doc', env_purge_doc) - app.connect('html-collect-pages', collect_pages) - app.connect('missing-reference', missing_reference) - # app.add_config_value('viewcode_include_modules', [], 'env') - # app.add_config_value('viewcode_exclude_modules', [], 'env') - app.add_event('viewcode-find-source') - app.add_event('viewcode-follow-imported') - app.add_post_transform(ViewcodeAnchorTransform) - return { - 'version': sphinx.__display_version__, - 'env_version': 1, - 'parallel_read_safe': True - } diff --git a/docs/hub/docs/requirements.txt b/docs/hub/docs/requirements.txt deleted file mode 100644 index a1b6a69f6dbd9c6f78710f56889e14f0e85b27f4..0000000000000000000000000000000000000000 --- a/docs/hub/docs/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -sphinx == 4.4.0 -docutils == 0.17.1 -myst-parser == 0.18.1 -sphinx_rtd_theme == 1.0.0 -numpy -IPython -jieba diff --git a/docs/hub/docs/source_en/conf.py b/docs/hub/docs/source_en/conf.py deleted file mode 100644 index b2c26a05bba70436e84f063d295b6bfc7b0f1f79..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_en/conf.py +++ /dev/null @@ -1,193 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import shutil -import IPython -import re -import sys -import sphinx.ext.autosummary.generate as g -from sphinx.ext import autodoc as sphinx_autodoc - -import mindspore_hub - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_parser', - 'sphinx.ext.mathjax', - 'IPython.sphinxext.ipython_console_highlighting' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -autodoc_inherit_docstrings = False - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -import sphinx_rtd_theme -layout_target = os.path.join(os.path.dirname(sphinx_rtd_theme.__file__), 'layout.html') -layout_src = '../../../../resource/_static/layout.html' -if os.path.exists(layout_target): - os.remove(layout_target) -shutil.copy(layout_src, layout_target) - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/', '../../../../resource/python_objects.inv'), - 'numpy': ('https://docs.scipy.org/doc/numpy/', '../../../../resource/numpy_objects.inv'), -} - -# overwriteautosummary_generate add view source for api and more autosummary class availably. -with open('../_ext/overwriteautosummary_generate.txt', 'r', encoding="utf8") as f: - exec(f.read(), g.__dict__) - -# Modify default signatures for autodoc. -autodoc_source_path = os.path.abspath(sphinx_autodoc.__file__) -autodoc_source_re = re.compile(r'stringify_signature\(.*?\)') -get_param_func_str = r"""\ -import re -import inspect as inspect_ - -def get_param_func(func): - try: - source_code = inspect_.getsource(func) - if func.__doc__: - source_code = source_code.replace(func.__doc__, '') - all_params_str = re.findall(r"def [\w_\d\-]+\(([\S\s]*?)(\):|\) ->.*?:)", source_code) - all_params = re.sub("(self|cls)(,|, )?", '', all_params_str[0][0].replace("\n", "").replace("'", "\"")) - return all_params - except: - return '' - -def get_obj(obj): - if isinstance(obj, type): - return obj.__init__ - - return obj -""" - -with open(autodoc_source_path, "r+", encoding="utf8") as f: - code_str = f.read() - code_str = autodoc_source_re.sub('"(" + get_param_func(get_obj(self.object)) + ")"', code_str, count=0) - exec(get_param_func_str, sphinx_autodoc.__dict__) - exec(code_str, sphinx_autodoc.__dict__) - -# Copy source files of chinese python api from mindpandas repository. -from sphinx.util import logging -logger = logging.getLogger(__name__) - -src_dir_en = os.path.join(os.getenv("HB_PATH"), 'docs/api_python_en') -present_path = os.path.dirname(__file__) - -for i in os.listdir(src_dir_en): - if os.path.isfile(os.path.join(src_dir_en,i)): - if os.path.exists('./'+i): - os.remove('./'+i) - shutil.copy(os.path.join(src_dir_en,i),'./'+i) - else: - if os.path.exists('./'+i): - shutil.rmtree('./'+i) - shutil.copytree(os.path.join(src_dir_en,i),'./'+i) - -# get params for add view source -import json - -if os.path.exists('../../../../tools/generate_html/version.json'): - with open('../../../../tools/generate_html/version.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily_dev.json'): - with open('../../../../tools/generate_html/daily_dev.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily.json'): - with open('../../../../tools/generate_html/daily.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) - -if os.getenv("HB_PATH").split('/')[-1]: - copy_repo = os.getenv("HB_PATH").split('/')[-1] -else: - copy_repo = os.getenv("HB_PATH").split('/')[-2] - -branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == copy_repo][0] -docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == 'tutorials'][0] -cst_module_name = 'mindspore_hub' -repo_whl = 'mindspore_hub' -giturl = 'https://gitee.com/mindspore/' - -sys.path.append(os.path.abspath('../../../../resource/sphinx_ext')) -# import anchor_mod -import nbsphinx_mod - -sys.path.append(os.path.abspath('../../../../resource/search')) -import search_code - -sys.path.append(os.path.abspath('../../../../resource/custom_directives')) -from custom_directives import IncludeCodeDirective - -def setup(app): - app.add_directive('includecode', IncludeCodeDirective) - app.add_config_value('docs_branch', '', True) - app.add_config_value('branch', '', True) - app.add_config_value('cst_module_name', '', True) - app.add_config_value('copy_repo', '', True) - app.add_config_value('giturl', '', True) - app.add_config_value('repo_whl', '', True) diff --git a/docs/hub/docs/source_en/hub_installation.md b/docs/hub/docs/source_en/hub_installation.md deleted file mode 100644 index 34c2c8777dc8b6abceb76b17b52aee805eb52199..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_en/hub_installation.md +++ /dev/null @@ -1,116 +0,0 @@ -# MindSpore Hub Installation - -- [MindSpore Hub Installation](#mindspore-hub-installation) - - [System Environment Information Confirmation](#system-environment-information-confirmation) - - [Installation Methods](#installation-methods) - - [Installation by pip](#installation-by-pip) - - [Installation by Source Code](#installation-by-source-code) - - [Installation Verification](#installation-verification) - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/hub/docs/source_en/hub_installation.md) - -## System Environment Information Confirmation - -- The hardware platform supports Ascend, GPU and CPU. -- Confirm that [Python](https://www.python.org/ftp/python/3.7.5/Python-3.7.5.tgz) 3.7.5 is installed. -- The versions of MindSpore Hub and MindSpore must be consistent. -- MindSpore Hub supports only Linux distro with x86 architecture 64-bit or ARM architecture 64-bit. -- When the network is connected, dependency items in the `setup.py` file are automatically downloaded during .whl package installation. In other cases, you need to manually install dependency items. - -## Installation Methods - -You can install MindSpore Hub either by pip or by source code. - -### Installation by pip - -Install MindSpore Hub using `pip` command. `hub` depends on the MindSpore version used in current environment. - -Download and install the MindSpore Hub whl package in [Release List](https://www.mindspore.cn/versions/en). - -```shell -pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/{version}/Hub/any/mindspore_hub-{version}-py3-none-any.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple -``` - -> - `{version}` denotes the version of MindSpore Hub. For example, when you are downloading MindSpore Hub 1.3.0, `{version}` should be 1.3.0. - -### Installation by Source Code - -1. Download source code from Gitee. - - ```bash - git clone https://gitee.com/mindspore/hub.git - ``` - -2. Compile and install in MindSpore Hub directory. - - ```bash - cd hub - python setup.py install - ``` - -## Installation Verification - -Run the following command in a network-enabled environment to verify the installation. - -```python -import mindspore_hub as mshub - -model = mshub.load("mindspore/1.6/lenet_mnist", num_class=10) -``` - -If it prompts the following information, the installation is successful: - -```text -Downloading data from url https://gitee.com/mindspore/hub/raw/master/mshub_res/assets/mindspore/1.6/lenet_mnist.md - -Download finished! -File size = 0.00 Mb -Checking /home/ma-user/.mscache/mindspore/1.6/lenet_mnist.md...Passed! -``` - -## FAQ - -**Q: What to do when `SSL: CERTIFICATE_VERIFY_FAILED` occurs?** - -A: Due to your network environment, for example, if you use a proxy to connect to the Internet, SSL verification failure may occur on Python because of incorrect certificate configuration. In this case, you can use either of the following methods to solve this problem: - -- Configure the SSL certificate **(recommended)**. -- Before import mindspore_hub, please add the codes (the fastest method). - - ```python - import ssl - ssl._create_default_https_context = ssl._create_unverified_context - - import mindspore_hub as mshub - model = mshub.load("mindspore/1.6/lenet_mnist", num_classes=10) - ``` - -**Q: What to do when `No module named src.*` occurs**? - -A: When you use mindspore_hub.load to load differenet models in the same process, because the model file path needs to be inserted into sys.path. Test results show that Python only looks for src.* in the first inserted path. It's no use to delete the first inserted path. To solve the problem, you can copy all model files to the working directory. The code is as follows: - -```python -# mindspore_hub_install_path/load.py -def _copy_all_file_to_target_path(path, target_path): - if not os.path.exists(target_path): - os.makedirs(target_path) - path = os.path.realpath(path) - target_path = os.path.realpath(target_path) - for p in os.listdir(path): - copy_path = os.path.join(path, p) - target_dir = os.path.join(target_path, p) - _delete_if_exist(target_dir) - if os.path.isdir(copy_path): - _copy_all_file_to_target_path(copy_path, target_dir) - else: - shutil.copy(copy_path, target_dir) - -def _get_network_from_cache(name, path, *args, **kwargs): - _copy_all_file_to_target_path(path, os.getcwd()) - config_path = os.path.join(os.getcwd(), HUB_CONFIG_FILE) - if not os.path.exists(config_path): - raise ValueError('{} not exists.'.format(config_path)) - ...... -``` - -**Note**: Some files of the previous model may be replaced when the next model is loaded. However, necessary model files must exist during model training. Therefore, you must finish training the previous model before the next model loads. diff --git a/docs/hub/docs/source_en/index.rst b/docs/hub/docs/source_en/index.rst deleted file mode 100644 index ea3aa09c9f94f225652cced4cb2f2b2f4eae9b67..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_en/index.rst +++ /dev/null @@ -1,69 +0,0 @@ -MindSpore Hub Documents -========================= - -MindSpore Hub is a pre-trained model application tool of the MindSpore ecosystem. It provides the following functions: - -- Plug-and-play model loading -- Easy-to-use transfer learning - -.. code-block:: - - import mindspore - import mindspore_hub as mshub - from mindspore import set_context, GRAPH_MODE - - set_context(mode=GRAPH_MODE, - device_target="Ascend", - device_id=0) - - model = "mindspore/1.6/googlenet_cifar10" - - # Initialize the number of classes based on the pre-trained model. - network = mshub.load(model, num_classes=10) - network.set_train(False) - - # ... - -Code repository address: - -Typical Application Scenarios --------------------------------------------- - -1. `Inference Validation `_ - - With only one line of code, use mindspore_hub.load to load the pre-trained model. - -2. `Transfer Learning `_ - - After loading models using mindspore_hub.load, add an extra argument to load the feature extraction of the neural network. This makes it easier to add new layers for transfer learning. - -3. `Model Releasing `_ - - Release the trained model to MindSpore Hub according to the specified procedure for download and use. - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Installation - - hub_installation - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Guide - - loading_model_from_hub - publish_model - -.. toctree:: - :maxdepth: 1 - :caption: API References - - hub - -.. toctree:: - :maxdepth: 1 - :caption: Models - - MindSpore Hub↗ diff --git a/docs/hub/docs/source_en/loading_model_from_hub.md b/docs/hub/docs/source_en/loading_model_from_hub.md deleted file mode 100644 index 02372d7a6a0a274f300e607a11fc3527bd64968c..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_en/loading_model_from_hub.md +++ /dev/null @@ -1,212 +0,0 @@ -# Loading the Model from Hub - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/hub/docs/source_en/loading_model_from_hub.md) - -## Overview - -For individual developers, training a better model from scratch requires a lot of well-labeled data, sufficient computational resources, and a lot of training and debugging time. It makes model training very resource-consuming and raises the threshold of AI development. To solve the above problems, MindSpore Hub provides a lot of model weight files with completed training, which can enable developers to quickly train a better model with a small amount of data and only a small amount of training time. - -This document demonstrates the use of the models provided by MindSpore Hub for both inference verification and migration learning, and shows how to quickly complete training with a small amount of data to get a better model. - -## For Inference Validation - -`mindspore_hub.load` API is used to load the pre-trained model in a single line of code. The main process of model loading is as follows: - -1. Search the model of interest on [MindSpore Hub Website](https://www.mindspore.cn/resources/hub). - - For example, if you aim to perform image classification on CIFAR-10 dataset using GoogleNet, please search on [MindSpore Hub Website](https://www.mindspore.cn/resources/hub) with the keyword `GoogleNet`. Then all related models will be returned. Once you enter into the related model page, you can find the `Usage`. **Notices**: if the model page doesn't have `Usage`, it means that the current model does not support loading with MindSpore Hub temporarily. - -2. Complete the task of loading model according to the `Usage` , as shown in the example below: - - ```python - import mindspore_hub as mshub - import mindspore - from mindspore import Tensor, nn, Model, set_context, GRAPH_MODE - from mindspore import dtype as mstype - import mindspore.dataset.vision as vision - - set_context(mode=GRAPH_MODE, - device_target="Ascend", - device_id=0) - - model = "mindspore/1.6/googlenet_cifar10" - - # Initialize the number of classes based on the pre-trained model. - network = mshub.load(model, num_classes=10) - network.set_train(False) - - # ... - - ``` - -3. After loading the model, you can use MindSpore to do inference. You can refer to [Multi-Platform Inference Overview](https://www.mindspore.cn/tutorials/en/master/model_infer/ms_infer/llm_inference_overview.html). - -## For Transfer Training - -When loading a model with `mindspore_hub.load` API, we can add an extra argument to load the feature extraction part of the model only. So we can easily add new layers to perform transfer learning. This feature can be found in the related model page when an extra argument (e.g., include_top) has been integrated into the model construction by the model developer. The value of `include_top` is True or False, indicating whether to keep the top layer in the fully-connected network. - -We use [MobileNetV2](https://gitee.com/mindspore/models/tree/master/research/cv/centerface) as an example to illustrate how to load a model trained on the ImageNet dataset and then perform transfer learning (re-training) on a specific sub-task dataset. The main steps are listed below: - -1. Search the model of interest on [MindSpore Hub Website](https://www.mindspore.cn/resources/hub/) and find the corresponding `Usage`. - -2. Load the model from MindSpore Hub using the `Usage`. Note that the parameter `include_top` is provided by the model developer. - - ```python - import os - import mindspore_hub as mshub - import mindspore - from mindspore import Tensor, nn, set_context, GRAPH_MODE, train - from mindspore.nn import Momentum - from mindspore import save_checkpoint, load_checkpoint,load_param_into_net - from mindspore import ops - import mindspore.dataset as ds - import mindspore.dataset.transforms as transforms - import mindspore.dataset.vision as vision - from mindspore import dtype as mstype - from mindspore import Model - set_context(mode=GRAPH_MODE, device_target="Ascend", device_id=0) - - model = "mindspore/1.6/mobilenetv2_imagenet2012" - network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid") - network.set_train(False) - ``` - -3. Add a new classification layer into current model architecture. - - ```python - class ReduceMeanFlatten(nn.Cell): - def __init__(self): - super(ReduceMeanFlatten, self).__init__() - self.mean = ops.ReduceMean(keep_dims=True) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.mean(x, (2, 3)) - x = self.flatten(x) - return x - - # Check MindSpore Hub website to conclude that the last output shape is 1280. - last_channel = 1280 - - # The number of classes in target task is 10. - num_classes = 10 - - reducemean_flatten = ReduceMeanFlatten() - - classification_layer = nn.Dense(last_channel, num_classes) - classification_layer.set_train(True) - - train_network = nn.SequentialCell([network, reducemean_flatten, classification_layer]) - ``` - -4. Define `dataset_loader`. - - As shown below, the new dataset used for fine-tuning is the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html). It is noted here we need to download the `binary version` dataset. After downloading and decompression, the following code can be used for data loading and processing. It is noted the `dataset_path` is the path to the dataset and should be given by the user. - - ```python - def create_cifar10dataset(dataset_path, batch_size, usage='train', shuffle=True): - data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=shuffle) - - # define map operations - trans = [ - vision.Resize((256, 256)), - vision.RandomHorizontalFlip(prob=0.5), - vision.Rescale(1.0 / 255.0, 0.0), - vision.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), - vision.HWC2CHW() - ] - - type_cast_op = transforms.TypeCast(mstype.int32) - - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) - data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) - - # apply batch operations - data_set = data_set.batch(batch_size, drop_remainder=True) - return data_set - - # Create Dataset - dataset_path = "/path_to_dataset/cifar-10-batches-bin" - dataset = create_cifar10dataset(dataset_path, batch_size=32, usage='train', shuffle=True) - ``` - -5. Define `loss`, `optimizer` and `learning rate`. - - ```python - def generate_steps_lr(lr_init, steps_per_epoch, total_epochs): - total_steps = total_epochs * steps_per_epoch - decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps] - lr_each_step = [] - for i in range(total_steps): - if i < decay_epoch_index[0]: - lr = lr_init - elif i < decay_epoch_index[1]: - lr = lr_init * 0.1 - elif i < decay_epoch_index[2]: - lr = lr_init * 0.01 - else: - lr = lr_init * 0.001 - lr_each_step.append(lr) - return lr_each_step - - # Set epoch size - epoch_size = 60 - - # Wrap the backbone network with loss. - loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - loss_net = nn.WithLossCell(train_network, loss_fn) - steps_per_epoch = dataset.get_dataset_size() - lr = generate_steps_lr(lr_init=0.01, steps_per_epoch=steps_per_epoch, total_epochs=epoch_size) - - # Create an optimizer. - optim = Momentum(filter(lambda x: x.requires_grad, classification_layer.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5) - train_net = nn.TrainOneStepCell(loss_net, optim) - ``` - -6. Start fine-tuning. - - ```python - for epoch in range(epoch_size): - for i, items in enumerate(dataset): - data, label = items - data = mindspore.Tensor(data) - label = mindspore.Tensor(label) - - loss = train_net(data, label) - print(f"epoch: {epoch}/{epoch_size}, loss: {loss}") - # Save the ckpt file for each epoch. - if not os.path.exists('ckpt'): - os.mkdir('ckpt') - ckpt_path = f"./ckpt/cifar10_finetune_epoch{epoch}.ckpt" - save_checkpoint(train_network, ckpt_path) - ``` - -7. Eval on test set. - - ```python - model = "mindspore/1.6/mobilenetv2_imagenet2012" - - network = mshub.load(model, num_classes=500, pretrained=True, include_top=False, activation="Sigmoid") - network.set_train(False) - reducemean_flatten = ReduceMeanFlatten() - classification_layer = nn.Dense(last_channel, num_classes) - classification_layer.set_train(False) - softmax = nn.Softmax() - network = nn.SequentialCell([network, reducemean_flatten, classification_layer, softmax]) - - # Load a pre-trained ckpt file. - ckpt_path = "./ckpt/cifar10_finetune_epoch59.ckpt" - trained_ckpt = load_checkpoint(ckpt_path) - load_param_into_net(classification_layer, trained_ckpt) - - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - - # Define loss and create model. - eval_dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=False) - eval_metrics = {'Loss': train.Loss(), - 'Top1-Acc': train.Top1CategoricalAccuracy(), - 'Top5-Acc': train.Top5CategoricalAccuracy()} - model = Model(network, loss_fn=loss, optimizer=None, metrics=eval_metrics) - metrics = model.eval(eval_dataset) - print("metric: ", metrics) - ``` diff --git a/docs/hub/docs/source_en/publish_model.md b/docs/hub/docs/source_en/publish_model.md deleted file mode 100644 index c21df51361421c16a53f56e0a2af2e2c73c52b68..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_en/publish_model.md +++ /dev/null @@ -1,74 +0,0 @@ -# Publishing Models Using MindSpore Hub - -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/hub/docs/source_en/publish_model.md) - -## Overview - -[MindSpore Hub](https://www.mindspore.cn/resources/hub/) is a platform for storing pre-trained models provided by MindSpore or third-party developers. It provides application developers with simple model loading and fine-tuning APIs, which enables the users to perform inference or fine-tuning based on the pre-trained models and thus deploy to their own applications. Users can also submit their pre-trained models into MindSpore Hub following the specific steps. Thus other users can download and use the published models. - -This tutorial uses GoogleNet as an example to describe how to submit models for model developers who are interested in publishing models into MindSpore Hub. - -## How to Publish Models - -You can publish models to MindSpore Hub via PR in [hub](https://gitee.com/mindspore/hub) repo. Here we use GoogleNet as an example to list the steps of model submission to MindSpore Hub. - -1. Host your pre-trained model in a storage location where we are able to access. - -2. Add a model generation python file called `mindspore_hub_conf.py` in your own repo using this [template](https://gitee.com/mindspore/models/blob/master/research/cv/SE_ResNeXt50/mindspore_hub_conf.py). The location of the `mindspore_hub_conf.py` file is shown below: - - ```text - googlenet - ├── src - │   ├── googlenet.py - ├── script - │   ├── run_train.sh - ├── train.py - ├── test.py - ├── mindspore_hub_conf.py - ``` - -3. Create a `{model_name}_{dataset}.md` file in `hub/mshub_res/assets/mindspore/1.6` using this [template](https://gitee.com/mindspore/hub/blob/master/mshub_res/assets/mindspore/1.6/googlenet_cifar10.md#). Here `1.6` indicates the MindSpore version. The structure of the `hub/mshub_res` folder is as follows: - - ```text - hub - ├── mshub_res - │   ├── assets - │   ├── mindspore - │ ├── 1.6 - │ ├── googlenet_cifar10.md - │   ├── tools - │ ├── get_sha256.py - │ ├── load_markdown.py - │ └── md_validator.py - ``` - - Note that it is required to fill in the `{model_name}_{dataset}.md` template by providing `file-format`, `asset-link` and `asset-sha256` below, which refers to the model file format, model storage location from step 1 and model hash value, respectively. - - ```text - file-format: ckpt - asset-link: https://download.mindspore.cn/models/r1.6/googlenet_ascend_v160_cifar10_official_cv_acc92.53.ckpt - asset-sha256: b2f7fe14782a3ab88ad3534ed5f419b4bbc3b477706258bd6ed8f90f529775e7 - ``` - - The MindSpore Hub supports multiple model file formats including: - - [MindSpore CKPT](https://www.mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model) - - [MindIR](https://www.mindspore.cn/tutorials/en/master/beginner/save_load.html) - - [AIR](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.export.html) - - [ONNX](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.export.html) - - For each pre-trained model, please run the following command to obtain a hash value required at `asset-sha256` of this `.md` file. Here the pre-trained model `googlenet.ckpt` is accessed from the storage location in step 1 and then saved in `tools` folder. The output hash value is: `b2f7fe14782a3ab88ad3534ed5f419b4bbc3b477706258bd6ed8f90f529775e7`. - - ```bash - cd /hub/mshub_res/tools - python get_sha256.py --file ../googlenet.ckpt - ``` - -4. Check the format of the markdown file locally using `hub/mshub_res/tools/md_validator.py` by running the following command. The output is `All Passed`, which indicates that the format and content of the `.md` file meets the requirements. - - ```bash - python md_validator.py --check_path ../assets/mindspore/1.6/googlenet_cifar10.md - ``` - -5. Create a PR in `mindspore/hub` repo. See our [Contributor Wiki](https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md#) for more information about creating a PR. - -Once your PR is merged into master branch here, your model will show up in [MindSpore Hub Website](https://www.mindspore.cn/resources/hub) within 24 hours. Please refer to [README](https://gitee.com/mindspore/hub/blob/master/mshub_res/README.md#) for more information about model submission. diff --git a/docs/hub/docs/source_zh_cn/conf.py b/docs/hub/docs/source_zh_cn/conf.py deleted file mode 100644 index b4697de34479918cd800e533f1ff3cde9817da2a..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_zh_cn/conf.py +++ /dev/null @@ -1,225 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import IPython -import re -import sys -from sphinx.ext import autodoc as sphinx_autodoc -import shutil - -import mindspore_hub - -# -- Project information ----------------------------------------------------- - -project = 'MindSpore' -copyright = 'MindSpore' -author = 'MindSpore' - -# The full version, including alpha/beta/rc tags -release = 'master' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -myst_enable_extensions = ["dollarmath", "amsmath"] - - -myst_heading_anchors = 5 -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_parser', - 'sphinx.ext.mathjax', - 'IPython.sphinxext.ipython_console_highlighting' -] - -source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -mathjax_path = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/mathjax/MathJax-3.2.2/es5/tex-mml-chtml.js' - -mathjax_options = { - 'async':'async' -} - -smartquotes_action = 'De' - -exclude_patterns = [] - -pygments_style = 'sphinx' - -autodoc_inherit_docstrings = False - -# -- Options for HTML output ------------------------------------------------- - -# Reconstruction of sphinx auto generated document translation. -language = 'zh_CN' -locale_dirs = ['../../../../resource/locale/'] -gettext_compact = False - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -html_search_language = 'zh' - -html_search_options = {'dict': '../../../resource/jieba.txt'} - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/', '../../../../resource/python_objects.inv'), - 'numpy': ('https://docs.scipy.org/doc/numpy/', '../../../../resource/numpy_objects.inv'), -} - -from sphinx import directives -with open('../_ext/overwriteobjectiondirective.txt', 'r', encoding="utf8") as f: - exec(f.read(), directives.__dict__) - -from sphinx.ext import viewcode -with open('../_ext/overwriteviewcode.txt', 'r', encoding="utf8") as f: - exec(f.read(), viewcode.__dict__) - -# Modify default signatures for autodoc. -autodoc_source_path = os.path.abspath(sphinx_autodoc.__file__) -autodoc_source_re = re.compile(r'stringify_signature\(.*?\)') -get_param_func_str = r"""\ -import re -import inspect as inspect_ - -def get_param_func(func): - try: - source_code = inspect_.getsource(func) - if func.__doc__: - source_code = source_code.replace(func.__doc__, '') - all_params_str = re.findall(r"def [\w_\d\-]+\(([\S\s]*?)(\):|\) ->.*?:)", source_code) - all_params = re.sub("(self|cls)(,|, )?", '', all_params_str[0][0].replace("\n", "").replace("'", "\"")) - return all_params - except: - return '' - -def get_obj(obj): - if isinstance(obj, type): - return obj.__init__ - - return obj -""" - -with open(autodoc_source_path, "r+", encoding="utf8") as f: - code_str = f.read() - code_str = autodoc_source_re.sub('"(" + get_param_func(get_obj(self.object)) + ")"', code_str, count=0) - exec(get_param_func_str, sphinx_autodoc.__dict__) - exec(code_str, sphinx_autodoc.__dict__) - -# Copy source files of chinese python api from hub repository. -from sphinx.util import logging -logger = logging.getLogger(__name__) - -copy_path = 'docs/api_python' -src_dir = os.path.join(os.getenv("HB_PATH"), copy_path) - -copy_list = [] - -present_path = os.path.dirname(__file__) - -for i in os.listdir(src_dir): - if os.path.isfile(os.path.join(src_dir,i)): - if os.path.exists('./'+i): - os.remove('./'+i) - shutil.copy(os.path.join(src_dir,i),'./'+i) - copy_list.append(os.path.join(present_path,i)) - else: - if os.path.exists('./'+i): - shutil.rmtree('./'+i) - shutil.copytree(os.path.join(src_dir,i),'./'+i) - copy_list.append(os.path.join(present_path,i)) - -# add view -import json - -if os.path.exists('../../../../tools/generate_html/version.json'): - with open('../../../../tools/generate_html/version.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily_dev.json'): - with open('../../../../tools/generate_html/daily_dev.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) -elif os.path.exists('../../../../tools/generate_html/daily.json'): - with open('../../../../tools/generate_html/daily.json', 'r+', encoding='utf-8') as f: - version_inf = json.load(f) - -if os.getenv("HB_PATH").split('/')[-1]: - copy_repo = os.getenv("HB_PATH").split('/')[-1] -else: - copy_repo = os.getenv("HB_PATH").split('/')[-2] - -branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == copy_repo][0] -docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if version_inf[i]['name'] == 'tutorials'][0] - -re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{docs_branch}/" + \ - f"resource/_static/logo_source.svg\n :target: https://gitee.com/mindspore/{copy_repo}/blob/{branch}/" - -for cur, _, files in os.walk(present_path): - for i in files: - flag_copy = 0 - if i.endswith('.rst'): - for j in copy_list: - if j in cur: - flag_copy = 1 - break - if os.path.join(cur, i) in copy_list or flag_copy: - try: - with open(os.path.join(cur, i), 'r+', encoding='utf-8') as f: - content = f.read() - new_content = content - if '.. include::' in content and '.. automodule::' in content: - continue - if 'autosummary::' not in content and "\n=====" in content: - re_view_ = re_view + copy_path + cur.split(present_path)[-1] + '/' + i + \ - '\n :alt: 查看源文件\n\n' - new_content = re.sub('([=]{5,})\n', r'\1\n' + re_view_, content, 1) - if new_content != content: - f.seek(0) - f.truncate() - f.write(new_content) - except Exception: - print(f'打开{i}文件失败') - - -sys.path.append(os.path.abspath('../../../../resource/sphinx_ext')) -# import anchor_mod -import nbsphinx_mod - -sys.path.append(os.path.abspath('../../../../resource/search')) -import search_code - -sys.path.append(os.path.abspath('../../../../resource/custom_directives')) -from custom_directives import IncludeCodeDirective - -def setup(app): - app.add_directive('includecode', IncludeCodeDirective) diff --git a/docs/hub/docs/source_zh_cn/hub_installation.md b/docs/hub/docs/source_zh_cn/hub_installation.md deleted file mode 100644 index c2641d7577eb939cec863dc331cbdda598a8f467..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_zh_cn/hub_installation.md +++ /dev/null @@ -1,114 +0,0 @@ -# 安装MindSpore Hub - -- [安装MindSpore Hub](#安装mindspore-hub) - - [确认系统环境信息](#确认系统环境信息) - - [安装方式](#安装方式) - - [pip安装](#pip安装) - - [源码安装](#源码安装) - - [验证是否成功安装](#验证是否成功安装) - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/hub/docs/source_zh_cn/hub_installation.md) - -## 确认系统环境信息 - -- 硬件平台支持Ascend、GPU和CPU。 -- 确认安装[Python](https://www.python.org/ftp/python/3.7.5/Python-3.7.5.tgz) 3.7.5版本。 -- MindSpore Hub与MindSpore的版本需保持一致。 -- MindSpore Hub支持使用x86 64位或ARM 64位架构的Linux发行版系统。 -- 在联网状态下,安装whl包时会自动下载`setup.py`中的依赖项,其余情况需自行安装。 - -## 安装方式 - -可以采用pip安装或者源码安装两种方式。 - -### pip安装 - -下载并安装[发布版本列表](https://www.mindspore.cn/versions)中的MindSpore Hub whl包。 - -```shell -pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/{version}/Hub/any/mindspore_hub-{version}-py3-none-any.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple -``` - -> - `{version}`表示MindSpore Hub版本号,例如下载1.3.0版本MindSpore Hub时,`{version}`应写为1.3.0。 - -### 源码安装 - -1. 从Gitee下载源码。 - - ```bash - git clone https://gitee.com/mindspore/hub.git - ``` - -2. 编译安装MindSpore Hub。 - - ```bash - cd hub - python setup.py install - ``` - -## 验证是否成功安装 - -在能联网的环境中执行以下命令,验证安装结果。 - -```python -import mindspore_hub as mshub - -model = mshub.load("mindspore/1.6/lenet_mnist", num_class=10) -``` - -如果出现下列提示,说明安装成功: - -```text -Downloading data from url https://gitee.com/mindspore/hub/raw/master/mshub_res/assets/mindspore/1.6/lenet_mnist.md - -Download finished! -File size = 0.00 Mb -Checking /home/ma-user/.mscache/mindspore/1.6/lenet_mnist.md...Passed! -``` - -## FAQ - -**Q: 遇到`SSL: CERTIFICATE_VERIFY_FAILED`怎么办?** - -A: 由于你的网络环境,例如你使用代理连接互联网,往往会由于证书配置问题导致python出现ssl verification failed的问题,此时有两种解决方法: - -- 配置好SSL证书 **(推荐)** -- 在加载mindspore_hub前增加如下代码进行解决(最快) - - ```python - import ssl - ssl._create_default_https_context = ssl._create_unverified_context - - import mindspore_hub as mshub - model = mshub.load("mindspore/1.6/lenet_mnist", num_classes=10) - ``` - -**Q: 遇到`No module named src.*`怎么办?** - -A: 同一进程中使用load接口加载不同的模型,由于每次加载模型需要将模型文件目录插入到环境变量中,经测试发现:Python只会去最开始插入的目录下查找src.*,尽管你将最开始插入的目录删除,Python还是会去这个目录下查找。解决办法:不添加环境变量,将模型目录下的所有文件都复制到当前工作目录下。代码如下: - -```python -# mindspore_hub_install_path/load.py -def _copy_all_file_to_target_path(path, target_path): - if not os.path.exists(target_path): - os.makedirs(target_path) - path = os.path.realpath(path) - target_path = os.path.realpath(target_path) - for p in os.listdir(path): - copy_path = os.path.join(path, p) - target_dir = os.path.join(target_path, p) - _delete_if_exist(target_dir) - if os.path.isdir(copy_path): - _copy_all_file_to_target_path(copy_path, target_dir) - else: - shutil.copy(copy_path, target_dir) - -def _get_network_from_cache(name, path, *args, **kwargs): - _copy_all_file_to_target_path(path, os.getcwd()) - config_path = os.path.join(os.getcwd(), HUB_CONFIG_FILE) - if not os.path.exists(config_path): - raise ValueError('{} not exists.'.format(config_path)) - ...... -``` - -**注意**:在load后一个模型时可能会将前一个模型的一些文件替换掉,但是模型训练需保证必要模型文件存在,你必须在加载新模型之前完成对前一个模型的训练。 diff --git a/docs/hub/docs/source_zh_cn/index.rst b/docs/hub/docs/source_zh_cn/index.rst deleted file mode 100644 index e708a5a2608d387148492c75c1ec37afe2b06319..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_zh_cn/index.rst +++ /dev/null @@ -1,71 +0,0 @@ -MindSpore Hub 文档 -========================= - -MindSpore Hub是MindSpore生态的预训练模型应用工具。 - -MindSpore Hub包含以下功能: - -- 即插即用的模型加载 -- 简单易用的迁移学习 - -.. code-block:: - - import mindspore - import mindspore_hub as mshub - from mindspore import set_context, GRAPH_MODE - - set_context(mode=GRAPH_MODE, - device_target="Ascend", - device_id=0) - - model = "mindspore/1.6/googlenet_cifar10" - - # Initialize the number of classes based on the pre-trained model. - network = mshub.load(model, num_classes=10) - network.set_train(False) - - # ... - -代码仓地址: - -使用MindSpore Hub的典型场景 ----------------------------- - -1. `推理验证 `_ - - mindspore_hub.load用于加载预训练模型,可以实现一行代码完成模型的加载。 - -2. `迁移学习 `_ - - 通过mindspore_hub.load完成模型加载后,可以增加一个额外的参数项只加载神经网络的特征提取部分,这样就能很容易地在之后增加一些新的层进行迁移学习。 - -3. `发布模型 `_ - - 可以将自己训练好的模型按照指定的步骤发布到MindSpore Hub中,以供其他用户进行下载和使用。 - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: 安装部署 - - hub_installation - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: 使用指南 - - loading_model_from_hub - publish_model - -.. toctree:: - :maxdepth: 1 - :caption: API参考 - - hub - -.. toctree:: - :maxdepth: 1 - :caption: 模型 - - MindSpore Hub↗ diff --git a/docs/hub/docs/source_zh_cn/loading_model_from_hub.md b/docs/hub/docs/source_zh_cn/loading_model_from_hub.md deleted file mode 100644 index 2d990e01e863cae3644927fb906792fab34eefae..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_zh_cn/loading_model_from_hub.md +++ /dev/null @@ -1,212 +0,0 @@ -# 从Hub加载模型 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/hub/docs/source_zh_cn/loading_model_from_hub.md) - -## 概述 - -对于个人开发者来说,从零开始训练一个较好模型,需要大量的标注完备的数据、足够的计算资源和大量训练调试时间。使得模型训练非常消耗资源,提升了AI开发的门槛,针对以上问题,MindSpore Hub提供了很多训练完成的模型权重文件,可以使得开发者在拥有少量数据的情况下,只需要花费少量训练时间,即可快速训练出一个较好的模型。 - -本文档从推理验证和迁移学习两种用途,展示使用MindSpore Hub提供的模型,用少量数据快速完成训练得到较好的模型。 - -## 用于推理验证 - -`mindspore_hub.load` API用于加载预训练模型,可以实现一行代码完成模型的加载。主要的模型加载流程如下: - -1. 在[MindSpore Hub官网](https://www.mindspore.cn/resources/hub)上搜索感兴趣的模型。 - - 例如,想使用GoogleNet对CIFAR-10数据集进行分类,可以在[MindSpore Hub官网](https://www.mindspore.cn/resources/hub)上使用关键词`GoogleNet`进行搜索。页面将会返回与GoogleNet相关的所有模型。进入相关模型页面之后,查看`Usage`。**注意**:如果页面没有`Usage`表示当前模型暂不支持使用MindSpore Hub加载。 - -2. 根据`Usage`完成模型的加载,示例代码如下: - - ```python - import mindspore_hub as mshub - import mindspore - from mindspore import Tensor, nn, Model, set_context, GRAPH_MODE - from mindspore import dtype as mstype - import mindspore.dataset.vision as vision - - set_context(mode=GRAPH_MODE, - device_target="Ascend", - device_id=0) - - model = "mindspore/1.6/googlenet_cifar10" - - # Initialize the number of classes based on the pre-trained model. - network = mshub.load(model, num_classes=10) - network.set_train(False) - - # ... - - ``` - -3. 完成模型加载后,可以使用MindSpore进行推理,参考[推理模型总览](https://www.mindspore.cn/tutorials/zh-CN/master/model_infer/ms_infer/llm_inference_overview.html)。 - -## 用于迁移学习 - -通过`mindspore_hub.load`完成模型加载后,可以增加一个额外的参数项只加载神经网络的特征提取部分,这样我们就能很容易地在之后增加一些新的层进行迁移学习。当模型开发者将额外的参数(例如 `include_top`)添加到模型构造中时,可以在模型的详情页中找到这个功能。`include_top`取值为True或者False,表示是否保留顶层的全连接网络。* - -下面我们以[MobileNetV2](https://gitee.com/mindspore/models/tree/master/research/cv/centerface)为例,说明如何加载一个基于OpenImage的预训练模型,并在特定的子任务数据集上进行迁移学习(重训练)。主要的步骤如下: - -1. 在[MindSpore Hub官网](https://www.mindspore.cn/resources/hub/)上搜索感兴趣的模型,查看对应的`Usage`。 - -2. 根据`Usage`进行MindSpore Hub模型的加载,注意:`include_top`参数需要模型开发者提供。 - - ```python - import os - import mindspore_hub as mshub - import mindspore - from mindspore import Tensor, nn, set_context, GRAPH_MODE, train - from mindspore.nn import Momentum - from mindspore import save_checkpoint, load_checkpoint,load_param_into_net - from mindspore import ops - import mindspore.dataset as ds - import mindspore.dataset.transforms as transforms - import mindspore.dataset.vision as vision - from mindspore import dtype as mstype - from mindspore import Model - set_context(mode=GRAPH_MODE, device_target="Ascend", device_id=0) - - model = "mindspore/1.6/mobilenetv2_imagenet2012" - network = mshub.load(model, num_classes=500, include_top=False, activation="Sigmoid") - network.set_train(False) - ``` - -3. 在现有模型结构基础上,增加一个与新任务相关的分类层。 - - ```python - class ReduceMeanFlatten(nn.Cell): - def __init__(self): - super(ReduceMeanFlatten, self).__init__() - self.mean = ops.ReduceMean(keep_dims=True) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.mean(x, (2, 3)) - x = self.flatten(x) - return x - - # Check MindSpore Hub website to conclude that the last output shape is 1280. - last_channel = 1280 - - # The number of classes in target task is 10. - num_classes = 10 - - reducemean_flatten = ReduceMeanFlatten() - - classification_layer = nn.Dense(last_channel, num_classes) - classification_layer.set_train(True) - - train_network = nn.SequentialCell([network, reducemean_flatten, classification_layer]) - ``` - -4. 定义数据集加载函数。 - - 如下所示,进行微调任务的数据集为[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html),注意此处需要下载二进制版本(`binary version`)的数据。下载解压后可以通过如下所示代码加载和处理数据。`dataset_path`是数据集的保存路径,由用户给定。 - - ```python - def create_cifar10dataset(dataset_path, batch_size, usage='train', shuffle=True): - data_set = ds.Cifar10Dataset(dataset_dir=dataset_path, usage=usage, shuffle=shuffle) - - # define map operations - trans = [ - vision.Resize((256, 256)), - vision.RandomHorizontalFlip(prob=0.5), - vision.Rescale(1.0 / 255.0, 0.0), - vision.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), - vision.HWC2CHW() - ] - - type_cast_op = transforms.TypeCast(mstype.int32) - - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) - data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) - - # apply batch operations - data_set = data_set.batch(batch_size, drop_remainder=True) - return data_set - - # Create Dataset - dataset_path = "/path_to_dataset/cifar-10-batches-bin" - dataset = create_cifar10dataset(dataset_path, batch_size=32, usage='train', shuffle=True) - ``` - -5. 为模型训练选择损失函数、优化器和学习率。 - - ```python - def generate_steps_lr(lr_init, steps_per_epoch, total_epochs): - total_steps = total_epochs * steps_per_epoch - decay_epoch_index = [0.3*total_steps, 0.6*total_steps, 0.8*total_steps] - lr_each_step = [] - for i in range(total_steps): - if i < decay_epoch_index[0]: - lr = lr_init - elif i < decay_epoch_index[1]: - lr = lr_init * 0.1 - elif i < decay_epoch_index[2]: - lr = lr_init * 0.01 - else: - lr = lr_init * 0.001 - lr_each_step.append(lr) - return lr_each_step - - # Set epoch size - epoch_size = 60 - - # Wrap the backbone network with loss. - loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - loss_net = nn.WithLossCell(train_network, loss_fn) - steps_per_epoch = dataset.get_dataset_size() - lr = generate_steps_lr(lr_init=0.01, steps_per_epoch=steps_per_epoch, total_epochs=epoch_size) - - # Create an optimizer. - optim = Momentum(filter(lambda x: x.requires_grad, classification_layer.get_parameters()), Tensor(lr, mindspore.float32), 0.9, 4e-5) - train_net = nn.TrainOneStepCell(loss_net, optim) - ``` - -6. 开始重训练。 - - ```python - for epoch in range(epoch_size): - for i, items in enumerate(dataset): - data, label = items - data = mindspore.Tensor(data) - label = mindspore.Tensor(label) - - loss = train_net(data, label) - print(f"epoch: {epoch}/{epoch_size}, loss: {loss}") - # Save the ckpt file for each epoch. - if not os.path.exists('ckpt'): - os.mkdir('ckpt') - ckpt_path = f"./ckpt/cifar10_finetune_epoch{epoch}.ckpt" - save_checkpoint(train_network, ckpt_path) - ``` - -7. 在测试集上测试模型精度。 - - ```python - model = "mindspore/1.6/mobilenetv2_imagenet2012" - - network = mshub.load(model, num_classes=500, pretrained=True, include_top=False, activation="Sigmoid") - network.set_train(False) - reducemean_flatten = ReduceMeanFlatten() - classification_layer = nn.Dense(last_channel, num_classes) - classification_layer.set_train(False) - softmax = nn.Softmax() - network = nn.SequentialCell([network, reducemean_flatten, classification_layer, softmax]) - - # Load a pre-trained ckpt file. - ckpt_path = "./ckpt/cifar10_finetune_epoch59.ckpt" - trained_ckpt = load_checkpoint(ckpt_path) - load_param_into_net(classification_layer, trained_ckpt) - - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - - # Define loss and create model. - eval_dataset = create_cifar10dataset(dataset_path, batch_size=32, do_train=False) - eval_metrics = {'Loss': train.Loss(), - 'Top1-Acc': train.Top1CategoricalAccuracy(), - 'Top5-Acc': train.Top5CategoricalAccuracy()} - model = Model(network, loss_fn=loss, optimizer=None, metrics=eval_metrics) - metrics = model.eval(eval_dataset) - print("metric: ", metrics) - ``` diff --git a/docs/hub/docs/source_zh_cn/publish_model.md b/docs/hub/docs/source_zh_cn/publish_model.md deleted file mode 100644 index 37ffafbe9ac93448884d0a0206416ad9f5a825f9..0000000000000000000000000000000000000000 --- a/docs/hub/docs/source_zh_cn/publish_model.md +++ /dev/null @@ -1,74 +0,0 @@ -# 发布模型 - -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/hub/docs/source_zh_cn/publish_model.md) - -## 概述 - -[MindSpore Hub](https://www.mindspore.cn/resources/hub/)是存放MindSpore官方或者第三方开发者提供的预训练模型的平台。它向应用开发者提供了简单易用的模型加载和微调APIs,使得用户可以基于预训练模型进行推理或者微调,并部署到自己的应用中。用户也可以将自己训练好的模型按照指定的步骤发布到MindSpore Hub中,以供其他用户进行下载和使用。 - -本教程以GoogleNet为例,对想要将模型发布到MindSpore Hub的模型开发者介绍了模型上传步骤。 - -## 发布模型到MindSpore Hub - -用户可通过向[hub](https://gitee.com/mindspore/hub)仓提交PR的方式向MindSpore Hub发布模型。这里我们以GoogleNet为例,列出模型提交到MindSpore Hub的步骤。 - -1. 将你的预训练模型托管在可以访问的存储位置。 - -2. 参照[模板](https://gitee.com/mindspore/models/blob/master/research/cv/SE_ResNeXt50/mindspore_hub_conf.py),在你自己的代码仓中添加模型生成文件`mindspore_hub_conf.py`,文件放置的位置如下: - - ```text - googlenet - ├── src - │   ├── googlenet.py - ├── script - │   ├── run_train.sh - ├── train.py - ├── test.py - ├── mindspore_hub_conf.py - ``` - -3. 参照[模板](https://gitee.com/mindspore/hub/blob/master/mshub_res/assets/mindspore/1.6/googlenet_cifar10.md#),在`hub/mshub_res/assets/mindspore/1.6`文件夹下创建`{model_name}_{dataset}.md`文件,其中`1.6`为MindSpore的版本号,`hub/mshub_res`的目录结构为: - - ```text - hub - ├── mshub_res - │   ├── assets - │   ├── mindspore - │ ├── 1.6 - │ ├── googlenet_cifar10.md - │   ├── tools - │ ├── get_sha256.py - │ ├── load_markdown.py - │ └── md_validator.py - ``` - - 注意,`{model_name}_{dataset}.md`文件中需要补充如下所示的`file-format`、`asset-link` 和 `asset-sha256`信息,它们分别表示模型文件格式、模型存储位置(步骤1所得)和模型哈希值。 - - ```text - file-format: ckpt - asset-link: https://download.mindspore.cn/models/r1.6/googlenet_ascend_v160_cifar10_official_cv_acc92.53.ckpt - asset-sha256: b2f7fe14782a3ab88ad3534ed5f419b4bbc3b477706258bd6ed8f90f529775e7 - ``` - - 其中,MindSpore Hub支持的模型文件格式有: - - [MindSpore CKPT](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html#保存与加载) - - [MINDIR](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html#保存和加载mindir) - - [AIR](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.export.html#mindspore.export) - - [ONNX](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.export.html#mindspore.export) - - 对于每个预训练模型,执行以下命令,用来获得`.md`文件`asset-sha256`处所需的哈希值,其中`googlenet.ckpt`是从步骤1的存储位置处下载并保存到`tools`文件夹的预训练模型,运行后输出的哈希值为`b2f7fe14782a3ab88ad3534ed5f419b4bbc3b477706258bd6ed8f90f529775e7`。 - - ```bash - cd /hub/mshub_res/tools - python get_sha256.py --file ../googlenet.ckpt - ``` - -4. 使用`hub/mshub_res/tools/md_validator.py`在本地核对`.md`文件的格式,执行以下命令,输出结果为`All Passed`,表示`.md`文件的格式和内容均符合要求。 - - ```bash - python md_validator.py --check_path ../assets/mindspore/1.6/googlenet_cifar10.md - ``` - -5. 在`mindspore/hub`仓创建PR,详细创建方式可以参考[贡献者Wiki](https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md#)。 - -一旦你的PR合入到`mindspore/hub`的master分支,你的模型将于24小时内在[MindSpore Hub 网站](https://www.mindspore.cn/resources/hub)上显示。有关模型上传的更多详细信息,请参考[README](https://gitee.com/mindspore/hub/blob/master/mshub_res/README.md#)。 diff --git a/docs/lite/api/_custom/sphinx_builder_html b/docs/lite/api/_custom/sphinx_builder_html index 453a52fea86bbe95cd49a2090bc25b2db53c3901..f288a881a6618d821a9cdf6a1b678e7d3d7de21c 100644 --- a/docs/lite/api/_custom/sphinx_builder_html +++ b/docs/lite/api/_custom/sphinx_builder_html @@ -1116,7 +1116,7 @@ class StandaloneHTMLBuilder(Builder): # Add links to the Python operator interface. if "mindspore.ops." in output: - output = re.sub(r'(mindspore\.ops\.\w+) ', r'\1 ', output, count=0) + output = re.sub(r'(mindspore\.ops\.\w+) ', r'\1 ', output, count=0) except UnicodeError: logger.warning(__("a Unicode error occurred when rendering the page %s. " diff --git a/docs/lite/api/source_en/api_c/lite_c_example.rst b/docs/lite/api/source_en/api_c/lite_c_example.rst index c4588f1ec970840d314f8e45510ff1786f997a4d..31bc6ee3e9f9ac475cd1de07fe316bf18fbe8d74 100644 --- a/docs/lite/api/source_en/api_c/lite_c_example.rst +++ b/docs/lite/api/source_en/api_c/lite_c_example.rst @@ -4,4 +4,4 @@ Example .. toctree:: :maxdepth: 1 - Simple Demo↗ + Simple Demo↗ diff --git a/docs/lite/api/source_en/api_cpp/lite_cpp_example.rst b/docs/lite/api/source_en/api_cpp/lite_cpp_example.rst index 41711025f8bb1b2b6ac69fe714c5f4a3c7612e20..224e43a200bdb29251d0dc2b9c5f6a62e098212d 100644 --- a/docs/lite/api/source_en/api_cpp/lite_cpp_example.rst +++ b/docs/lite/api/source_en/api_cpp/lite_cpp_example.rst @@ -4,6 +4,6 @@ Example .. toctree:: :maxdepth: 1 - Simple Demo↗ - Android Application Development Based on JNI Interface↗ - High-level Usage↗ \ No newline at end of file + Simple Demo↗ + Android Application Development Based on JNI Interface↗ + High-level Usage↗ \ No newline at end of file diff --git a/docs/lite/api/source_en/api_java/class_list.md b/docs/lite/api/source_en/api_java/class_list.md index e70109e14f9671b5f39663731c5f3d7d2e411aba..d84787f0134b9be945b9060f2b32c017475b787d 100644 --- a/docs/lite/api/source_en/api_java/class_list.md +++ b/docs/lite/api/source_en/api_java/class_list.md @@ -1,17 +1,17 @@ # Class List -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_en/api_java/class_list.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_en/api_java/class_list.md) | Package | Class Name | Description | Supported At Cloud-side Inference | Supported At Device-side Inference | | ------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |--------|--------| -| com.mindspore | [Model](https://www.mindspore.cn/lite/api/en/master/api_java/model.html) | Model defines model in MindSpore for compiling and running compute graph. | √ | √ | -| com.mindspore.config | [MSContext](https://www.mindspore.cn/lite/api/en/master/api_java/mscontext.html) | MSContext defines for holding environment variables during runtime. | √ | √ | -| com.mindspore | [MSTensor](https://www.mindspore.cn/lite/api/en/master/api_java/mstensor.html) | MSTensor defines the tensor in MindSpore. | √ | √ | -| com.mindspore | [ModelParallelRunner](https://www.mindspore.cn/lite/api/en/master/api_java/model_parallel_runner.html) | Defines MindSpore Lite concurrent inference. | √ | ✕ | -| com.mindspore.config | [RunnerConfig](https://www.mindspore.cn/lite/api/en/master/api_java/runner_config.html) | RunnerConfig defines configuration parameters for concurrent inference. | √ | ✕ | -| com.mindspore | [Graph](https://www.mindspore.cn/lite/api/en/master/api_java/graph.html) | Graph defines the compute graph in MindSpore. | ✕ | √ | -| com.mindspore.config | [CpuBindMode](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java) | CpuBindMode defines the CPU binding mode. | √ | √ | -| com.mindspore.config | [DeviceType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java) | DeviceType defines the back-end device type. | √ | √ | -| com.mindspore.config | [DataType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java) | DataType defines the supported data types. | √ | √ | -| com.mindspore.config | [Version](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java) | Version is used to obtain the version information of MindSpore. | ✕ | √ | -| com.mindspore.config | [ModelType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java) | ModelType defines the model file type. | √ | √ | +| com.mindspore | [Model](https://www.mindspore.cn/lite/api/en/r2.6.0/api_java/model.html) | Model defines model in MindSpore for compiling and running compute graph. | √ | √ | +| com.mindspore.config | [MSContext](https://www.mindspore.cn/lite/api/en/r2.6.0/api_java/mscontext.html) | MSContext defines for holding environment variables during runtime. | √ | √ | +| com.mindspore | [MSTensor](https://www.mindspore.cn/lite/api/en/r2.6.0/api_java/mstensor.html) | MSTensor defines the tensor in MindSpore. | √ | √ | +| com.mindspore | [ModelParallelRunner](https://www.mindspore.cn/lite/api/en/r2.6.0/api_java/model_parallel_runner.html) | Defines MindSpore Lite concurrent inference. | √ | ✕ | +| com.mindspore.config | [RunnerConfig](https://www.mindspore.cn/lite/api/en/r2.6.0/api_java/runner_config.html) | RunnerConfig defines configuration parameters for concurrent inference. | √ | ✕ | +| com.mindspore | [Graph](https://www.mindspore.cn/lite/api/en/r2.6.0/api_java/graph.html) | Graph defines the compute graph in MindSpore. | ✕ | √ | +| com.mindspore.config | [CpuBindMode](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java) | CpuBindMode defines the CPU binding mode. | √ | √ | +| com.mindspore.config | [DeviceType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java) | DeviceType defines the back-end device type. | √ | √ | +| com.mindspore.config | [DataType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java) | DataType defines the supported data types. | √ | √ | +| com.mindspore.config | [Version](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java) | Version is used to obtain the version information of MindSpore. | ✕ | √ | +| com.mindspore.config | [ModelType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java) | ModelType defines the model file type. | √ | √ | diff --git a/docs/lite/api/source_en/api_java/graph.md b/docs/lite/api/source_en/api_java/graph.md index a5892222adf05943482d7afdf0b5f88e0409c564..f70ceda5a5c55c58c396648d465fba22bf0689f9 100644 --- a/docs/lite/api/source_en/api_java/graph.md +++ b/docs/lite/api/source_en/api_java/graph.md @@ -1,6 +1,6 @@ # Graph -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_en/api_java/graph.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_en/api_java/graph.md) ```java import com.mindspore.Graph; diff --git a/docs/lite/api/source_en/api_java/lite_java_example.rst b/docs/lite/api/source_en/api_java/lite_java_example.rst index 01f76f0495b7394007f45abde2213d365095ba6e..56e3f023560d97108f542b911b983953033cb3e3 100644 --- a/docs/lite/api/source_en/api_java/lite_java_example.rst +++ b/docs/lite/api/source_en/api_java/lite_java_example.rst @@ -4,6 +4,6 @@ Example .. toctree:: :maxdepth: 1 - Simple Demo↗ - Android Application Development Based on Java Interface↗ - High-level Usage↗ \ No newline at end of file + Simple Demo↗ + Android Application Development Based on Java Interface↗ + High-level Usage↗ \ No newline at end of file diff --git a/docs/lite/api/source_en/api_java/model.md b/docs/lite/api/source_en/api_java/model.md index ad5be2106b19f51651abc7da406d71acfca05337..376d2fce78a06c18f620af0cba0117a6b92527e0 100644 --- a/docs/lite/api/source_en/api_java/model.md +++ b/docs/lite/api/source_en/api_java/model.md @@ -1,6 +1,6 @@ # Model -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_en/api_java/model.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_en/api_java/model.md) ```java import com.mindspore.model; diff --git a/docs/lite/api/source_en/api_java/model_parallel_runner.md b/docs/lite/api/source_en/api_java/model_parallel_runner.md index 1525eaf8bf9f70b5e0c30e9a3cfcf12a53d22987..451d2d7fc439501753d00adae155f6201303c42c 100644 --- a/docs/lite/api/source_en/api_java/model_parallel_runner.md +++ b/docs/lite/api/source_en/api_java/model_parallel_runner.md @@ -1,6 +1,6 @@ # ModelParallelRunner -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_en/api_java/model_parallel_runner.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_en/api_java/model_parallel_runner.md) ```java import com.mindspore.config.RunnerConfig; diff --git a/docs/lite/api/source_en/api_java/mscontext.md b/docs/lite/api/source_en/api_java/mscontext.md index 950f839859aaea7d0febc3bd569bdd932d023443..3f732710deaccc37f645e39273e19c03bab72b21 100644 --- a/docs/lite/api/source_en/api_java/mscontext.md +++ b/docs/lite/api/source_en/api_java/mscontext.md @@ -1,6 +1,6 @@ # MSContext -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_en/api_java/mscontext.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_en/api_java/mscontext.md) ```java import com.mindspore.config.MSContext; @@ -53,7 +53,7 @@ Initialize MSContext for cpu. - Parameters - `threadNum`: Thread number config for thread pool. - - `cpuBindMode`: A **[CpuBindMode](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)** **enum** variable. + - `cpuBindMode`: A **[CpuBindMode](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)** **enum** variable. - Returns @@ -68,7 +68,7 @@ Initialize MSContext. - Parameters - `threadNum`: Thread number config for thread pool. - - `cpuBindMode`: A **[CpuBindMode](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)** **enum** variable. + - `cpuBindMode`: A **[CpuBindMode](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)** **enum** variable. - `isEnableParallel`: Is enable parallel in different device. - Returns @@ -85,7 +85,7 @@ Add device info for mscontext. - Parameters - - `deviceType`: A **[DeviceType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)** **enum** type. + - `deviceType`: A **[DeviceType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)** **enum** type. - `isEnableFloat16`: Is enable fp16. - Returns @@ -100,7 +100,7 @@ Add device info for mscontext. - Parameters - - `deviceType`: A **[DeviceType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)** **enum** type. + - `deviceType`: A **[DeviceType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)** **enum** type. - `isEnableFloat16`: is enable fp16. - `npuFreq`: Npu frequency. diff --git a/docs/lite/api/source_en/api_java/mstensor.md b/docs/lite/api/source_en/api_java/mstensor.md index 5290aa66d9db6b50bb1dd49c723bcf0d6d51538d..878eea71cdfc2294ea1979b7cfb07a5bf92cc596 100644 --- a/docs/lite/api/source_en/api_java/mstensor.md +++ b/docs/lite/api/source_en/api_java/mstensor.md @@ -1,6 +1,6 @@ # MSTensor -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_en/api_java/mstensor.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_en/api_java/mstensor.md) ```java import com.mindspore.MSTensor; @@ -83,7 +83,7 @@ Get the shape of the MindSpore MSTensor. public int getDataType() ``` -DataType is defined in [com.mindspore.DataType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java). +DataType is defined in [com.mindspore.DataType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java). - Returns diff --git a/docs/lite/api/source_en/api_java/runner_config.md b/docs/lite/api/source_en/api_java/runner_config.md index 8e332ac16da7aae3a03f59dbe6afc1c3c4f7ce85..3d52e62b7254a28d14f1b92cbfd3cc54115d051b 100644 --- a/docs/lite/api/source_en/api_java/runner_config.md +++ b/docs/lite/api/source_en/api_java/runner_config.md @@ -1,6 +1,6 @@ # RunnerConfig -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_en/api_java/runner_config.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_en/api_java/runner_config.md) RunnerConfig defines the configuration parameters of MindSpore Lite concurrent inference. diff --git a/docs/lite/api/source_en/conf.py b/docs/lite/api/source_en/conf.py index 408e41a12c0d7928efab0d519b43b62841af0b91..de98e4c2159611f05dc277c3ab9ca594a1729068 100644 --- a/docs/lite/api/source_en/conf.py +++ b/docs/lite/api/source_en/conf.py @@ -26,12 +26,12 @@ from exhale import graph as exh_graph # -- Project information ----------------------------------------------------- -project = 'MindSpore' +project = 'MindSpore Lite' copyright = 'MindSpore' author = 'MindSpore' # The full version, including alpha/beta/rc tags -release = 'master' +release = '2.6.0' # -- General configuration --------------------------------------------------- @@ -322,6 +322,43 @@ try: except: pass +# modify urls +re_url = r"(((gitee.com/mindspore/docs)|(github.com/mindspore-ai/(mindspore|docs))|" + \ + r"(mindspore.cn/(docs|tutorials|lite))|(obs.dualstack.cn-north-4.myhuaweicloud)|" + \ + r"(mindspore-website.obs.cn-north-4.myhuaweicloud))[\w\d/_.-]*?)/(master)" + +re_url2 = r"(gitee.com/mindspore/mindspore[\w\d/_.-]*?)/(master)" + +re_url3 = r"(((gitee.com/mindspore/golden-stick)|(mindspore.cn/golden_stick))[\w\d/_.-]*?)/(master)" + +re_url4 = r"(((gitee.com/mindspore/mindformers)|(mindspore.cn/mindformers))[\w\d/_.-]*?)/(dev)" + +with open(os.path.join('./mindspore_lite.rst'), 'r+', encoding='utf-8') as f: + content = f.read() + new_content = re.sub(re_url, r'\1/r2.6.0', content) + new_content = re.sub(re_url2, r'\1/v2.6.0', new_content) + # new_content = re.sub(re_url3, r'\1/r1.1.0', new_content) + new_content = re.sub(re_url4, r'\1/r1.5.0', new_content) + if new_content != content: + f.seek(0) + f.truncate() + f.write(new_content) + +base_path = os.path.dirname(os.path.dirname(sphinx.__file__)) +for cur, _, files in os.walk(os.path.join(base_path, 'mindspore_lite')): + for i in files: + if i.endswith('.py'): + with open(os.path.join(cur, i), 'r+', encoding='utf-8') as f: + content = f.read() + new_content = re.sub(re_url, r'\1/r2.6.0', content) + new_content = re.sub(re_url2, r'\1/v2.6.0', new_content) + # new_content = re.sub(re_url3, r'\1/r1.1.0', new_content) + new_content = re.sub(re_url4, r'\1/r1.5.0', new_content) + if new_content != content: + f.seek(0) + f.truncate() + f.write(new_content) + # modify urls import json diff --git a/docs/lite/api/source_en/index.rst b/docs/lite/api/source_en/index.rst index 522bcff0e167aa1bb42a2ec22d9c68a77aabe531..1f6b95ca47d2ba981c0652fa0351fa295ba440d2 100644 --- a/docs/lite/api/source_en/index.rst +++ b/docs/lite/api/source_en/index.rst @@ -12,21 +12,21 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Class | Description | C++ API | Python API | +=========================================================+===================================================================================================================================+==========================================================================================================================================================================================================================+============================================================================================================================================================================================================================================================================================================================================================================+ -| Context | Set the number of threads at runtime | void SetThreadNum(int32_t thread_num) | `Context.cpu.thread_num `__ | +| Context | Set the number of threads at runtime | void SetThreadNum(int32_t thread_num) | `Context.cpu.thread_num `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Get the current thread number setting | int32_t GetThreadNum() const | `Context.cpu.thread_num `__ | +| Context | Get the current thread number setting | int32_t GetThreadNum() const | `Context.cpu.thread_num `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Set the parallel number of operators at runtime | void SetInterOpParallelNum(int32_t parallel_num) | `Context.cpu.inter_op_parallel_num `__ | +| Context | Set the parallel number of operators at runtime | void SetInterOpParallelNum(int32_t parallel_num) | `Context.cpu.inter_op_parallel_num `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Get the current operators parallel number setting | int32_t GetInterOpParallelNum() const | `Context.cpu.inter_op_parallel_num `__ | +| Context | Get the current operators parallel number setting | int32_t GetInterOpParallelNum() const | `Context.cpu.inter_op_parallel_num `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Set the thread affinity to CPU cores | void SetThreadAffinity(int mode) | `Context.cpu.thread_affinity_mode `__ | +| Context | Set the thread affinity to CPU cores | void SetThreadAffinity(int mode) | `Context.cpu.thread_affinity_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Get the thread affinity of CPU cores | int GetThreadAffinityMode() const | `Context.cpu.thread_affinity_mode `__ | +| Context | Get the thread affinity of CPU cores | int GetThreadAffinityMode() const | `Context.cpu.thread_affinity_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Set the thread lists to CPU cores | void SetThreadAffinity(const std::vector &core_list) | `Context.cpu.thread_affinity_core_list `__ | +| Context | Set the thread lists to CPU cores | void SetThreadAffinity(const std::vector &core_list) | `Context.cpu.thread_affinity_core_list `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Get the thread lists of CPU cores | std::vector GetThreadAffinityCoreList() const | `Context.cpu.thread_affinity_core_list `__ | +| Context | Get the thread lists of CPU cores | std::vector GetThreadAffinityCoreList() const | `Context.cpu.thread_affinity_core_list `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Context | Set the status whether to perform model inference or training in parallel | void SetEnableParallel(bool is_parallel) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -44,7 +44,7 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Context | Get the mode of the model run | bool GetMultiModalHW() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | Get a mutable reference of DeviceInfoContext vector in this context | std::vector> &MutableDeviceInfo() | Wrapped in `Context.target `__ | +| Context | Get a mutable reference of DeviceInfoContext vector in this context | std::vector> &MutableDeviceInfo() | Wrapped in `Context.target `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | DeviceInfoContext | Get the type of this DeviceInfoContext | enum DeviceType GetDeviceType() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -62,29 +62,29 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | DeviceInfoContext | obtain memory allocator | std::shared_ptr GetAllocator() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| CPUDeviceInfo | Get the type of this DeviceInfoContext | enum DeviceType GetDeviceType() const | `context.cpu `__ | +| CPUDeviceInfo | Get the type of this DeviceInfoContext | enum DeviceType GetDeviceType() const | `context.cpu `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| CPUDeviceInfo | Set enables to perform the float16 inference | void SetEnableFP16(bool is_fp16) | `Context.cpu.precision_mode `__ | +| CPUDeviceInfo | Set enables to perform the float16 inference | void SetEnableFP16(bool is_fp16) | `Context.cpu.precision_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| CPUDeviceInfo | Get enables to perform the float16 inference | bool GetEnableFP16() const | `Context.cpu.precision_mode `__ | +| CPUDeviceInfo | Get enables to perform the float16 inference | bool GetEnableFP16() const | `Context.cpu.precision_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | Get the type of this DeviceInfoContext | enum DeviceType GetDeviceType() const | `Context.gpu `__ | +| GPUDeviceInfo | Get the type of this DeviceInfoContext | enum DeviceType GetDeviceType() const | `Context.gpu `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | Set device id | void SetDeviceID(uint32_t device_id) | `Context.gpu.device_id `__ | +| GPUDeviceInfo | Set device id | void SetDeviceID(uint32_t device_id) | `Context.gpu.device_id `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | Get the device id | uint32_t GetDeviceID() const | `Context.gpu.device_id `__ | +| GPUDeviceInfo | Get the device id | uint32_t GetDeviceID() const | `Context.gpu.device_id `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | Get the distribution rank id | int GetRankID() const | `Context.gpu.rank_id `__ | +| GPUDeviceInfo | Get the distribution rank id | int GetRankID() const | `Context.gpu.rank_id `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | Get the distribution group size | int GetGroupSize() const | `Context.gpu.group_size `__ | +| GPUDeviceInfo | Get the distribution group size | int GetGroupSize() const | `Context.gpu.group_size `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | Set the precision mode | void SetPrecisionMode(const std::string &precision_mode) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | Get the precision mode | std::string GetPrecisionMode() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | Set enables to perform the float16 inference | void SetEnableFP16(bool is_fp16) | `Context.gpu.precision_mode `__ | +| GPUDeviceInfo | Set enables to perform the float16 inference | void SetEnableFP16(bool is_fp16) | `Context.gpu.precision_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | Get enables to perform the float16 inference | bool GetEnableFP16() const | `Context.gpu.precision_mode `__ | +| GPUDeviceInfo | Get enables to perform the float16 inference | bool GetEnableFP16() const | `Context.gpu.precision_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | Set enables to sharing mem with OpenGL | void SetEnableGLTexture(bool is_enable_gl_texture) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -98,11 +98,11 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | Get current OpenGL display | void \*GetGLDisplay() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | Get the type of this DeviceInfoContext | enum DeviceType GetDeviceType() const | `Context.ascend `__ | +| AscendDeviceInfo | Get the type of this DeviceInfoContext | enum DeviceType GetDeviceType() const | `Context.ascend `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | Set device id | void SetDeviceID(uint32_t device_id) | `Context.ascend.device_id `__ | +| AscendDeviceInfo | Set device id | void SetDeviceID(uint32_t device_id) | `Context.ascend.device_id `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | Get the device id | uint32_t GetDeviceID() const | `Context.ascend.device_id `__ | +| AscendDeviceInfo | Get the device id | uint32_t GetDeviceID() const | `Context.ascend.device_id `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | AscendDeviceInfo | Set AIPP configuration file path | void SetInsertOpConfigPath(const std::string &cfg_path) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -132,9 +132,9 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | AscendDeviceInfo | Get type of model outputs | enum DataType GetOutputType() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | Set precision mode of model | void SetPrecisionMode(const std::string &precision_mode) | `Context.ascend.precision_mode `__ | +| AscendDeviceInfo | Set precision mode of model | void SetPrecisionMode(const std::string &precision_mode) | `Context.ascend.precision_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | Get precision mode of model | std::string GetPrecisionMode() const | `Context.ascend.precision_mode `__ | +| AscendDeviceInfo | Get precision mode of model | std::string GetPrecisionMode() const | `Context.ascend.precision_mode `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | AscendDeviceInfo | Set op select implementation mode | void SetOpSelectImplMode(const std::string &op_select_impl_mode) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -160,7 +160,7 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Build a model from model buffer so that it can run on a device | Status Build(const void \*model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context = nullptr) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | Load and build a model from model buffer so that it can run on a device | Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context = nullptr) | `Model.build_from_file `__ | +| Model | Load and build a model from model buffer so that it can run on a device | Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context = nullptr) | `Model.build_from_file `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Build a model from model buffer so that it can run on a device | Status Build(const void \*model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -172,11 +172,11 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Build a Transfer Learning model where the backbone weights are fixed and the head weights are trainable | Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr &context, const std::shared_ptr &train_cfg = nullptr) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | Resize the shapes of inputs | Status Resize(const std::vector &inputs, const std::vector > &dims) | `Model.resize `__ | +| Model | Resize the shapes of inputs | Status Resize(const std::vector &inputs, const std::vector > &dims) | `Model.resize `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Change the size and or content of weight tensors | Status UpdateWeights(const std::vector &new_weights) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | Inference model API | Status Predict(const std::vector &inputs, std::vector \*outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.predict `__ | +| Model | Inference model API | Status Predict(const std::vector &inputs, std::vector \*outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.predict `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Inference model API only with callback | Status Predict(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -188,11 +188,11 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Check if data preprocess exists in model | bool HasPreprocess() | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | Load config file | Status LoadConfig(const std::string &config_path) | Wrapped in the parameter `config_path` of `Model.build_from_file `__ | +| Model | Load config file | Status LoadConfig(const std::string &config_path) | Wrapped in the parameter `config_path` of `Model.build_from_file `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Update config | Status UpdateConfig(const std::string §ion, const std::pair &config) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | Obtains all input tensors of the model | std::vector GetInputs() | `Model.get_inputs `__ | +| Model | Obtains all input tensors of the model | std::vector GetInputs() | `Model.get_inputs `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Obtains the input tensor of the model by name | MSTensor GetInputByTensorName(const std::string &tensor_name) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -220,7 +220,7 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Accessor to TrainLoop metric objects | std::vector GetMetrics() | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | Obtains all output tensors of the model | std::vector GetOutputs() | Wrapped in the return value of `Model.predict `__ | +| Model | Obtains all output tensors of the model | std::vector GetOutputs() | Wrapped in the return value of `Model.predict `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Obtains names of all output tensors of the model | std::vector GetOutputTensorNames() | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -240,33 +240,33 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | Check if the device supports the model | static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Set the number of workers at runtime | void SetWorkersNum(int32_t workers_num) | `Context.parallel.workers_num `__ | +| RunnerConfig | Set the number of workers at runtime | void SetWorkersNum(int32_t workers_num) | `Context.parallel.workers_num `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Get the current operators parallel workers number setting | int32_t GetWorkersNum() const | `Context.parallel.workers_num `__ | +| RunnerConfig | Get the current operators parallel workers number setting | int32_t GetWorkersNum() const | `Context.parallel.workers_num `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Set the context at runtime | void SetContext(const std::shared_ptr &context) | Wrapped in `Context.parallel `__ | +| RunnerConfig | Set the context at runtime | void SetContext(const std::shared_ptr &context) | Wrapped in `Context.parallel `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Get the current context setting | std::shared_ptr GetContext() const | Wrapped in `Context.parallel `__ | +| RunnerConfig | Get the current context setting | std::shared_ptr GetContext() const | Wrapped in `Context.parallel `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Set the config before runtime | void SetConfigInfo(const std::string §ion, const std::map &config) | `Context.parallel.config_info `__ | +| RunnerConfig | Set the config before runtime | void SetConfigInfo(const std::string §ion, const std::map &config) | `Context.parallel.config_info `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Get the current config setting | std::map> GetConfigInfo() const | `Context.parallel.config_info `__ | +| RunnerConfig | Get the current config setting | std::map> GetConfigInfo() const | `Context.parallel.config_info `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Set the config path before runtime | void SetConfigPath(const std::string &config_path) | `Context.parallel.config_path `__ | +| RunnerConfig | Set the config path before runtime | void SetConfigPath(const std::string &config_path) | `Context.parallel.config_path `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | Get the current config path | std::string GetConfigPath() const | `Context.parallel.config_path `__ | +| RunnerConfig | Get the current config path | std::string GetConfigPath() const | `Context.parallel.config_path `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | build a model parallel runner from model path so that it can run on a device | Status Init(const std::string &model_path, const std::shared_ptr &runner_config = nullptr) | `Model.parallel_runner.build_from_file `__ | +| ModelParallelRunner | build a model parallel runner from model path so that it can run on a device | Status Init(const std::string &model_path, const std::shared_ptr &runner_config = nullptr) | `Model.parallel_runner.build_from_file `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ModelParallelRunner | build a model parallel runner from model buffer so that it can run on a device | Status Init(const void \*model_data, const size_t data_size, const std::shared_ptr &runner_config = nullptr) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | Obtains all input tensors information of the model | std::vector GetInputs() | `Model.parallel_runner.get_inputs `__ | +| ModelParallelRunner | Obtains all input tensors information of the model | std::vector GetInputs() | `Model.parallel_runner.get_inputs `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | Obtains all output tensors information of the model | std::vector GetOutputs() | Wrapped in the return value of `Model.parallel_runner.predict `__ | +| ModelParallelRunner | Obtains all output tensors information of the model | std::vector GetOutputs() | Wrapped in the return value of `Model.parallel_runner.predict `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | Inference ModelParallelRunner | Status Predict(const std::vector &inputs, std::vector \*outputs,const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.parallel_runner.predict `__ | +| ModelParallelRunner | Inference ModelParallelRunner | Status Predict(const std::vector &inputs, std::vector \*outputs,const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.parallel_runner.predict `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Creates a MSTensor object, whose data need to be copied before accessed by Model | static inline MSTensor \*CreateTensor(const std::string &name, DataType type, const std::vector &shape, const void \*data, size_t data_len) noexcept | `Tensor `__ | +| MSTensor | Creates a MSTensor object, whose data need to be copied before accessed by Model | static inline MSTensor \*CreateTensor(const std::string &name, DataType type, const std::vector &shape, const void \*data, size_t data_len) noexcept | `Tensor `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Creates a MSTensor object, whose data can be directly accessed by Model | static inline MSTensor \*CreateRefTensor(const std::string &name, DataType type, const std::vector &shape, const void \*data, size_t data_len, bool own_data = true) noexcept | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -280,19 +280,19 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Destroy an object created by `Clone` , `StringsToTensor` , `CreateRefTensor` or `CreateTensor` | static void DestroyTensorPtr(MSTensor \*tensor) noexcept | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Obtains the name of the MSTensor | std::string Name() const | `Tensor.name `__ | +| MSTensor | Obtains the name of the MSTensor | std::string Name() const | `Tensor.name `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Obtains the data type of the MSTensor | enum DataType DataType() const | `Tensor.dtype `__ | +| MSTensor | Obtains the data type of the MSTensor | enum DataType DataType() const | `Tensor.dtype `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Obtains the shape of the MSTensor | const std::vector &Shape() const | `Tensor.shape `__ | +| MSTensor | Obtains the shape of the MSTensor | const std::vector &Shape() const | `Tensor.shape `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Obtains the number of elements of the MSTensor | int64_t ElementNum() const | `Tensor.element_num `__ | +| MSTensor | Obtains the number of elements of the MSTensor | int64_t ElementNum() const | `Tensor.element_num `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Obtains a shared pointer to the copy of data of the MSTensor | std::shared_ptr Data() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Obtains the pointer to the data of the MSTensor | void \*MutableData() | Wrapped in `Tensor.get_data_to_numpy `__ and `Tensor.set_data_from_numpy `__ | +| MSTensor | Obtains the pointer to the data of the MSTensor | void \*MutableData() | Wrapped in `Tensor.get_data_to_numpy `__ and `Tensor.set_data_from_numpy `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Obtains the length of the data of the MSTensor, in bytes | size_t DataSize() const | `Tensor.data_size `__ | +| MSTensor | Obtains the length of the data of the MSTensor, in bytes | size_t DataSize() const | `Tensor.data_size `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Get whether the MSTensor data is const data | bool IsConst() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -308,19 +308,19 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Get the boolean value that indicates whether the MSTensor not equals tensor | bool operator!=(const MSTensor &tensor) const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Set the shape of for the MSTensor | void SetShape(const std::vector &shape) | `Tensor.shape `__ | +| MSTensor | Set the shape of for the MSTensor | void SetShape(const std::vector &shape) | `Tensor.shape `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Set the data type for the MSTensor | void SetDataType(enum DataType data_type) | `Tensor.dtype `__ | +| MSTensor | Set the data type for the MSTensor | void SetDataType(enum DataType data_type) | `Tensor.dtype `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Set the name for the MSTensor | void SetTensorName(const std::string &name) | `Tensor.name `__ | +| MSTensor | Set the name for the MSTensor | void SetTensorName(const std::string &name) | `Tensor.name `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Set the Allocator for the MSTensor | void SetAllocator(std::shared_ptr allocator) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Obtain the Allocator of the MSTensor | std::shared_ptr allocator() const | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Set the format for the MSTensor | void SetFormat(mindspore::Format format) | `Tensor.format `__ | +| MSTensor | Set the format for the MSTensor | void SetFormat(mindspore::Format format) | `Tensor.format `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | Obtain the format of the MSTensor | mindspore::Format format() const | `Tensor.format `__ | +| MSTensor | Obtain the format of the MSTensor | mindspore::Format format() const | `Tensor.format `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Set the data for the MSTensor | void SetData(void \*data, bool own_data = true) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -332,15 +332,15 @@ Summary of MindSpore Lite API support +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | Set the quantization parameters for the MSTensor | void SetQuantParams(std::vector quant_params) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | Construct a ModelGroup object and indicate shared workspace memory or shared weight memory, with default shared workspace memory | ModelGroup(ModelGroupFlag flags = ModelGroupFlag::kShareWorkspace) | `ModelGroup `__ | +| ModelGroup | Construct a ModelGroup object and indicate shared workspace memory or shared weight memory, with default shared workspace memory | ModelGroup(ModelGroupFlag flags = ModelGroupFlag::kShareWorkspace) | `ModelGroup `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | When sharing weight memory, add model objects that require shared weight memory | Status AddModel(const std::vector &model_list) | `ModelGroup.add_model `__ | +| ModelGroup | When sharing weight memory, add model objects that require shared weight memory | Status AddModel(const std::vector &model_list) | `ModelGroup.add_model `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | When sharing workspace memory, add the path of the model that requires shared workspace memory | Status AddModel(const std::vector &model_path_list) | `ModelGroup.add_model `__ | +| ModelGroup | When sharing workspace memory, add the path of the model that requires shared workspace memory | Status AddModel(const std::vector &model_path_list) | `ModelGroup.add_model `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ModelGroup | When sharing workspace memory, add a model buffer that requires shared workspace memory | Status AddModel(const std::vector> &model_buff_list) | | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | When sharing workspace memory, calculate the maximum workspace memory size | Status CalMaxSizeOfWorkspace(ModelType model_type, const std::shared_ptr &ms_context) | `ModelGroup.cal_max_size_of_workspace `__ | +| ModelGroup | When sharing workspace memory, calculate the maximum workspace memory size | Status CalMaxSizeOfWorkspace(ModelType model_type, const std::shared_ptr &ms_context) | `ModelGroup.cal_max_size_of_workspace `__ | +---------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/docs/lite/api/source_zh_cn/api_c/context_c.md b/docs/lite/api/source_zh_cn/api_c/context_c.md index 5e265507a5a4c1cc37fd482fbc57e7a65742b397..d7994b9a05e7c70c60f9f15ade1fad26f9dac4d0 100644 --- a/docs/lite/api/source_zh_cn/api_c/context_c.md +++ b/docs/lite/api/source_zh_cn/api_c/context_c.md @@ -1,6 +1,6 @@ # context_c -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_c/context_c.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_c/context_c.md) ```c #include @@ -198,7 +198,7 @@ MSDeviceInfoHandle MSDeviceInfoCreate(MSDeviceType device_type) 新建运行设备信息,若创建失败则会返回`nullptr`,并日志中输出信息。 - 参数 - - `device_type`: 设备类型,具体见[MSDeviceType](https://www.mindspore.cn/lite/api/zh-CN/master/api_c/types_c.html#msdevicetype)。 + - `device_type`: 设备类型,具体见[MSDeviceType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/types_c.html#msdevicetype)。 - 返回值 diff --git a/docs/lite/api/source_zh_cn/api_c/data_type_c.md b/docs/lite/api/source_zh_cn/api_c/data_type_c.md index ac6c4b6384887fd3d36aa458e855831904af7d07..4f4dee3adc7adc14826432f837412c716c310462 100644 --- a/docs/lite/api/source_zh_cn/api_c/data_type_c.md +++ b/docs/lite/api/source_zh_cn/api_c/data_type_c.md @@ -1,6 +1,6 @@ # data_type_c -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_c/data_type_c.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_c/data_type_c.md) ```C #include diff --git a/docs/lite/api/source_zh_cn/api_c/format_c.md b/docs/lite/api/source_zh_cn/api_c/format_c.md index 3b57375f73ee2225e474b006a018edc8a550f79b..7526bfee61ede3700bdd23e28e209715e9f531bd 100644 --- a/docs/lite/api/source_zh_cn/api_c/format_c.md +++ b/docs/lite/api/source_zh_cn/api_c/format_c.md @@ -1,6 +1,6 @@ # format_c -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_c/format_c.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_c/format_c.md) ```C #include diff --git a/docs/lite/api/source_zh_cn/api_c/lite_c_example.rst b/docs/lite/api/source_zh_cn/api_c/lite_c_example.rst index 9def15a73ba9657997156d0608ff819af596b3df..c2a44e78c9f15c880f2da758383d7c57c489f324 100644 --- a/docs/lite/api/source_zh_cn/api_c/lite_c_example.rst +++ b/docs/lite/api/source_zh_cn/api_c/lite_c_example.rst @@ -4,4 +4,4 @@ .. toctree:: :maxdepth: 1 - 极简Demo↗ + 极简Demo↗ diff --git a/docs/lite/api/source_zh_cn/api_c/model_c.md b/docs/lite/api/source_zh_cn/api_c/model_c.md index 30e786fc14f4dc11073a00aeebb437efffc47060..47e56c30ac164ffa56e649df0baf46dabd0fead1 100644 --- a/docs/lite/api/source_zh_cn/api_c/model_c.md +++ b/docs/lite/api/source_zh_cn/api_c/model_c.md @@ -1,6 +1,6 @@ # model_c -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_c/model_c.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_c/model_c.md) ```C #include @@ -91,8 +91,8 @@ MSStatus MSModelBuild(MSModelHandle model, const void* model_data, size_t data_s - `model`: 指向模型对象的指针。 - `model_data`: 内存中已经加载的模型数据地址。 - `data_size`: 模型数据的长度。 - - `model_type`: 模型文件类型,具体见: [MSModelType](https://mindspore.cn/lite/api/zh-CN/master/api_c/types_c.html#msmodeltype)。 - - `model_context`: 模型的上下文环境,具体见: [Context](https://mindspore.cn/lite/api/zh-CN/master/api_c/context_c.html)。 + - `model_type`: 模型文件类型,具体见: [MSModelType](https://mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/types_c.html#msmodeltype)。 + - `model_context`: 模型的上下文环境,具体见: [Context](https://mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/context_c.html)。 - 返回值 @@ -111,8 +111,8 @@ MSStatus MSModelBuildFromFile(MSModelHandle model, const char* model_path, MSMod - `model`: 指向模型对象的指针。 - `model_path`: 模型文件路径。 - - `model_type`: 模型文件类型,具体见: [MSModelType](https://mindspore.cn/lite/api/zh-CN/master/api_c/types_c.html#msmodeltype)。 - - `model_context`: 模型的上下文环境,具体见: [Context](https://mindspore.cn/lite/api/zh-CN/master/api_c/context_c.html)。 + - `model_type`: 模型文件类型,具体见: [MSModelType](https://mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/types_c.html#msmodeltype)。 + - `model_context`: 模型的上下文环境,具体见: [Context](https://mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/context_c.html)。 - 返回值 diff --git a/docs/lite/api/source_zh_cn/api_c/tensor_c.md b/docs/lite/api/source_zh_cn/api_c/tensor_c.md index bf50fa6d563e667d5c74de6af3a7ad405e85db39..56c901c6e58e57944ba8c695a2e831064f1b5189 100644 --- a/docs/lite/api/source_zh_cn/api_c/tensor_c.md +++ b/docs/lite/api/source_zh_cn/api_c/tensor_c.md @@ -1,6 +1,6 @@ # tensor_c -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_c/tensor_c.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_c/tensor_c.md) ```C #include @@ -123,7 +123,7 @@ void MSTensorSetDataType(MSTensorHandle tensor, MSDataType type) MSDataType MSTensorGetDataType(const MSTensorHandle tensor) ``` -获取MSTensor的数据类型,具体数据类型见[MSDataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_c/data_type_c.html#msdatatype)。 +获取MSTensor的数据类型,具体数据类型见[MSDataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/data_type_c.html#msdatatype)。 - 参数 - `tensor`: 指向MSTensor的指针。 @@ -171,7 +171,7 @@ void MSTensorSetFormat(MSTensorHandle tensor, MSFormat format) - 参数 - `tensor`: 指向MSTensor的指针。 - - `format`: 张量的数据排列,具体见[MSFormat](https://www.mindspore.cn/lite/api/zh-CN/master/api_c/format_c.html#msformat)。 + - `format`: 张量的数据排列,具体见[MSFormat](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/format_c.html#msformat)。 ### MSTensorGetFormat @@ -183,7 +183,7 @@ MSFormat MSTensorGetFormat(const MSTensorHandle tensor) - 返回值 - 张量的数据排列,具体见[MSFormat](https://www.mindspore.cn/lite/api/zh-CN/master/api_c/format_c.html#msformat)。 + 张量的数据排列,具体见[MSFormat](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_c/format_c.html#msformat)。 ### MSTensorSetData diff --git a/docs/lite/api/source_zh_cn/api_c/types_c.md b/docs/lite/api/source_zh_cn/api_c/types_c.md index 335996f20e37006db67221b0610e4f9253e57432..3e3b17d2fd1e6d70dcf93282af91393f3087b83e 100644 --- a/docs/lite/api/source_zh_cn/api_c/types_c.md +++ b/docs/lite/api/source_zh_cn/api_c/types_c.md @@ -1,6 +1,6 @@ # types_c -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_c/types_c.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_c/types_c.md) ```C #include diff --git a/docs/lite/api/source_zh_cn/api_cpp/lite_cpp_example.rst b/docs/lite/api/source_zh_cn/api_cpp/lite_cpp_example.rst index ecdf9d26248719343aa45cd2d7217615cced2eb9..708cd5aad2b2dede30479296ed4de7d5655fb33a 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/lite_cpp_example.rst +++ b/docs/lite/api/source_zh_cn/api_cpp/lite_cpp_example.rst @@ -4,6 +4,6 @@ .. toctree:: :maxdepth: 1 - 极简Demo↗ - 基于JNI接口的Android应用开发↗ - 高阶用法↗ \ No newline at end of file + 极简Demo↗ + 基于JNI接口的Android应用开发↗ + 高阶用法↗ \ No newline at end of file diff --git a/docs/lite/api/source_zh_cn/api_cpp/mindspore.md b/docs/lite/api/source_zh_cn/api_cpp/mindspore.md index 8e52b4ae9421445cfc06f486193ef7de697abeef..09eb24b87ab14701e305533ce8221a82c746f056 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/mindspore.md +++ b/docs/lite/api/source_zh_cn/api_cpp/mindspore.md @@ -1,6 +1,6 @@ # mindspore -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_cpp/mindspore.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_cpp/mindspore.md) ## 接口汇总 @@ -34,8 +34,8 @@ |--------------------------------------------------|---------------------------------------------------|--------|--------| | [MSTensor](#mstensor) | MindSpore中的张量。 | √ | √ | | [QuantParam](#quantparam) | MSTensor中的一组量化参数。 | √ | √ | -| [mindspore::DataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_datatype.html) | MindSpore MSTensor保存的数据支持的类型。 | √ | √ | -| [mindspore::Format](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_format.html) | MindSpore MSTensor保存的数据支持的排列格式。 | √ | √ | +| [mindspore::DataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_datatype.html) | MindSpore MSTensor保存的数据支持的类型。 | √ | √ | +| [mindspore::Format](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_format.html) | MindSpore MSTensor保存的数据支持的排列格式。 | √ | √ | | [Allocator](#allocator-1) | 内存管理基类。 | √ | √ | ### 模型分组 @@ -110,7 +110,7 @@ ## Context -\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/context.h)> +\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/context.h)> Context类用于保存执行中的环境变量。 @@ -139,8 +139,8 @@ Context() | [DelegateMode GetBuiltInDelegate() const](#getbuiltindelegate) | ✕ | √ | | [void SetDelegate(const std::shared_ptr &delegate)](#setdelegate) | ✕ | √ | | [std::shared_ptr GetDelegate() const](#getdelegate) | ✕ | √ | -| [void set_delegate(const std::shared_ptr &delegate)](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#set-delegate) | ✕ | √ | -| [std::shared_ptr get_delegate() const](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#get-delegate) | ✕ | √ | +| [void set_delegate(const std::shared_ptr &delegate)](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#set-delegate) | ✕ | √ | +| [std::shared_ptr get_delegate() const](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#get-delegate) | ✕ | √ | | [void SetMultiModalHW(bool float_mode)](#setmultimodalhw) | ✕ | √ | | [bool GetMultiModalHW() const](#getmultimodalhw) | ✕ | √ | | [std::vector> &MutableDeviceInfo()](#mutabledeviceinfo) | √ | √ | @@ -375,7 +375,7 @@ std::vector> &MutableDeviceInfo() ## DeviceInfoContext -\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/context.h)> +\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/context.h)> DeviceInfoContext类定义不同硬件设备的环境信息。 @@ -497,7 +497,7 @@ std::shared_ptr GetAllocator() const; ## CPUDeviceInfo -\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/context.h)> +\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/context.h)> 派生自[DeviceInfoContext](#deviceinfocontext),模型运行在CPU上的配置。 @@ -511,7 +511,7 @@ std::shared_ptr GetAllocator() const; ## GPUDeviceInfo -\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/context.h)> +\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/context.h)> 派生自[DeviceInfoContext](#deviceinfocontext),模型运行在GPU上的配置。 @@ -537,7 +537,7 @@ std::shared_ptr GetAllocator() const; ## KirinNPUDeviceInfo -\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/context.h)> +\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/context.h)> 派生自[DeviceInfoContext](#deviceinfocontext),模型运行在NPU上的配置。 @@ -553,7 +553,7 @@ std::shared_ptr GetAllocator() const; ## AscendDeviceInfo -\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/context.h)> +\#include <[context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/context.h)> 派生自[DeviceInfoContext](#deviceinfocontext),模型运行在Atlas 200/300/500推理产品、Atlas推理系列产品上的配置。 @@ -564,8 +564,8 @@ std::shared_ptr GetAllocator() const; | `enum DeviceType GetDeviceType() const` | - 返回值: DeviceType::kAscend | √ | √ | | `void SetDeviceID(uint32_t device_id)` | 用于指定设备ID

    - `device_id`: 设备ID | √ | √ | | `uint32_t GetDeviceID() const` | - 返回值: 已配置的设备ID | √ | √ | -| `void SetInsertOpConfigPath(const std::string &cfg_path)` | 模型插入[AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/devaids/auxiliarydevtool/atlasatc_16_0025.html)算子

    - `cfg_path`: [AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/devaids/auxiliarydevtool/atlasatc_16_0025.html)配置文件路径 | √ | √ | -| `std::string GetInsertOpConfigPath()` | - 返回值: 已配置的[AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/devaids/auxiliarydevtool/atlasatc_16_0025.html) | √ | √ | +| `void SetInsertOpConfigPath(const std::string &cfg_path)` | 模型插入[AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha001/devaids/devtools/atc/atlasatc_16_0016.html)算子

    - `cfg_path`: [AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha001/devaids/devtools/atc/atlasatc_16_0016.html)配置文件路径 | √ | √ | +| `std::string GetInsertOpConfigPath()` | - 返回值: 已配置的[AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha001/devaids/devtools/atc/atlasatc_16_0016.html) | √ | √ | | `void SetInputFormat(const std::string &format)` | 指定模型输入format

    - `format`: 可选有`"NCHW"`,`"NHWC"`,`"ND"` | √ | √ | | `std::string GetInputFormat()` | - 返回值: 已配置模型输入format | √ | √ | | `void SetInputShape(const std::string &shape)` | 指定模型输入shape,为字符串形式,需指定输入名称,每个shape值由`,`隔开,不同输入由`;`隔开

    - `shape`: 如`"input_op_name1:1,2,3,4;input_op_name2:4,3,2,1"` | √ | √ | @@ -589,7 +589,7 @@ std::shared_ptr GetAllocator() const; ## Serialization -\#include <[serialization.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/serialization.h)> +\#include <[serialization.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/serialization.h)> Serialization类汇总了模型文件读写的方法。 @@ -859,7 +859,7 @@ Buffer Clone() const; ## Model -\#include <[model.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/model.h)> +\#include <[model.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/model.h)> Model定义了MindSpore中的模型,便于计算图管理。 @@ -1591,7 +1591,7 @@ Status UpdateWeights(const std::vector &new_weights) ## MSTensor -\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/types.h)> +\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/types.h)> `MSTensor`定义了MindSpore中的张量。 @@ -1786,9 +1786,9 @@ void DestroyTensorPtr(MSTensor *tensor) noexcept; | [bool IsConst() const](#isconst) | √ | √ | | [bool IsDevice() const](#isdevice) | √ | ✕ | | [MSTensor *Clone() const](#clone) | √ | √ | -| [bool operator==(std::nullptr_t) const](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#operator==std-nullptr-t) | √ | √ | -| [bool operator!=(std::nullptr_t) const](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#operator!=std-nullptr-t) | √ | √ | -| [bool operator==(const MSTensor &tensor) const](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#operator==const-mstensor-tensor) | √ | √ | +| [bool operator==(std::nullptr_t) const](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#operator==std-nullptr-t) | √ | √ | +| [bool operator!=(std::nullptr_t) const](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#operator!=std-nullptr-t) | √ | √ | +| [bool operator==(const MSTensor &tensor) const](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#operator==const-mstensor-tensor) | √ | √ | | [void SetShape(const std::vector &shape)](#setshape) | √ | √ | | [void SetDataType(enum DataType data_type)](#setdatatype) | √ | √ | | [void SetTensorName(const std::string &name)](#settensorname) | √ | √ | @@ -2102,7 +2102,7 @@ const std::shared_ptr impl() ## QuantParam -\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/types.h)> +\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/types.h)> 一个结构体。QuantParam定义了MSTensor的一组量化参数。 @@ -2150,7 +2150,7 @@ max ## MSKernelCallBack -\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/types.h)> +\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/types.h)> ```cpp using MSKernelCallBack = std::function &inputs, const std::vector &outputs, const MSCallBackParam &opInfo)> @@ -2160,7 +2160,7 @@ using MSKernelCallBack = std::function &inputs, ## MSCallBackParam -\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/types.h)> +\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/types.h)> 一个结构体。MSCallBackParam定义了回调函数的输入参数。 @@ -2192,7 +2192,7 @@ execute_time ## Delegate -\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/delegate.h)> +\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/delegate.h)> `Delegate`定义了第三方AI框架接入MindSpore Lite的代理接口。 @@ -2235,7 +2235,7 @@ Delegate在线构图。 ## CoreMLDelegate -\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/delegate.h)> +\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/delegate.h)> `CoreMLDelegate`继承自`Delegate`类,定义了CoreML框架接入MindSpore Lite的代理接口。 @@ -2277,7 +2277,7 @@ CoreMLDelegate在线构图,仅在内部图编译阶段调用。 ## SchemaVersion -\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/delegate.h)> +\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/delegate.h)> 定义了MindSpore Lite执行在线推理时模型文件的版本。 @@ -2291,9 +2291,9 @@ typedef enum { ## KernelIter -\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/delegate.h)> +\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/delegate.h)> -定义了MindSpore Lite [Kernel](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_kernel.html#mindspore-kernel)列表的迭代器。 +定义了MindSpore Lite [Kernel](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_kernel.html#mindspore-kernel)列表的迭代器。 ```cpp using KernelIter = std::vector::iterator; @@ -2301,7 +2301,7 @@ using KernelIter = std::vector::iterator; ## DelegateModel -\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/delegate.h)> +\#include <[delegate.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/delegate.h)> `DelegateModel`定义了MindSpore Lite Delegate机制操作的的模型对象。 @@ -2327,7 +2327,7 @@ DelegateModel(std::vector *kernels, const std::vector *kernels_; ``` -[**Kernel**](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_kernel.html#kernel)的列表,保存模型的所有算子。 +[**Kernel**](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_kernel.html#kernel)的列表,保存模型的所有算子。 #### inputs_ @@ -2335,7 +2335,7 @@ std::vector *kernels_; const std::vector &inputs_; ``` -[**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)的列表,保存这个算子的输入tensor。 +[**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)的列表,保存这个算子的输入tensor。 #### outputs_ @@ -2343,7 +2343,7 @@ const std::vector &inputs_; const std::vector &outputs; ``` -[**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)的列表,保存这个算子的输出tensor。 +[**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)的列表,保存这个算子的输出tensor。 #### primitives_ @@ -2351,7 +2351,7 @@ const std::vector &outputs; const std::map &primitives_; ``` -[**Kernel**](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_kernel.html#kernel)和**schema::Primitive**的Map,保存所有算子的属性。 +[**Kernel**](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_kernel.html#kernel)和**schema::Primitive**的Map,保存所有算子的属性。 #### version_ @@ -2431,7 +2431,7 @@ const std::vector &inputs() - 返回值 - [**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)的列表。 + [**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)的列表。 #### outputs @@ -2443,7 +2443,7 @@ const std::vector &outputs() - 返回值 - [**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)的列表。 + [**MSTensor**](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)的列表。 #### GetVersion @@ -2459,7 +2459,7 @@ const SchemaVersion GetVersion() { return version_; } ## TrainCfg -\#include <[cfg.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/cfg.h)> +\#include <[cfg.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/cfg.h)> `TrainCfg`MindSpore Lite训练的相关配置参数。 @@ -2507,7 +2507,7 @@ bool accumulate_gradients_; ## MixPrecisionCfg -\#include <[cfg.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/cfg.h)> +\#include <[cfg.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/cfg.h)> `MixPrecisionCfg`MindSpore Lite训练混合精度配置类。 @@ -2549,7 +2549,7 @@ bool is_raw_mix_precision_; ## AccuracyMetrics -\#include <[accuracy.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/metrics/accuracy.h)> +\#include <[accuracy.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/metrics/accuracy.h)> `AccuracyMetrics`MindSpore Lite训练精度类。 @@ -2584,7 +2584,7 @@ float Eval() override; ## Metrics -\#include <[metrics.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/metrics/metrics.h)> +\#include <[metrics.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/metrics/metrics.h)> `Metrics`MindSpore Lite训练指标类。 @@ -2631,7 +2631,7 @@ virtual void Update(std::vector inputs, std::vector outp ## TrainCallBack -\#include <[callback.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/callback.h)> +\#include <[callback.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/callback.h)> `Metrics`MindSpore Lite训练回调类。 @@ -2730,7 +2730,7 @@ virtual void Begin(const TrainCallBackData &cb_data) {} ## TrainCallBackData -\#include <[callback.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/callback.h)> +\#include <[callback.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/callback.h)> 一个结构体。TrainCallBackData定义了训练回调的一组参数。 @@ -2770,7 +2770,7 @@ model_ ## CkptSaver -\#include <[ckpt_saver.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/ckpt_saver.h)> +\#include <[ckpt_saver.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/ckpt_saver.h)> `Metrics`MindSpore Lite训练模型文件保存类。 @@ -2783,7 +2783,7 @@ model_ ## LossMonitor -\#include <[loss_monitor.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/loss_monitor.h)> +\#include <[loss_monitor.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/loss_monitor.h)> `Metrics`MindSpore Lite训练损失函数类。 @@ -2810,7 +2810,7 @@ model_ ## LRScheduler -\#include <[lr_scheduler.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/lr_scheduler.h)> +\#include <[lr_scheduler.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/lr_scheduler.h)> `Metrics`MindSpore Lite训练学习率调度类。 @@ -2823,7 +2823,7 @@ model_ ## StepLRLambda -\#include <[lr_scheduler.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/lr_scheduler.h)> +\#include <[lr_scheduler.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/lr_scheduler.h)> 一个结构体。StepLRLambda定义了训练学习率的一组参数。 @@ -2847,7 +2847,7 @@ gamma ## MultiplicativeLRLambda -\#include <[lr_scheduler.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/lr_scheduler.h)> +\#include <[lr_scheduler.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/lr_scheduler.h)> 每个epoch将学习率乘以一个因子。 @@ -2875,7 +2875,7 @@ int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication) ## TimeMonitor -\#include <[time_monitor.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/time_monitor.h)> +\#include <[time_monitor.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/time_monitor.h)> `Metrics`MindSpore Lite训练时间监测类。 @@ -2921,7 +2921,7 @@ int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication) ## TrainAccuracy -\#include <[train_accuracy.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/callback/train_accuracy.h)> +\#include <[train_accuracy.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/callback/train_accuracy.h)> `Metrics`MindSpore Lite训练学习率调度类。 @@ -2980,7 +2980,7 @@ std::vector CharVersion() |-----------------------|--------|--------| | [std::string Version()](#version) | ✕ | √ | -\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/types.h)> +\#include <[types.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/types.h)> ```cpp std::string Version() @@ -2994,7 +2994,7 @@ std::string Version() ## Allocator -\#include <[allocator.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/allocator.h)> +\#include <[allocator.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/allocator.h)> 内存管理基类。 @@ -3148,11 +3148,11 @@ inline Status(const StatusCode code, int line_of_code, const char *file_name, co | [inline std::string GetErrDescription() const](#geterrdescription) | √ | √ | | [inline std::string SetErrDescription(const std::string &err_description)](#seterrdescription) | √ | √ | | [inline void SetStatusMsg(const std::string &status_msg)](#setstatusmsg) | √ | √ | -| [friend std::ostream &operator<<(std::ostream &os, const Status &s)](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#operator< GetDeviceIds() const ## ModelParallelRunner -\#include <[model_parallel_runner.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/model_parallel_runner.h)> +\#include <[model_parallel_runner.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/model_parallel_runner.h)> ModelParallelRunner定义了MindSpore的多个Model以及并发策略,便于多个Model的调度与管理。 @@ -3790,7 +3790,7 @@ std::vector GetOutputs() ## ModelGroup -\#include <[model_group.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/model_group.h)> +\#include <[model_group.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/model_group.h)> ModelGroup 类定义MindSpore Lite模型分组信息,用于共享工作空间(Workspace)内存或者权重(包括常量和变量)内存。 diff --git a/docs/lite/api/source_zh_cn/api_cpp/mindspore_converter.md b/docs/lite/api/source_zh_cn/api_cpp/mindspore_converter.md index 6991a358761e7cdcf860115dce6b229641ee9941..5fd57baf93b71d21fd7873a270f097ed7d644686 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/mindspore_converter.md +++ b/docs/lite/api/source_zh_cn/api_cpp/mindspore_converter.md @@ -1,6 +1,6 @@ # mindspore::converter -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_cpp/mindspore_converter.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_cpp/mindspore_converter.md) 以下描述了MindSpore Lite转换支持的模型类型及用户扩展所需的必要信息。 @@ -16,7 +16,7 @@ ## FmkType -\#include <[converter_context.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/converter_context.h)> +\#include <[converter_context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/converter_context.h)> **enum**类型变量,定义MindSpore Lite转换支持的框架类型。 @@ -31,7 +31,7 @@ ## ConverterParameters -\#include <[converter_context.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/converter_context.h)> +\#include <[converter_context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/converter_context.h)> **struct**类型结构体,定义模型解析时的转换参数,用于模型解析时的只读参数。 @@ -46,7 +46,7 @@ struct ConverterParameters { ## ConverterContext -\#include <[converter_context.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/converter_context.h)> +\#include <[converter_context.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/converter_context.h)> 模型转换过程中,基本信息的设置与获取。 @@ -94,7 +94,7 @@ static std::vector GetGraphOutputTensorNames(); ## NodeParser -\#include <[node_parser.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/node_parser.h)> +\#include <[node_parser.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/node_parser.h)> op节点的解析基类。 @@ -197,7 +197,7 @@ tflite节点解析接口函数。 ## NodeParserPtr -\#include <[node_parser.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/node_parser.h)> +\#include <[node_parser.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/node_parser.h)> NodeParser类的共享智能指针类型。 @@ -207,7 +207,7 @@ using NodeParserPtr = std::shared_ptr; ## ModelParser -\#include <[model_parser.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/model_parser.h)> +\#include <[model_parser.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/model_parser.h)> 解析原始模型的基类。 @@ -239,7 +239,7 @@ api::FuncGraphPtr Parse(const converter::ConverterParameters &flags); - 参数 - - `flags`: 解析模型时基本信息,具体见[ConverterParameters](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#converterparameters)。 + - `flags`: 解析模型时基本信息,具体见[ConverterParameters](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#converterparameters)。 - 返回值 diff --git a/docs/lite/api/source_zh_cn/api_cpp/mindspore_datatype.md b/docs/lite/api/source_zh_cn/api_cpp/mindspore_datatype.md index 9f0b6d8f3603b2ce5d48881df7f2ee4c9e2a3575..ba254f61a43fa8cc6fbcf6834cc51ab5e89e5d25 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/mindspore_datatype.md +++ b/docs/lite/api/source_zh_cn/api_cpp/mindspore_datatype.md @@ -1,6 +1,6 @@ # mindspore::DataType -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_cpp/mindspore_datatype.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_cpp/mindspore_datatype.md) 以下表格描述了MindSpore MSTensor保存的数据支持的类型。 diff --git a/docs/lite/api/source_zh_cn/api_cpp/mindspore_format.md b/docs/lite/api/source_zh_cn/api_cpp/mindspore_format.md index 480bbfa1b318b3c186bd64df42db30d3b8d400c9..cc34c8ce9bcad615801b809c3d40cd583cdb79e5 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/mindspore_format.md +++ b/docs/lite/api/source_zh_cn/api_cpp/mindspore_format.md @@ -1,6 +1,6 @@ # mindspore::Format -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_cpp/mindspore_format.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_cpp/mindspore_format.md) 以下表格描述了MindSpore MSTensor保存的数据支持的排列格式。 diff --git a/docs/lite/api/source_zh_cn/api_cpp/mindspore_kernel.md b/docs/lite/api/source_zh_cn/api_cpp/mindspore_kernel.md index e5f8303d442e04ccfcdaec2144592a44a009c1a7..5d090cb6e48999561df5587fea1a130689ac16ef 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/mindspore_kernel.md +++ b/docs/lite/api/source_zh_cn/api_cpp/mindspore_kernel.md @@ -1,6 +1,6 @@ # mindspore::kernel -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_cpp/mindspore_kernel.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_cpp/mindspore_kernel.md) ## 接口汇总 @@ -11,7 +11,7 @@ ## Kernel -\#include <[kernel.h](https://gitee.com/mindspore/mindspore/blob/master/include/api/kernel.h)> +\#include <[kernel.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/include/api/kernel.h)> Kernel是算子实现的基类,定义了几个必须实现的接口。 @@ -30,13 +30,13 @@ Kernel的默认与带参构造函数,构造Kernel实例。 - 参数 - - `inputs`: 算子输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)。 + - `inputs`: 算子输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)。 - - `outputs`: 算子输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)。 + - `outputs`: 算子输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)。 - `primitive`: 算子经由flatbuffers反序化为Primitive后的结果。 - - `ctx`: 算子的上下文[Context](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#context)。 + - `ctx`: 算子的上下文[Context](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#context)。 ## 析构函数 @@ -82,7 +82,7 @@ virtual int InferShape() ``` 在用户调用`Model::Build`接口时,或是模型推理中需要推理算子形状时,会调用到该接口。 -在自定义算子场景中,用户可以覆写该接口,实现自定义算子的形状推理逻辑。详见[自定义算子章节](https://www.mindspore.cn/lite/docs/zh-CN/master/advanced/third_party/register_kernel.html)。 +在自定义算子场景中,用户可以覆写该接口,实现自定义算子的形状推理逻辑。详见[自定义算子章节](https://www.mindspore.cn/lite/docs/zh-CN/r2.6.0/advanced/third_party/register_kernel.html)。 在`InferShape`函数中,一般需要实现算子的形状、数据类型和数据排布的推理逻辑。 ### type @@ -111,7 +111,7 @@ virtual void set_inputs(const std::vector &in_tensors) - 参数 - - `in_tensors`: 算子的所有输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)列表。 + - `in_tensors`: 算子的所有输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)列表。 ### set_input @@ -123,7 +123,7 @@ virtual set_input(mindspore::MSTensor in_tensor, int index) - 参数 - - `in_tensor`: 算子的输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)。 + - `in_tensor`: 算子的输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)。 - `index`: 算子输入在所有输入中的下标,从0开始计数。 @@ -137,7 +137,7 @@ virtual void set_outputs(const std::vector &out_tensors) - 参数 - - `out_tensor`: 算子的所有输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)列表。 + - `out_tensor`: 算子的所有输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)列表。 ### set_output @@ -149,7 +149,7 @@ virtual void set_output(mindspore::MSTensor out_tensor, int index) - 参数 - - `out_tensor`: 算子的输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)。 + - `out_tensor`: 算子的输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)。 - `index`: 算子输出在所有输出中的下标,从0开始计数。 @@ -159,7 +159,7 @@ virtual void set_output(mindspore::MSTensor out_tensor, int index) virtual const std::vector &inputs() ``` -返回算子的所有输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)列表。 +返回算子的所有输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)列表。 ### outputs @@ -167,7 +167,7 @@ virtual const std::vector &inputs() virtual const std::vector &outputs() ``` -返回算子的所有输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)列表。 +返回算子的所有输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)列表。 ### name @@ -195,7 +195,7 @@ void set_name(const std::string &name) const lite::Context *context() const ``` -返回算子对应的[Context](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#context)。 +返回算子对应的[Context](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#context)。 ### primitive @@ -243,7 +243,7 @@ std::map GetConfig(const std::string §ion) const ## KernelInterface -\#include <[kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/kernel_interface.h)> +\#include <[kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/kernel_interface.h)> 算子扩展能力基类。 @@ -275,9 +275,9 @@ virtual int Infer(std::vector *inputs, std::vector *inputs, std::vector +\#include <[node_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/node_parser_registry.h)> NodeParserRegistry类用于注册及获取NodeParser类型的共享智能指针。 @@ -41,11 +41,11 @@ NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type, - 参数 - - `fmk_type`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#fmktype)说明。 + - `fmk_type`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#fmktype)说明。 - `node_type`: 节点的类型。 - - `node_parser`: NodeParser类型的共享智能指针实例, 具体见[NodeParserPtr](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#nodeparserptr)说明。 + - `node_parser`: NodeParser类型的共享智能指针实例, 具体见[NodeParserPtr](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#nodeparserptr)说明。 ### ~NodeParserRegistry @@ -67,13 +67,13 @@ static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const - 参数 - - `fmk_type`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#fmktype)说明。 + - `fmk_type`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#fmktype)说明。 - `node_type`: 节点的类型。 ## REG_NODE_PARSER -\#include <[node_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/node_parser_registry.h)> +\#include <[node_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/node_parser_registry.h)> ```c++ #define REG_NODE_PARSER(fmk_type, node_type, node_parser) @@ -83,25 +83,25 @@ static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const - 参数 - - `fmk_type`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#fmktype)说明。 + - `fmk_type`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#fmktype)说明。 - `node_type`: 节点的类型。 - - `node_parser`: NodeParser类型的共享智能指针实例, 具体见[NodeParserPtr](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#nodeparserptr)说明。 + - `node_parser`: NodeParser类型的共享智能指针实例, 具体见[NodeParserPtr](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#nodeparserptr)说明。 ## ModelParserCreator -\#include <[model_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/model_parser_registry.h)> +\#include <[model_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/model_parser_registry.h)> ```c++ typedef converter::ModelParser *(*ModelParserCreator)() ``` -创建[ModelParser](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#modelparser)的函数原型声明。 +创建[ModelParser](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#modelparser)的函数原型声明。 ## ModelParserRegistry -\#include <[model_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/model_parser_registry.h)> +\#include <[model_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/model_parser_registry.h)> ModelParserRegistry类用于注册及获取ModelParserCreator类型的函数指针。 @@ -115,7 +115,7 @@ ModelParserRegistry(FmkType fmk, ModelParserCreator creator) - 参数 - - `fmk`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#fmktype)说明。 + - `fmk`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#fmktype)说明。 - `creator`: ModelParserCreator类型的函数指针, 具体见[ModelParserCreator](#modelparsercreator)说明。 @@ -139,11 +139,11 @@ static ModelParser *GetModelParser(FmkType fmk) - 参数 - - `fmk`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#fmktype)说明。 + - `fmk`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#fmktype)说明。 ## REG_MODEL_PARSER -\#include <[model_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/model_parser_registry.h)> +\#include <[model_parser_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/model_parser_registry.h)> ```c++ #define REG_MODEL_PARSER(fmk, parserCreator) @@ -153,15 +153,15 @@ static ModelParser *GetModelParser(FmkType fmk) - 参数 - - `fmk`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#fmktype)说明。 + - `fmk`: 框架类型,具体见[FmkType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#fmktype)说明。 - `creator`: ModelParserCreator类型的函数指针, 具体见[ModelParserCreator](#modelparsercreator)说明。 -> 用户自定义的ModelParser,框架类型必须满足设定支持的框架类型[FmkType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_converter.html#fmktype)。 +> 用户自定义的ModelParser,框架类型必须满足设定支持的框架类型[FmkType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_converter.html#fmktype)。 ## PassBase -\#include <[pass_base.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/pass_base.h)> +\#include <[pass_base.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/pass_base.h)> PassBase定义了图优化的基类,以供用户继承并自定义图优化算法。 @@ -201,7 +201,7 @@ virtual bool Execute(const api::FuncGraphPtr &func_graph) = 0; ## PassBasePtr -\#include <[pass_base.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/pass_base.h)> +\#include <[pass_base.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/pass_base.h)> PassBase类的共享智能指针类型。 @@ -211,7 +211,7 @@ using PassBasePtr = std::shared_ptr ## PassPosition -\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/pass_registry.h)> +\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/pass_registry.h)> **enum**类型变量,定义扩展Pass的运行位置。 @@ -224,7 +224,7 @@ enum PassPosition { ## PassRegistry -\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/pass_registry.h)> +\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/pass_registry.h)> PassRegistry类用于注册及获取Pass类实例。 @@ -290,7 +290,7 @@ static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name) ## REG_PASS -\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/pass_registry.h)> +\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/pass_registry.h)> ```c++ #define REG_PASS(name, pass) @@ -306,7 +306,7 @@ static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name) ## REG_SCHEDULED_PASS -\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/pass_registry.h)> +\#include <[pass_registry.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/pass_registry.h)> ```c++ #define REG_SCHEDULED_PASS(position, names) @@ -322,7 +322,7 @@ static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name) > MindSpore Lite开放了部分内置Pass,请见以下说明。用户可以在`names`参数中添加内置Pass的命名标识,以在指定运行处调用内置Pass。 > -> - `ConstFoldPass`: 将输入均是常量的节点进行离线计算,导出的模型将不含该节点。特别地,针对shape算子,在[inputShape](https://www.mindspore.cn/lite/docs/zh-CN/master/converter/converter_tool.html#参数说明)给定的情形下,也会触发预计算。 +> - `ConstFoldPass`: 将输入均是常量的节点进行离线计算,导出的模型将不含该节点。特别地,针对shape算子,在[inputShape](https://www.mindspore.cn/lite/docs/zh-CN/r2.6.0/converter/converter_tool.html#参数说明)给定的情形下,也会触发预计算。 > - `DumpGraph`: 导出当前状态下的模型。请确保当前模型为NHWC或者NCHW格式的模型,例如卷积算子等。 > - `ToNCHWFormat`: 将当前状态下的模型转换为NCHW的格式,例如,四维的图输入、卷积算子等。 > - `ToNHWCFormat`: 将当前状态下的模型转换为NHWC的格式,例如,四维的图输入、卷积算子等。 @@ -334,7 +334,7 @@ static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name) ## KernelDesc -\#include <[registry/register_kernel.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/register_kernel.h)> +\#include <[registry/register_kernel.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel.h)> **struct**类型结构体,定义扩展kernel的基本属性。 @@ -349,7 +349,7 @@ struct KernelDesc { ## RegisterKernel -\#include <[registry/register_kernel.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/register_kernel.h)> +\#include <[registry/register_kernel.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel.h)> ### CreateKernel @@ -363,13 +363,13 @@ using CreateKernel = std::function( - 参数 - - `inputs`: 算子输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)。 + - `inputs`: 算子输入[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)。 - - `outputs`: 算子输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#mstensor)。 + - `outputs`: 算子输出[MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#mstensor)。 - `primitive`: 算子经由flatbuffers反序化为Primitive后的结果。 - - `ctx`: 算子的上下文[Context](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#context)。 + - `ctx`: 算子的上下文[Context](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#context)。 ### 公有成员函数 @@ -387,9 +387,9 @@ static Status RegKernel(const std::string &arch, const std::string &provider, Da - `provider`: 生产商名,由用户自定义。 - - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_datatype.html)。 + - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_datatype.html)。 - - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 + - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 - `creator`: 创建算子的函数指针,具体见[CreateKernel](#createkernel)的说明。 @@ -407,7 +407,7 @@ Custom算子注册。 - `provider`: 生产商名,由用户自定义。 - - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_datatype.html)。 + - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_datatype.html)。 - `type`: 算子类型,由用户自定义,确保唯一即可。 @@ -429,7 +429,7 @@ static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *d ## KernelReg -\#include <[registry/register_kernel.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/register_kernel.h)> +\#include <[registry/register_kernel.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel.h)> ### ~KernelReg @@ -453,9 +453,9 @@ KernelReg(const std::string &arch, const std::string &provider, DataType data_ty - `provider`: 生产商名,由用户自定义。 - - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_datatype.html)。 + - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_datatype.html)。 - - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 + - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 - `creator`: 创建算子的函数指针,具体见[CreateKernel](#createkernel)的说明。 @@ -471,7 +471,7 @@ KernelReg(const std::string &arch, const std::string &provider, DataType data_ty - `provider`: 生产商名,由用户自定义。 - - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_datatype.html)。 + - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_datatype.html)。 - `op_type`: 算子类型,由用户自定义,确保唯一即可。 @@ -491,9 +491,9 @@ KernelReg(const std::string &arch, const std::string &provider, DataType data_ty - `provider`: 生产商名,由用户自定义。 - - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_datatype.html)。 + - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_datatype.html)。 - - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 + - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 - `creator`: 创建算子的函数指针,具体见[CreateKernel](#createkernel)的说明。 @@ -511,7 +511,7 @@ KernelReg(const std::string &arch, const std::string &provider, DataType data_ty - `provider`: 生产商名,由用户自定义。 - - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_datatype.html)。 + - `data_type`: 算子支持的数据类型,具体见[DataType](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_datatype.html)。 - `op_type`: 算子类型,由用户自定义,确保唯一即可。 @@ -519,7 +519,7 @@ KernelReg(const std::string &arch, const std::string &provider, DataType data_ty ## KernelInterfaceCreator -\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/register_kernel_interface.h)> +\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel_interface.h)> 定义创建算子的函数指针类型。 @@ -529,7 +529,7 @@ using KernelInterfaceCreator = std::function +\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel_interface.h)> 算子扩展能力注册实现类。 @@ -563,7 +563,7 @@ static Status Reg(const std::string &provider, int op_type, const KernelInterfac - `provider`: 生产商,由用户自定义。 - - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 + - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 - `creator`: KernelInterface的创建函数,详细见[KernelInterfaceCreator](#kernelinterfacecreator)的说明。 @@ -585,7 +585,7 @@ static std::shared_ptr GetKernelInterface(const std::st ## KernelInterfaceReg -\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/register_kernel_interface.h)> +\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel_interface.h)> 算子扩展能力注册构造类。 @@ -601,7 +601,7 @@ KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfa - `provider`: 生产商,由用户自定义。 - - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 + - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 - `creator`: KernelInterface的创建函数,详细见[KernelInterfaceCreator](#kernelinterfacecreator)的说明。 @@ -621,7 +621,7 @@ KernelInterfaceReg(const std::string &provider, const std::string &op_type, cons ## REGISTER_KERNEL_INTERFACE -\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/register_kernel_interface.h)> +\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel_interface.h)> 注册KernelInterface的实现。 @@ -633,13 +633,13 @@ KernelInterfaceReg(const std::string &provider, const std::string &op_type, cons - `provider`: 生产商,由用户自定义。 - - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 + - `op_type`: 算子类型,定义在[ops.fbs](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/schema/ops.fbs)中,编绎时会生成到ops_generated.h,该文件可以在发布件中获取。 - `creator`: 创建KernelInterface的函数指针,具体见[KernelInterfaceCreator](#kernelinterfacecreator)的说明。 ## REGISTER_CUSTOM_KERNEL_INTERFACE -\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/register_kernel_interface.h)> +\#include <[registry/register_kernel_interface.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/register_kernel_interface.h)> 注册Custom算子对应的KernelInterface实现。 diff --git a/docs/lite/api/source_zh_cn/api_cpp/mindspore_registry_opencl.md b/docs/lite/api/source_zh_cn/api_cpp/mindspore_registry_opencl.md index 21c8d89d0f6b4640e27a3633265c862db679e44d..3bb12937bebd94822843c15bb0f1947808a1664a 100644 --- a/docs/lite/api/source_zh_cn/api_cpp/mindspore_registry_opencl.md +++ b/docs/lite/api/source_zh_cn/api_cpp/mindspore_registry_opencl.md @@ -1,6 +1,6 @@ # mindspore::registry::opencl -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_cpp/mindspore_registry_opencl.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_cpp/mindspore_registry_opencl.md) ## 接口汇总 @@ -10,7 +10,7 @@ ## OpenCLRuntimeWrapper -\#include <[include/registry/opencl_runtime_wrapper.h](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/include/registry/opencl_runtime_wrapper.h)> +\#include <[include/registry/opencl_runtime_wrapper.h](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/include/registry/opencl_runtime_wrapper.h)> OpenCLRuntimeWrapper类包装了内部OpenCL的相关接口,用于支持南向GPU算子的开发。 @@ -132,7 +132,7 @@ Status SyncCommandQueue(); std::shared_ptr GetAllocator(); ``` -获取GPU内存分配器的智能指针。通过[Allocator接口](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html),可申请GPU内存,用于OpenCL内核的运算。 +获取GPU内存分配器的智能指针。通过[Allocator接口](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html),可申请GPU内存,用于OpenCL内核的运算。 ### MapBuffer diff --git a/docs/lite/api/source_zh_cn/api_java/class_list.md b/docs/lite/api/source_zh_cn/api_java/class_list.md index 8deb37158a697e1af41126b6d258a7c3020ccb69..cdaecab7c20a1cd9affe020e58f47d73ece04ee4 100644 --- a/docs/lite/api/source_zh_cn/api_java/class_list.md +++ b/docs/lite/api/source_zh_cn/api_java/class_list.md @@ -1,17 +1,17 @@ # 类列表 -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_java/class_list.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_java/class_list.md) | 包 | 类 | 描述 | 云侧推理是否支持 | 端侧推理是否支持 | | ------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |--------|--------| -| com.mindspore | [Model](https://www.mindspore.cn/lite/api/zh-CN/master/api_java/model.html) | Model定义了MindSpore中的模型,用于计算图的编译和执行。 | √ | √ | -| com.mindspore.config | [MSContext](https://www.mindspore.cn/lite/api/zh-CN/master/api_java/mscontext.html) | MSContext用于保存执行期间的上下文。 | √ | √ | -| com.mindspore | [MSTensor](https://www.mindspore.cn/lite/api/zh-CN/master/api_java/mstensor.html) | MSTensor定义了MindSpore中的张量。 | √ | √ | -| com.mindspore | [ModelParallelRunner](https://www.mindspore.cn/lite/api/zh-CN/master/api_java/model_parallel_runner.html) | 定义了MindSpore Lite并发推理。 | √ | ✕ | -| com.mindspore.config | [RunnerConfig](https://www.mindspore.cn/lite/api/zh-CN/master/api_java/runner_config.html) | RunnerConfig 定义并发推理的配置参数。 | √ | ✕ | -| com.mindspore | [Graph](https://www.mindspore.cn/lite/api/zh-CN/master/api_java/graph.html) | Model定义了MindSpore中的计算图。 | ✕ | √ | -| com.mindspore.config | [CpuBindMode](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java) | CpuBindMode定义了CPU绑定模式。 | √ | √ | -| com.mindspore.config | [DeviceType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java) | DeviceType定义了后端设备类型。 | √ | √ | -| com.mindspore.config | [DataType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java) | DataType定义了所支持的数据类型。 | √ | √ | -| com.mindspore.config | [Version](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java) | Version用于获取MindSpore的版本信息。 | ✕ | √ | -| com.mindspore.config | [ModelType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java) | ModelType 定义了模型文件的类型。 | √ | √ | +| com.mindspore | [Model](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_java/model.html) | Model定义了MindSpore中的模型,用于计算图的编译和执行。 | √ | √ | +| com.mindspore.config | [MSContext](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_java/mscontext.html) | MSContext用于保存执行期间的上下文。 | √ | √ | +| com.mindspore | [MSTensor](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_java/mstensor.html) | MSTensor定义了MindSpore中的张量。 | √ | √ | +| com.mindspore | [ModelParallelRunner](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_java/model_parallel_runner.html) | 定义了MindSpore Lite并发推理。 | √ | ✕ | +| com.mindspore.config | [RunnerConfig](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_java/runner_config.html) | RunnerConfig 定义并发推理的配置参数。 | √ | ✕ | +| com.mindspore | [Graph](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_java/graph.html) | Model定义了MindSpore中的计算图。 | ✕ | √ | +| com.mindspore.config | [CpuBindMode](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java) | CpuBindMode定义了CPU绑定模式。 | √ | √ | +| com.mindspore.config | [DeviceType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java) | DeviceType定义了后端设备类型。 | √ | √ | +| com.mindspore.config | [DataType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java) | DataType定义了所支持的数据类型。 | √ | √ | +| com.mindspore.config | [Version](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/Version.java) | Version用于获取MindSpore的版本信息。 | ✕ | √ | +| com.mindspore.config | [ModelType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/ModelType.java) | ModelType 定义了模型文件的类型。 | √ | √ | diff --git a/docs/lite/api/source_zh_cn/api_java/graph.md b/docs/lite/api/source_zh_cn/api_java/graph.md index d3685e2c6afcab0697de2bb40d36f265b4a3870d..9c5d74119217b07b555d830aef0095dd01547001 100644 --- a/docs/lite/api/source_zh_cn/api_java/graph.md +++ b/docs/lite/api/source_zh_cn/api_java/graph.md @@ -1,6 +1,6 @@ # Graph -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_java/graph.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_java/graph.md) ```java import com.mindspore.Graph; diff --git a/docs/lite/api/source_zh_cn/api_java/lite_java_example.rst b/docs/lite/api/source_zh_cn/api_java/lite_java_example.rst index 680868ab2f7f22a79569496cd6c87be38d663859..c177edfcf1aba9cd39b5e4ad7d0c861af29a05c2 100644 --- a/docs/lite/api/source_zh_cn/api_java/lite_java_example.rst +++ b/docs/lite/api/source_zh_cn/api_java/lite_java_example.rst @@ -4,6 +4,6 @@ .. toctree:: :maxdepth: 1 - 极简Demo↗ - 基于Java接口的Android应用开发↗ - 高阶用法↗ \ No newline at end of file + 极简Demo↗ + 基于Java接口的Android应用开发↗ + 高阶用法↗ \ No newline at end of file diff --git a/docs/lite/api/source_zh_cn/api_java/model.md b/docs/lite/api/source_zh_cn/api_java/model.md index 15dd3587f4220a6e2ebc89ee725c507ff56e4167..ad1d3de849d874a834b044d2ae098cb951787a84 100644 --- a/docs/lite/api/source_zh_cn/api_java/model.md +++ b/docs/lite/api/source_zh_cn/api_java/model.md @@ -1,6 +1,6 @@ # Model -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_java/model.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_java/model.md) ```java import com.mindspore.Model; diff --git a/docs/lite/api/source_zh_cn/api_java/model_parallel_runner.md b/docs/lite/api/source_zh_cn/api_java/model_parallel_runner.md index 44f7a1d734a4325ab169c60b578fe8f376934524..adb5545eb1512dea2ebe8ad6e57a61a41a603d93 100644 --- a/docs/lite/api/source_zh_cn/api_java/model_parallel_runner.md +++ b/docs/lite/api/source_zh_cn/api_java/model_parallel_runner.md @@ -1,6 +1,6 @@ # ModelParallelRunner -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_java/model_parallel_runner.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_java/model_parallel_runner.md) ```java import com.mindspore.config.RunnerConfig; diff --git a/docs/lite/api/source_zh_cn/api_java/mscontext.md b/docs/lite/api/source_zh_cn/api_java/mscontext.md index bfec9fc92c68ed8d209e76770939f0deaa41a1b6..f54436360469d186278aacc1732af783bc5eba76 100644 --- a/docs/lite/api/source_zh_cn/api_java/mscontext.md +++ b/docs/lite/api/source_zh_cn/api_java/mscontext.md @@ -1,6 +1,6 @@ # MSContext -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_java/mscontext.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_java/mscontext.md) ```java import com.mindspore.config.MSContext; @@ -53,7 +53,7 @@ public boolean init(int threadNum, int cpuBindMode) - 参数 - `threadNum`: 线程数。 - - `cpuBindMode`: CPU绑定模式,`cpuBindMode`在[com.mindspore.config.CpuBindMode](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)中定义。 + - `cpuBindMode`: CPU绑定模式,`cpuBindMode`在[com.mindspore.config.CpuBindMode](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)中定义。 - 返回值 @@ -68,7 +68,7 @@ public boolean init(int threadNum, int cpuBindMode, boolean isEnableParallel) - 参数 - `threadNum`: 线程数。 - - `cpuBindMode`: CPU绑定模式,`cpuBindMode`在[com.mindspore.config.CpuBindMode](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)中定义。 + - `cpuBindMode`: CPU绑定模式,`cpuBindMode`在[com.mindspore.config.CpuBindMode](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/CpuBindMode.java)中定义。 - `isEnableParallel`: 是否开启异构并行。 - 返回值 @@ -85,7 +85,7 @@ boolean addDeviceInfo(int deviceType, boolean isEnableFloat16) - 参数 - - `deviceType`: 设备类型,`deviceType`在[com.mindspore.config.DeviceType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)中定义。 + - `deviceType`: 设备类型,`deviceType`在[com.mindspore.config.DeviceType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)中定义。 - `isEnableFloat16`: 是否开启fp16。 - 返回值 @@ -100,7 +100,7 @@ boolean addDeviceInfo(int deviceType, boolean isEnableFloat16, int npuFreq) - 参数 - - `deviceType`: 设备类型,`deviceType`在[com.mindspore.config.DeviceType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)中定义。 + - `deviceType`: 设备类型,`deviceType`在[com.mindspore.config.DeviceType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DeviceType.java)中定义。 - `isEnableFloat16`: 是否开启fp16。 - `npuFreq`: NPU运行频率,仅当deviceType为npu才需要。 diff --git a/docs/lite/api/source_zh_cn/api_java/mstensor.md b/docs/lite/api/source_zh_cn/api_java/mstensor.md index 1061072e5340a71e915ef336640f43a92ff6f7ef..11cc589ecc5126946d709b37fa852cb30711ad47 100644 --- a/docs/lite/api/source_zh_cn/api_java/mstensor.md +++ b/docs/lite/api/source_zh_cn/api_java/mstensor.md @@ -1,6 +1,6 @@ # MSTensor -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_java/mstensor.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_java/mstensor.md) ```java import com.mindspore.MSTensor; @@ -83,7 +83,7 @@ public int[] getShape() public int getDataType() ``` -DataType在[com.mindspore.DataType](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java)中定义。 +DataType在[com.mindspore.DataType](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/java/src/main/java/com/mindspore/config/DataType.java)中定义。 - 返回值 diff --git a/docs/lite/api/source_zh_cn/api_java/runner_config.md b/docs/lite/api/source_zh_cn/api_java/runner_config.md index d06cfedcc3eb3ce82534497148621526c9e2462f..0cc47496b439b54197d3d617c497f4f4ee609d5c 100644 --- a/docs/lite/api/source_zh_cn/api_java/runner_config.md +++ b/docs/lite/api/source_zh_cn/api_java/runner_config.md @@ -1,6 +1,6 @@ # RunnerConfig -[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/api/source_zh_cn/api_java/runner_config.md) +[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/api/source_zh_cn/api_java/runner_config.md) RunnerConfig定义了MindSpore Lite并发推理的配置参数。 diff --git a/docs/lite/api/source_zh_cn/conf.py b/docs/lite/api/source_zh_cn/conf.py index 57ffe8f2c28ac4766e2d1f4ee27001f3582e1ea5..ddfd8bd48d58bcf673c80efeb19005af96467d16 100644 --- a/docs/lite/api/source_zh_cn/conf.py +++ b/docs/lite/api/source_zh_cn/conf.py @@ -30,7 +30,7 @@ copyright = 'MindSpore' author = 'MindSpore Lite' # The full version, including alpha/beta/rc tags -release = 'master' +release = '2.6.0' # -- General configuration --------------------------------------------------- @@ -234,31 +234,34 @@ docs_branch = [version_inf[i]['branch'] for i in range(len(version_inf)) if vers re_view = f"\n.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/{docs_branch}/" + \ f"resource/_static/logo_source.svg\n :target: https://gitee.com/mindspore/{copy_repo}/blob/{branch}/" -for cur, _, files in os.walk(present_path): +# modify urls +re_url = r"(((gitee.com/mindspore/docs)|(github.com/mindspore-ai/(mindspore|docs))|" + \ + r"(mindspore.cn/(docs|tutorials|lite))|(obs.dualstack.cn-north-4.myhuaweicloud)|" + \ + r"(mindspore-website.obs.cn-north-4.myhuaweicloud))[\w\d/_.-]*?)/(master)" + +re_url2 = r"(gitee.com/mindspore/mindspore[\w\d/_.-]*?)/(master)" + +re_url3 = r"(((gitee.com/mindspore/golden-stick)|(mindspore.cn/golden_stick))[\w\d/_.-]*?)/(master)" + +re_url4 = r"(((gitee.com/mindspore/mindformers)|(mindspore.cn/mindformers))[\w\d/_.-]*?)/(dev)" + +for cur, _, files in os.walk('./mindspore_lite'): for i in files: - flag_copy = 0 - if i.endswith('.rst'): - for j in copy_list: - if j in cur: - flag_copy = 1 - break - if os.path.join(cur, i) in copy_list or flag_copy: - try: - with open(os.path.join(cur, i), 'r+', encoding='utf-8') as f: - content = f.read() - new_content = content - if '.. include::' in content and '.. automodule::' in content: - continue - if 'autosummary::' not in content and "\n=====" in content: - re_view_ = re_view + copy_path + cur.split(present_path)[-1] + '/' + i + \ - '\n :alt: 查看源文件\n\n' - new_content = re.sub('([=]{5,})\n', r'\1\n' + re_view_, content, 1) - if new_content != content: - f.seek(0) - f.truncate() - f.write(new_content) - except Exception: - print(f'打开{i}文件失败') + if i.endswith('.rst') or i.endswith('.md') or i.endswith('.ipynb'): + try: + with open(os.path.join(cur, i), 'r+', encoding='utf-8') as f: + content = f.read() + new_content = re.sub(re_url, r'\1/r2.6.0', content) + # new_content = re.sub(re_url3, r'\1/r1.1.0', new_content) + new_content = re.sub(re_url4, r'\1/r1.5.0', new_content) + if i.endswith('.rst'): + new_content = re.sub(re_url2, r'\1/v2.6.0', new_content) + if new_content != content: + f.seek(0) + f.truncate() + f.write(new_content) + except Exception: + print(f'打开{i}文件失败') rst_files = set([i.replace('.rst', '') for i in glob.glob('mindspore_lite/*.rst', recursive=True)]) diff --git a/docs/lite/api/source_zh_cn/index.rst b/docs/lite/api/source_zh_cn/index.rst index 2f8ac789cf42e5661eecbde2ec89f9c97c87e783..05d7c4b98ea8bf61ab101de94d95ce4b9e69e59c 100644 --- a/docs/lite/api/source_zh_cn/index.rst +++ b/docs/lite/api/source_zh_cn/index.rst @@ -12,21 +12,21 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | 类名 | 接口说明 | C++ 接口 | Python 接口 | +=====================+=========================================================================================================+==========================================================================================================================================================================================================================+============================================================================================================================================================================================================================================================================================================================================================================+ -| Context | 设置运行时的线程数 | void SetThreadNum(int32_t thread_num) | `Context.cpu.thread_num `__ | +| Context | 设置运行时的线程数 | void SetThreadNum(int32_t thread_num) | `Context.cpu.thread_num `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 获取当前线程数设置 | int32_t GetThreadNum() const | `Context.cpu.thread_num `__ | +| Context | 获取当前线程数设置 | int32_t GetThreadNum() const | `Context.cpu.thread_num `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 设置运行时的算子并行推理数目 | void SetInterOpParallelNum(int32_t parallel_num) | `Context.cpu.inter_op_parallel_num `__ | +| Context | 设置运行时的算子并行推理数目 | void SetInterOpParallelNum(int32_t parallel_num) | `Context.cpu.inter_op_parallel_num `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 获取当前算子并行数设置 | int32_t GetInterOpParallelNum() const | `Context.cpu.inter_op_parallel_num `__ | +| Context | 获取当前算子并行数设置 | int32_t GetInterOpParallelNum() const | `Context.cpu.inter_op_parallel_num `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 设置运行时的CPU绑核策略 | void SetThreadAffinity(int mode) | `Context.cpu.thread_affinity_mode `__ | +| Context | 设置运行时的CPU绑核策略 | void SetThreadAffinity(int mode) | `Context.cpu.thread_affinity_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 获取当前CPU绑核策略 | int GetThreadAffinityMode() const | `Context.cpu.thread_affinity_mode `__ | +| Context | 获取当前CPU绑核策略 | int GetThreadAffinityMode() const | `Context.cpu.thread_affinity_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 设置运行时的CPU绑核列表 | void SetThreadAffinity(const std::vector &core_list) | `Context.cpu.thread_affinity_core_list `__ | +| Context | 设置运行时的CPU绑核列表 | void SetThreadAffinity(const std::vector &core_list) | `Context.cpu.thread_affinity_core_list `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 获取当前CPU绑核列表 | std::vector GetThreadAffinityCoreList() const | `Context.cpu.thread_affinity_core_list `__ | +| Context | 获取当前CPU绑核列表 | std::vector GetThreadAffinityCoreList() const | `Context.cpu.thread_affinity_core_list `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Context | 设置运行时是否支持并行 | void SetEnableParallel(bool is_parallel) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -44,7 +44,7 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Context | 获取当前配置中,量化模型的运行模式 | bool GetMultiModalHW() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Context | 修改该context下的DeviceInfoContext数组 | std::vector> &MutableDeviceInfo() | 封装在 `Context.target `__ | +| Context | 修改该context下的DeviceInfoContext数组 | std::vector> &MutableDeviceInfo() | 封装在 `Context.target `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | DeviceInfoContext | 获取该DeviceInfoContext的类型 | enum DeviceType GetDeviceType() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -62,29 +62,29 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | DeviceInfoContext | 获取内存管理器 | std::shared_ptr GetAllocator() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| CPUDeviceInfo | 获取该DeviceInfoContext的类型 | enum DeviceType GetDeviceType() const | `context.cpu `__ | +| CPUDeviceInfo | 获取该DeviceInfoContext的类型 | enum DeviceType GetDeviceType() const | `context.cpu `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| CPUDeviceInfo | 设置是否以FP16精度进行推理 | void SetEnableFP16(bool is_fp16) | `Context.cpu.precision_mode `__ | +| CPUDeviceInfo | 设置是否以FP16精度进行推理 | void SetEnableFP16(bool is_fp16) | `Context.cpu.precision_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| CPUDeviceInfo | 获取当前是否以FP16精度进行推理 | bool GetEnableFP16() const | `Context.cpu.precision_mode `__ | +| CPUDeviceInfo | 获取当前是否以FP16精度进行推理 | bool GetEnableFP16() const | `Context.cpu.precision_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | 获取该DeviceInfoContext的类型 | enum DeviceType GetDeviceType() const | `Context.gpu `__ | +| GPUDeviceInfo | 获取该DeviceInfoContext的类型 | enum DeviceType GetDeviceType() const | `Context.gpu `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | 设置设备ID | void SetDeviceID(uint32_t device_id) | `Context.gpu.device_id `__ | +| GPUDeviceInfo | 设置设备ID | void SetDeviceID(uint32_t device_id) | `Context.gpu.device_id `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | 获取设备ID | uint32_t GetDeviceID() const | `Context.gpu.device_id `__ | +| GPUDeviceInfo | 获取设备ID | uint32_t GetDeviceID() const | `Context.gpu.device_id `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | 获取当前运行的RANK ID | int GetRankID() const | `Context.gpu.rank_id `__ | +| GPUDeviceInfo | 获取当前运行的RANK ID | int GetRankID() const | `Context.gpu.rank_id `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | 获取当前运行的GROUP SIZE | int GetGroupSize() const | `Context.gpu.group_size `__ | +| GPUDeviceInfo | 获取当前运行的GROUP SIZE | int GetGroupSize() const | `Context.gpu.group_size `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | 设置推理时算子精度 | void SetPrecisionMode(const std::string &precision_mode) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | 获取推理时算子精度 | std::string GetPrecisionMode() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | 设置是否以FP16精度进行推理 | void SetEnableFP16(bool is_fp16) | `Context.gpu.precision_mode `__ | +| GPUDeviceInfo | 设置是否以FP16精度进行推理 | void SetEnableFP16(bool is_fp16) | `Context.gpu.precision_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| GPUDeviceInfo | 获取是否以FP16精度进行推理 | bool GetEnableFP16() const | `Context.gpu.precision_mode `__ | +| GPUDeviceInfo | 获取是否以FP16精度进行推理 | bool GetEnableFP16() const | `Context.gpu.precision_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | 设置是否绑定OpenGL纹理数据 | void SetEnableGLTexture(bool is_enable_gl_texture) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -98,11 +98,11 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | GPUDeviceInfo | 获取当前OpenGL EGLDisplay | void \*GetGLDisplay() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | 获取该DeviceInfoContext的类型 | enum DeviceType GetDeviceType() const | `Context.ascend `__ | +| AscendDeviceInfo | 获取该DeviceInfoContext的类型 | enum DeviceType GetDeviceType() const | `Context.ascend `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | 设置设备ID | void SetDeviceID(uint32_t device_id) | `Context.ascend.device_id `__ | +| AscendDeviceInfo | 设置设备ID | void SetDeviceID(uint32_t device_id) | `Context.ascend.device_id `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | 获取设备ID | uint32_t GetDeviceID() const | `Context.ascend.device_id `__ | +| AscendDeviceInfo | 获取设备ID | uint32_t GetDeviceID() const | `Context.ascend.device_id `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | AscendDeviceInfo | 设置AIPP配置文件路径 | void SetInsertOpConfigPath(const std::string &cfg_path) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -132,9 +132,9 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | AscendDeviceInfo | 获取模型输出type | enum DataType GetOutputType() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | 设置模型精度模式 | void SetPrecisionMode(const std::string &precision_mode) | `Context.ascend.precision_mode `__ | +| AscendDeviceInfo | 设置模型精度模式 | void SetPrecisionMode(const std::string &precision_mode) | `Context.ascend.precision_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| AscendDeviceInfo | 获取模型精度模式 | std::string GetPrecisionMode() const | `Context.ascend.precision_mode `__ | +| AscendDeviceInfo | 获取模型精度模式 | std::string GetPrecisionMode() const | `Context.ascend.precision_mode `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | AscendDeviceInfo | 设置算子实现方式 | void SetOpSelectImplMode(const std::string &op_select_impl_mode) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -160,7 +160,7 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 从内存缓冲区加载模型,并将模型编译至可在Device上运行的状态 | Status Build(const void \*model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context = nullptr) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | 从内存缓冲区加载模型,并将模型编译至可在Device上运行的状态 | Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context = nullptr) | `Model.build_from_file `__ | +| Model | 从内存缓冲区加载模型,并将模型编译至可在Device上运行的状态 | Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context = nullptr) | `Model.build_from_file `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 根据路径读取加载模型,并将模型编译至可在Device上运行的状态 | Status Build(const void \*model_data, size_t data_size, ModelType model_type, const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -172,11 +172,11 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 构建一个迁移学习模型,其中主干权重是固定的,头部权重是可训练的 | Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr &context, const std::shared_ptr &train_cfg = nullptr) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | 调整已编译模型的输入张量形状 | Status Resize(const std::vector &inputs, const std::vector > &dims) | `Model.resize `__ | +| Model | 调整已编译模型的输入张量形状 | Status Resize(const std::vector &inputs, const std::vector > &dims) | `Model.resize `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 更新模型的权重Tensor的大小和内容 | Status UpdateWeights(const std::vector &new_weights) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | 推理模型 | Status Predict(const std::vector &inputs, std::vector \*outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.predict `__ | +| Model | 推理模型 | Status Predict(const std::vector &inputs, std::vector \*outputs, const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.predict `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 仅带callback的推理模型 | Status Predict(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -188,11 +188,11 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 检查模型是否配置了数据预处理 | bool HasPreprocess() | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | 根据路径读取配置文件 | Status LoadConfig(const std::string &config_path) | 封装在 `Model.build_from_file `__ 方法的 `config_path` 参数中 | +| Model | 根据路径读取配置文件 | Status LoadConfig(const std::string &config_path) | 封装在 `Model.build_from_file `__ 方法的 `config_path` 参数中 | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 刷新配置 | Status UpdateConfig(const std::string §ion, const std::pair &config) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | 获取模型所有输入张量 | std::vector GetInputs() | `Model.get_inputs `__ | +| Model | 获取模型所有输入张量 | std::vector GetInputs() | `Model.get_inputs `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 获取模型指定名字的输入张量 | MSTensor GetInputByTensorName(const std::string &tensor_name) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -220,7 +220,7 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 获取训练指标参数 | std::vector GetMetrics() | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| Model | 获取模型所有输出张量 | std::vector GetOutputs() | 封装在 `Model.predict `__ 的返回值 | +| Model | 获取模型所有输出张量 | std::vector GetOutputs() | 封装在 `Model.predict `__ 的返回值 | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 获取模型所有输出张量的名字 | std::vector GetOutputTensorNames() | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -240,33 +240,33 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Model | 检查设备是否支持该模型 | static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 设置RunnerConfig的worker的个数 | void SetWorkersNum(int32_t workers_num) | `Context.parallel.workers_num `__ | +| RunnerConfig | 设置RunnerConfig的worker的个数 | void SetWorkersNum(int32_t workers_num) | `Context.parallel.workers_num `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 获取RunnerConfig的worker的个数 | int32_t GetWorkersNum() const | `Context.parallel.workers_num `__ | +| RunnerConfig | 获取RunnerConfig的worker的个数 | int32_t GetWorkersNum() const | `Context.parallel.workers_num `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 设置RunnerConfig的context参数 | void SetContext(const std::shared_ptr &context) | 封装在 `Context.parallel `__ | +| RunnerConfig | 设置RunnerConfig的context参数 | void SetContext(const std::shared_ptr &context) | 封装在 `Context.parallel `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 获取RunnerConfig配置的上下文参数 | std::shared_ptr GetContext() const | 封装在 `Context.parallel `__ | +| RunnerConfig | 获取RunnerConfig配置的上下文参数 | std::shared_ptr GetContext() const | 封装在 `Context.parallel `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 设置RunnerConfig的配置参数 | void SetConfigInfo(const std::string §ion, const std::map &config) | `Context.parallel.config_info `__ | +| RunnerConfig | 设置RunnerConfig的配置参数 | void SetConfigInfo(const std::string §ion, const std::map &config) | `Context.parallel.config_info `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 获取RunnerConfig配置参数信息 | std::map> GetConfigInfo() const | `Context.parallel.config_info `__ | +| RunnerConfig | 获取RunnerConfig配置参数信息 | std::map> GetConfigInfo() const | `Context.parallel.config_info `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 设置RunnerConfig中的配置文件路径 | void SetConfigPath(const std::string &config_path) | `Context.parallel.config_path `__ | +| RunnerConfig | 设置RunnerConfig中的配置文件路径 | void SetConfigPath(const std::string &config_path) | `Context.parallel.config_path `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| RunnerConfig | 获取RunnerConfig中的配置文件的路径 | std::string GetConfigPath() const | `Context.parallel.config_path `__ | +| RunnerConfig | 获取RunnerConfig中的配置文件的路径 | std::string GetConfigPath() const | `Context.parallel.config_path `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | 根据路径读取加载模型,生成一个或者多个模型,并将所有模型编译至可在Device上运行的状态 | Status Init(const std::string &model_path, const std::shared_ptr &runner_config = nullptr) | `Model.parallel_runner.build_from_file `__ | +| ModelParallelRunner | 根据路径读取加载模型,生成一个或者多个模型,并将所有模型编译至可在Device上运行的状态 | Status Init(const std::string &model_path, const std::shared_ptr &runner_config = nullptr) | `Model.parallel_runner.build_from_file `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ModelParallelRunner | 根据模文件数据,生成一个或者多个模型,并将所有模型编译至可在Device上运行的状态 | Status Init(const void \*model_data, const size_t data_size, const std::shared_ptr &runner_config = nullptr) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | 获取模型所有输入张量 | std::vector GetInputs() | `Model.parallel_runner.get_inputs `__ | +| ModelParallelRunner | 获取模型所有输入张量 | std::vector GetInputs() | `Model.parallel_runner.get_inputs `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | 获取模型所有输出张量 | std::vector GetOutputs() | 封装在 `Model.parallel_runner.predict `__ 的返回值 | +| ModelParallelRunner | 获取模型所有输出张量 | std::vector GetOutputs() | 封装在 `Model.parallel_runner.predict `__ 的返回值 | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelParallelRunner | 并发推理模型 | Status Predict(const std::vector &inputs, std::vector \*outputs,const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.parallel_runner.predict `__ | +| ModelParallelRunner | 并发推理模型 | Status Predict(const std::vector &inputs, std::vector \*outputs,const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr) | `Model.parallel_runner.predict `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 创建一个MSTensor对象,其数据需复制后才能由Model访问 | static inline MSTensor \*CreateTensor(const std::string &name, DataType type, const std::vector &shape, const void \*data, size_t data_len) noexcept | `Tensor `__ | +| MSTensor | 创建一个MSTensor对象,其数据需复制后才能由Model访问 | static inline MSTensor \*CreateTensor(const std::string &name, DataType type, const std::vector &shape, const void \*data, size_t data_len) noexcept | `Tensor `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 创建一个MSTensor对象,其数据可以直接由Model访问 | static inline MSTensor \*CreateRefTensor(const std::string &name, DataType type, const std::vector &shape, const void \*data, size_t data_len, bool own_data = true) noexcept | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -280,19 +280,19 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 销毁一个由 `Clone` 、 `StringsToTensor` 、 `CreateRefTensor` 或 `CreateTensor` 所创建的对象 | static void DestroyTensorPtr(MSTensor \*tensor) noexcept | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 获取MSTensor的名字 | std::string Name() const | `Tensor.name `__ | +| MSTensor | 获取MSTensor的名字 | std::string Name() const | `Tensor.name `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 获取MSTensor的数据类型 | enum DataType DataType() const | `Tensor.dtype `__ | +| MSTensor | 获取MSTensor的数据类型 | enum DataType DataType() const | `Tensor.dtype `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 获取MSTensor的Shape | const std::vector &Shape() const | `Tensor.shape `__ | +| MSTensor | 获取MSTensor的Shape | const std::vector &Shape() const | `Tensor.shape `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 获取MSTensor的元素个数 | int64_t ElementNum() const | `Tensor.element_num `__ | +| MSTensor | 获取MSTensor的元素个数 | int64_t ElementNum() const | `Tensor.element_num `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 获取指向MSTensor中的数据拷贝的智能指针 | std::shared_ptr Data() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 获取MSTensor中的数据的指针 | void \*MutableData() | 封装在 `Tensor.get_data_to_numpy `__ 和 `Tensor.set_data_from_numpy `__ | +| MSTensor | 获取MSTensor中的数据的指针 | void \*MutableData() | 封装在 `Tensor.get_data_to_numpy `__ 和 `Tensor.set_data_from_numpy `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 获取MSTensor中的数据的以字节为单位的内存长度 | size_t DataSize() const | `Tensor.data_size `__ | +| MSTensor | 获取MSTensor中的数据的以字节为单位的内存长度 | size_t DataSize() const | `Tensor.data_size `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 判断MSTensor中的数据是否是常量数据 | bool IsConst() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -308,19 +308,19 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 判断MSTensor是否与另一个MSTensor不相等 | bool operator!=(const MSTensor &tensor) const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 设置MSTensor的Shape | void SetShape(const std::vector &shape) | `Tensor.shape `__ | +| MSTensor | 设置MSTensor的Shape | void SetShape(const std::vector &shape) | `Tensor.shape `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 设置MSTensor的DataType | void SetDataType(enum DataType data_type) | `Tensor.dtype `__ | +| MSTensor | 设置MSTensor的DataType | void SetDataType(enum DataType data_type) | `Tensor.dtype `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 设置MSTensor的名字 | void SetTensorName(const std::string &name) | `Tensor.name `__ | +| MSTensor | 设置MSTensor的名字 | void SetTensorName(const std::string &name) | `Tensor.name `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 设置MSTensor数据所属的内存池 | void SetAllocator(std::shared_ptr allocator) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 获取MSTensor数据所属的内存池 | std::shared_ptr allocator() const | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 设置MSTensor数据的format | void SetFormat(mindspore::Format format) | `Tensor.format `__ | +| MSTensor | 设置MSTensor数据的format | void SetFormat(mindspore::Format format) | `Tensor.format `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| MSTensor | 获取MSTensor数据的format | mindspore::Format format() const | `Tensor.format `__ | +| MSTensor | 获取MSTensor数据的format | mindspore::Format format() const | `Tensor.format `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 设置指向MSTensor数据的指针 | void SetData(void \*data, bool own_data = true) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -332,15 +332,15 @@ MindSpore Lite API 支持情况汇总 +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | MSTensor | 设置MSTensor的量化参数 | void SetQuantParams(std::vector quant_params) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | 构造ModelGroup对象,指示共享工作空间内存或共享权重内存,默认共享工作空间内存 | ModelGroup(ModelGroupFlag flags = ModelGroupFlag::kShareWorkspace) | `ModelGroup `__ | +| ModelGroup | 构造ModelGroup对象,指示共享工作空间内存或共享权重内存,默认共享工作空间内存 | ModelGroup(ModelGroupFlag flags = ModelGroupFlag::kShareWorkspace) | `ModelGroup `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | 共享权重内存时,添加需要共享权重内存的模型对象 | Status AddModel(const std::vector &model_list) | `ModelGroup.add_model `__ | +| ModelGroup | 共享权重内存时,添加需要共享权重内存的模型对象 | Status AddModel(const std::vector &model_list) | `ModelGroup.add_model `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | 共享工作空间内存时,添加需要共享工作空间内存的模型路径 | Status AddModel(const std::vector &model_path_list) | `ModelGroup.add_model `__ | +| ModelGroup | 共享工作空间内存时,添加需要共享工作空间内存的模型路径 | Status AddModel(const std::vector &model_path_list) | `ModelGroup.add_model `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ModelGroup | 共享工作空间内存时,添加需要共享工作空间内存的模型缓存 | Status AddModel(const std::vector> &model_buff_list) | | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| ModelGroup | 共享工作空间内存时,计算最大的工作空间内存大小 | Status CalMaxSizeOfWorkspace(ModelType model_type, const std::shared_ptr &ms_context) | `ModelGroup.cal_max_size_of_workspace `__ | +| ModelGroup | 共享工作空间内存时,计算最大的工作空间内存大小 | Status CalMaxSizeOfWorkspace(ModelType model_type, const std::shared_ptr &ms_context) | `ModelGroup.cal_max_size_of_workspace `__ | +---------------------+---------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/docs/lite/docs/source_en/advanced/image_processing.md b/docs/lite/docs/source_en/advanced/image_processing.md index 18e6d253c35b6874ec9ac12ef2bf771b2ec25653..d395667cfee3e99d088d2d085efd617716020c3c 100644 --- a/docs/lite/docs/source_en/advanced/image_processing.md +++ b/docs/lite/docs/source_en/advanced/image_processing.md @@ -1,13 +1,11 @@ # Data Preprocessing -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/image_processing.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/image_processing.md) ## Overview The main purpose of image preprocessing is to eliminate irrelevant information in the image, restore useful real information, enhance the detectability of related information and simplify data to the greatest extent, thereby improving the reliability of feature extraction, image segmentation, matching and recognition. Here, by creating a LiteMat object, the image data is processed before inference to meet the data format requirements for model inference. -The process is as follows: - ## Importing Image Preprocessing Function Library ```cpp @@ -17,7 +15,7 @@ The process is as follows: ## Initializing the Image -Here, the [InitFromPixel](https://www.mindspore.cn/lite/api/en/master/generate/function_mindspore_dataset_InitFromPixel-1.html) function in the `image_process.h` file is used to initialize the image. +Here, the [InitFromPixel](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/function_mindspore_dataset_InitFromPixel-1.html) function in the `image_process.h` file is used to initialize the image. ```cpp bool InitFromPixel(const unsigned char *data, LPixelType pixel_type, LDataType data_type, int w, int h, LiteMat &m) @@ -40,7 +38,7 @@ The image processing operations here can be used in any combination according to ### Resizing Image -Here we use the [ResizeBilinear](https://www.mindspore.cn/lite/api/en/master/generate/function_mindspore_dataset_ResizeBilinear-1.html) function in `image_process.h` to resize the image through a bilinear algorithm. Currently, the supported data type is unit8, and the supported channels are 3 and 1. +Here we use the [ResizeBilinear](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/function_mindspore_dataset_ResizeBilinear-1.html) function in `image_process.h` to resize the image through a bilinear algorithm. Currently, the supported data type is unit8, and the supported channels are 3 and 1. ```cpp bool ResizeBilinear(const LiteMat &src, LiteMat &dst, int dst_w, int dst_h) @@ -62,7 +60,7 @@ ResizeBilinear(lite_mat_bgr, lite_mat_resize, 256, 256); ### Converting the Image Data Type -Here we use the [ConvertTo](https://www.mindspore.cn/lite/api/en/master/generate/function_mindspore_dataset_ConvertTo-1.html) function in `image_process.h` to convert the image data type. Currently, the conversion from uint8 to float is supported. +Here we use the [ConvertTo](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/function_mindspore_dataset_ConvertTo-1.html) function in `image_process.h` to convert the image data type. Currently, the conversion from uint8 to float is supported. ```cpp bool ConvertTo(const LiteMat &src, LiteMat &dst, double scale = 1.0) @@ -84,7 +82,7 @@ ConvertTo(lite_mat_bgr, lite_mat_convert_float); ### Cropping Image Data -Here we use the [Crop](https://www.mindspore.cn/lite/api/en/master/generate/function_mindspore_dataset_Crop-1.html) function in `image_process.h` to crop the image. Currently, channels 3 and 1 are supported. +Here we use the [Crop](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/function_mindspore_dataset_Crop-1.html) function in `image_process.h` to crop the image. Currently, channels 3 and 1 are supported. ```cpp bool Crop(const LiteMat &src, LiteMat &dst, int x, int y, int w, int h) @@ -106,7 +104,7 @@ Crop(lite_mat_bgr, lite_mat_cut, 16, 16, 224, 224); ### Normalizing Image Data -In order to eliminate the dimensional influence among the data indicators and solve the comparability problem among the data indicators through standardization processing is adopted, here is the use of the [SubStractMeanNormalize](https://www.mindspore.cn/lite/api/en/master/generate/function_mindspore_dataset_SubStractMeanNormalize-1.html) function in `image_process.h` to normalize the image data. +In order to eliminate the dimensional influence among the data indicators and solve the comparability problem among the data indicators through standardization processing is adopted, here is the use of the [SubStractMeanNormalize](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/function_mindspore_dataset_SubStractMeanNormalize-1.html) function in `image_process.h` to normalize the image data. ```cpp bool SubStractMeanNormalize(const LiteMat &src, LiteMat &dst, const std::vector &mean, const std::vector &std) diff --git a/docs/lite/docs/source_en/advanced/micro.md b/docs/lite/docs/source_en/advanced/micro.md index 2ca72c0e072fb93d2f139f740ad80117a704878c..c32516947bfcabd79d4fd489915f92312cbbf921 100644 --- a/docs/lite/docs/source_en/advanced/micro.md +++ b/docs/lite/docs/source_en/advanced/micro.md @@ -1,6 +1,6 @@ # Performing Inference or Training on MCU or Small Systems -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/micro.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/micro.md) ## Overview @@ -17,8 +17,8 @@ Deploying a model for inference or training via the Micro involves the following ### Overview -The Micro configuration item in the parameter configuration file is configured via the MindSpore Lite conversion tool `convert_lite`. -This chapter describes the functions related to code generation in the conversion tool. For details about how to use the conversion tool, see [Converting Models for Inference](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html). +The Micro configuration item in the parameter configuration file is configured via the MindSpore Lite conversion tool `converter_lite`. +This chapter describes the functions related to code generation in the conversion tool. For details about how to use the conversion tool, see [Converting Models for Inference](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html). ### Preparing Environment @@ -32,11 +32,11 @@ The following describes how to prepare the environment for using the conversion You can obtain the conversion tool in either of the following ways: - - Download [Release Version](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) from the MindSpore official website. + - Download [Release Version](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) from the MindSpore official website. Download the release package whose OS is Linux-x86_64 and hardware platform is CPU. - - Start from the source code for [Building MindSpore Lite](https://www.mindspore.cn/lite/docs/en/master/build/build.html). + - Start from the source code for [Building MindSpore Lite](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html). 3. Decompress the downloaded package. @@ -103,7 +103,7 @@ The following describes how to prepare the environment for using the conversion CONVERT RESULT SUCCESS:0 ``` - For details about the parameters related to converter_lite, see [Converter Parameter Description](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html#parameter-description). + For details about the parameters related to converter_lite, see [Converter Parameter Description](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html#parameter-description). After the conversion tool is successfully executed, the generated code is saved in the specified `outputFile` directory. In this example, the mnist folder is in the current conversion directory. The content is as follows: @@ -208,7 +208,7 @@ Table 1: micro_param Parameter Definition ``` - In the configuration file, `[micro_param]` in the first line indicates that the subsequent variable parameters belong to the micro configuration item `micro_param`. These parameters are used to control code generation, and the meaning of each parameter is shown in Table 1. `[model_param]` indicates that the subsequent variable parameters belong to the specify model configuration item`model_param`. These parameters are used to control the conversion of different models. The range of parameters includes the necessary parameters supported by `converter_lite`. + In the configuration file, `[micro_param]` in the first line indicates that the subsequent variable parameters belong to the micro configuration item `micro_param`. These parameters are used to control code generation, and the meaning of each parameter is shown in Table 1. `[model_param]` indicates that the subsequent variable parameters belong to the specify model configuration item `model_param`. These parameters are used to control the conversion of different models. The range of parameters includes the necessary parameters supported by `converter_lite`. In this example, we will generate single model inference code for Linux systems with the underlying architecture x86_64, so set `target=x86` to declare that the generated inference code will be used for Linux systems with the underlying architecture x86_64. 3. Prepare the model to generate inference code @@ -228,7 +228,7 @@ Table 1: micro_param Parameter Definition CONVERT RESULT SUCCESS:0 ``` - For details about the parameters related to converter_lite, see [Converter Parameter Description](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html#parameter-description). + For details about the parameters related to converter_lite, see [Converter Parameter Description](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html#parameter-description). After the conversion tool is successfully executed, the generated code is saved in the specified `save_path` + `project_name` directory. In this example, the mnist folder is in the current conversion directory. The content is as follows: @@ -277,7 +277,7 @@ Table 1: micro_param Parameter Definition Usually, when generating code, you can reduce the probability of errors in the deployment process by configuring the model input shape as the input shape for actual inference. When the model contains a `Shape` operator or the original model has a non-fixed input shape value, the input shape value of the model must be configured to support the relevant shape optimization and code generation. -The `--inputShape=` command of the conversion tool can be used to configure the input shape of the generated code. For specific parameter meanings, please refer to [Conversion Tool Instructions](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html). +The `--inputShape=` command of the conversion tool can be used to configure the input shape of the generated code. For specific parameter meanings, please refer to [Conversion Tool Instructions](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html). ### (Optional) Dynamic Shape Configuration @@ -325,7 +325,7 @@ support_parallel=true #### Involved Calling Interfaces By integrating the code and calling the following interfaces, the user can configure the multi-threaded inference of the model. -For specific interface parameters, refer to [API Document](https://www.mindspore.cn/lite/api/en/master/index.html). +For specific interface parameters, refer to [API Document](https://www.mindspore.cn/lite/api/en/r2.6.0/index.html). Table 2: API Interface for Multi-threaded Configuration @@ -349,12 +349,12 @@ At present, this function is only enabled when the `target` is configured as x86 In MCU scenarios such as Cortex-M, limited by the memory size and computing power of the device, Int8 quantization operators are usually used for deployment inference to reduce the runtime memory size and speed up operations. -If the user already has an Int8 full quantitative model, you can refer to the section on [Generating Inference Code by Running converter_lite](https://www.mindspore.cn/lite/docs/en/master/advanced/micro.html#generating-inference-code-by-running-converter-lite) to try to generate Int8 quantitative inference code directly without reading this chapter. +If the user already has an Int8 full quantitative model, you can refer to the section on [Generating Inference Code by Running converter_lite](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/micro.html#generating-inference-code-by-running-converter-lite) to try to generate Int8 quantitative inference code directly without reading this chapter. In general, the user has only one trained float32 model. To generate Int8 quantitative inference code at this time, it is necessary to cooperate with the post quantization function of the conversion tool to generate code. See the following for specific steps. #### Configuration -Int8 quantization inference code can be generated by configuring quantization control parameters in the configuration file. For the description of quantization control parameters (`universal quantization parameters` and `full quantization parameters`), please refer to the [Quantization](https://www.mindspore.cn/lite/docs/en/master/advanced/quantization.html). +Int8 quantization inference code can be generated by configuring quantization control parameters in the configuration file. For the description of quantization control parameters (universal quantization parameter `common_quant_param` and full quantization parameter `full_quant_param`), please refer to the [Quantization](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/quantization.html). An example of Int8 quantitative inference code generation configuration file for a `Cortex-M` platform is as follows: @@ -402,16 +402,16 @@ target_device=DSP - Currently, it only supports full quantitative inference code generation. -- The `target_device` of the `full quantization parameter` in the configuration file usually needs to be set to DSP to support more operators for post quantization. +- The `target_device` of the full quantization parameter `full_quant_param` in the configuration file usually needs to be set to DSP to support more operators for post quantization. -- At present, Micro has supported 8 Int8 quantization operators(add, batchnorm, concat, conv, convolution, matmul, resize, slice). If a related quantization operator does not support it when generating code, you can circumvent the operator through the `skip_quant_node` of the `universal quantization parameter`. The circumvented operator node still uses float32 inference. +- At present, Micro has supported 8 Int8 quantization operators(add, batchnorm, concat, conv, convolution, matmul, resize, slice). If a related quantization operator does not support it when generating code, you can circumvent the operator through the `skip_quant_node` of the universal quantization parameter `common_quant_param`. The circumvented operator node still uses float32 inference. ## Generating Model Training Code ### Overview The training code can be generated for the input model by using the MindSpore Lite conversion tool `converter_lite` and configuring the Micro configuration item in the parameter configuration file of the conversion tool. -This chapter describes the functions related to code generation in the conversion tool. For details about how to use the conversion tool, see [Converting Models for Training](https://www.mindspore.cn/lite/docs/en/master/train/converter_train.html). +This chapter describes the functions related to code generation in the conversion tool. For details about how to use the conversion tool, see [Converting Models for Training](https://www.mindspore.cn/lite/docs/en/r2.6.0/train/converter_train.html). ### Preparing Environment @@ -491,7 +491,7 @@ For preparing environment section, refer to the [above](#preparing-environment), After generating model inference code, you need to obtain the `Micro` lib on which the generated inference code depends before performing integrated development on the code. The inference code of different platforms depends on the `Micro` lib of the corresponding platform. You need to specify the platform via the micro configuration item `target` based on the platform in use when generating code, and obtain the `Micro` lib of the platform when obtaining the inference package. -You can download the [Release Version](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) of the corresponding platform from the MindSpore official website. +You can download the [Release Version](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) of the corresponding platform from the MindSpore official website. In chapter [Generating Model Inference Code](#generating-model-inference-code), we obtain the model inference code of the Linux platform with the x86_64 architecture. The `Micro` lib on which the code depends is the release package used by the conversion tool. In the release package, the following content depended by the inference code: @@ -523,7 +523,7 @@ Users can refer to the benchmark routine to integrate and develop the `src` infe ### Calling Interface of Inference Code -The following is the general calling interface of the inference code. For a detailed description of the interface, please refer to the [API documentation](https://www.mindspore.cn/lite/api/en/master/index.html). +The following is the general calling interface of the inference code. For a detailed description of the interface, please refer to the [API documentation](https://www.mindspore.cn/lite/api/en/r2.6.0/index.html). Table 3: Inference Common API Interface @@ -559,9 +559,9 @@ Different platforms have differences in code integration and compilation deploym - For the MCU of the cortex-M architecture, see [Performing Inference on the MCU](#performing-inference-on-the-mcu) -- For the Linux platform with the x86_64 architecture, see [Compilation and Deployment on Linux_x86_64 Platform](https://gitee.com/mindspore/mindspore/tree/master/mindspore/lite/examples/quick_start_micro/mnist_x86) +- For the Linux platform with the x86_64 architecture, see [Compilation and Deployment on Linux_x86_64 Platform](https://gitee.com/mindspore/mindspore/tree/v2.6.0/mindspore/lite/examples/quick_start_micro/mnist_x86) -- For details about how to compile and deploy arm32 or arm64 on the Android platform, see [Compilation and Deployment on Android Platform](https://gitee.com/mindspore/mindspore/tree/master/mindspore/lite/examples/quick_start_micro/mobilenetv2_arm64) +- For details about how to compile and deploy arm32 or arm64 on the Android platform, see [Compilation and Deployment on Android Platform](https://gitee.com/mindspore/mindspore/tree/v2.6.0/mindspore/lite/examples/quick_start_micro/mobilenetv2_arm64) - For compilation and deployment on the OpenHarmony platform, see [Executing Inference on Light Harmony Devices](#executing-inference-on-light-harmony-devices) @@ -619,11 +619,11 @@ mnist # Specified name of generated code root directory The STM32F767 uses the Cortex-M7 architecture. You can obtain the `Micro` lib of the architecture in either of the following ways: -- Download [Release Version](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) from the MindSpore official website. +- Download [Release Version](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) from the MindSpore official website. You need to download the release package whose OS is None and hardware platform is Cortex-M7. -- Start from the source code for [Building MindSpore Lite](https://www.mindspore.cn/lite/docs/en/master/build/build.html). +- Start from the source code for [Building MindSpore Lite](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html). You can run the `MSLITE_MICRO_PLATFORM=cortex-m7 bash build.sh -I x86_64` command to compile the Cortex-M7 release package. @@ -796,7 +796,7 @@ This chapter uses the STM32F767 startup project as an example to describe how to - In the `MCU/MPU Selector` window, search for and select `STM32F767IGT6`, and click `Start Project` to create a project for the chip -- On the `Project Manager` page, configure the project name and the path of the generated project, and select `EWARM` in `Toolchain / IDE` to generate the IAR project +- On the `Project Manager` page, configure the project name and the path of the generated project, and select `Makefile` at the `Toolchain / IDE` option and generate a `MakeFile` project specified by 👈. - Click `GENERATE CODE` above to generate code @@ -898,7 +898,7 @@ This chapter uses the STM32F767 startup project as an example to describe how to In this example, to facilitate reading the inference result by using the burner, variables are defined in a customized section (`myram`). You can set the customized section in the following way or ignore the declaration: obtaining the inference result through serial ports or other interactive modes. To set a customized section, perform the following steps: - Modify the `MEMORY` section in the `STM32F767IGTx_FLASH.ld` file, and add a customized memory segment `MYRAM`. (In this example, add 4 to the `RAM` memory start address to free up memory for `MYRAM`). Then add a customized `myram` segment declaration to the `SectionS` segment. + Modify the `MEMORY` section in the `STM32F767IGTx_FLASH.ld` file, and add a customized memory segment `MYRAM`. (In this example, add 4 to the `RAM` memory start address to free up memory for `MYRAM`). Then add a customized `myram` segment declaration to the `SECTIONS` segment. ```text MEMORY @@ -956,7 +956,7 @@ This chapter uses the STM32F767 startup project as an example to describe how to bash ${STMSTM32CubePrg_PATH}/bin/STM32_Programmer.sh -c port=SWD -w build/test_stm767.bin 0x08000000 -s 0x08000000 ``` - ${STMSTM32CubePrg_PATH is}: installation path of `STMSTM32CubePrg`. For details about the parameters in the command, see the `STMSTM32CubePrg` user manual. + ${STMSTM32CubePrg_PATH} is installation path of `STMSTM32CubePrg`. For details about the parameters in the command, see the `STMSTM32CubePrg` user manual. #### Inference Result Verification @@ -967,7 +967,7 @@ On the PC, use `STLink` to connect to a development board where programs have be bash ${STMSTM32CubePrg_PATH is }/bin/STM32_Programmer.sh -c port=SWD model=HOTPLUG --upload 0x20000000 0x1 ret.bin ``` -${STMSTM32CubePrg_PATH is}: installation path of `STMSTM32CubePrg`. For details about the parameters in the command, see the `STMSTM32CubePrg` user manual. +${STMSTM32CubePrg_PATH} is installation path of `STMSTM32CubePrg`. For details about the parameters in the command, see the `STMSTM32CubePrg` user manual. The read data is saved in the `ret.bin` file and run `cat ret.bin`. If the board inference is successful and `ret.bin` stores `1`, the following information is displayed: @@ -1004,7 +1004,7 @@ For details about how to develop light Harmony applications, see [Running Hello └── src ``` -Download the [precompiled inference runtime package](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) for OpenHarmony and decompress it to any Harmony source code path. Compile Build.gn file: +Download the [precompiled inference runtime package](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) for OpenHarmony and decompress it to any Harmony source code path. Compile BUILD.gn file: ```text import("//build/lite/config/component/lite_component.gni") @@ -1123,7 +1123,7 @@ name: int8toft32_Softmax-7_post0/output-0, DataType: 43, Elements: 10, Shape: [1 ## Custom Kernel -Please refer to [Custom Kernel](https://www.mindspore.cn/lite/docs/en/master/advanced/third_party/register.html) to understand the basic concepts before using. +Please refer to [Custom Kernel](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/third_party/register.html) to understand the basic concepts before using. Micro currently only supports the registration and implementation of custom operators of custom type, and does not support the registration and custom implementation of built-in operators (such as conv2d and fc). We use Hi3516D board as an example to show you how to use kernel register in Micro. @@ -1135,7 +1135,7 @@ The manner that the model generates code is consistent with that of the non-cust where target sets to be ARM32. -### Implementing custom kernel by users +### Implementing Custom Kernel by Users The previous step generates the source code directory under the specified path with a header file called `src/registered_kernel.h` that specifies the function declarations for the custom operator. @@ -1143,7 +1143,7 @@ The previous step generates the source code directory under the specified path w int CustomKernel(TensorC *inputs, int input_num, TensorC *outputs, int output_num, CustomParameter *param); ``` -Users need to implement this function and add their source files to the cmake project. For example, we provide the custom kernel example dynamic library libmicro_nnie.so that supports NNIE from Hysis, which is included in the [official download page](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) "NNIE inference runtime lib, benchmark tool" component. Users need to modify the CMakeLists.txt of the generated code, add the name and path of the linked library. +Users need to implement this function and add their source files to the cmake project. For example, we provide the custom kernel example dynamic library libmicro_nnie.so that supports NNIE from Hysis, which is included in the [official download page](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) "NNIE inference runtime lib, benchmark tool" component. Users need to modify the CMakeLists.txt of the generated code, add the name and path of the linked library. ``` shell @@ -1155,7 +1155,7 @@ target_link_libraries(benchmark net micro_nnie nnie mpi VoiceEngine upvqe dnvqe ``` -In the generated `benchmark/benchmark.c` file, add the [NNIE device related initialization code](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/test/config_level0/micro/svp_sys_init.c) before and after calling the main function. +In the generated `benchmark/benchmark.c` file, add the [NNIE device related initialization code](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/test/config_level0/micro/svp_sys_init.c) before and after calling the main function. Finally, we compile the source code: ``` shell @@ -1188,7 +1188,7 @@ Except for MCU, micro inference is a inference model that separates model struct ### Exporting Inference Model -Users can directly refer to [Device-side training](https://www.mindspore.cn/lite/docs/en/master/train/runtime_train_cpp.html). +Users can directly refer to [Device-side training](https://www.mindspore.cn/lite/docs/en/r2.6.0/train/runtime_train_cpp.html). ### Generating Inference Code @@ -1202,7 +1202,7 @@ keep_original_weight=false # the names of those weight-tensors whose shape is changeable, only embedding-table supports change now. # the parameter is used to collaborate with lite-train. If set, `keep_original_weight` must be true. -changeable_weights_name=name0, +changeable_weights_name=name0,name1 ``` `keep_original_weight` is a key attribute that ensures consistency in weight, and when combined with training, the attribute must be set `true`. `changeable_weights_name` is used for special scenarios, such as changes in the shape of certain weights. Of course, currently only the number of embedding-table can be changeable. Generally, users do not need to set the attribute. diff --git a/docs/lite/docs/source_en/advanced/quantization.md b/docs/lite/docs/source_en/advanced/quantization.md index 02d48fa4eab69fd7760fe973d990662f7fc6caed..828e53e21ffa903aefe86fa90f232121f202d1aa 100644 --- a/docs/lite/docs/source_en/advanced/quantization.md +++ b/docs/lite/docs/source_en/advanced/quantization.md @@ -1,6 +1,6 @@ # Quantization -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/quantization.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/quantization.md) ## Overview @@ -114,7 +114,7 @@ For the scenarios where the CV model needs to improve the model running speed an To fully quantize the quantization parameters for calculating the activation values, the user needs to provide a calibration dataset. The calibration dataset should preferably come from real inference scenarios that characterize the actual inputs to the model, in the order of 100 - 500, **and the calibration dataset needs to be processed into `NHWC` format**. -For image data, it currently supports the functions of channel adjustment, normalization, scaling, cropping and other preprocessing. The user can set the appropriate [Data Preprocessing](https://www.mindspore.cn/lite/docs/en/master/advanced/quantization.html#data-preprocessing) according to the preprocessing operation required for inference. +For image data, it currently supports the functions of channel adjustment, normalization, scaling, cropping and other preprocessing. The user can set the appropriate [Data Preprocessing Parameters](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/quantization.html#data-preprocessing-parameters) according to the preprocessing operation required for inference. User configuration of full quantization requires at least `[common_quant_param]`, `[data_preprocess_param]`, and `[full_quant_param]`. @@ -223,7 +223,7 @@ target_device=DSP #### Ascend -Ascend quantization needs to configure Ascend-related configuration at [offline conversion](https://www.mindspore.cn/lite/docs/en/master/mindir/converter_tool.html#description-of-parameters) first, i.e. `optimize` needs to be set to `ascend_oriented`, and then configure Ascend related environment variables during conversion. +Ascend quantization needs to configure Ascend-related configuration at [offline conversion](https://www.mindspore.cn/lite/docs/en/r2.6.0/mindir/converter_tool.html#description-of-parameters) first, i.e. `optimize` needs to be set to `ascend_oriented`, and then configure Ascend related environment variables during conversion. **Ascend Fully Quantized Static Shape Parameter Configuration** @@ -245,7 +245,7 @@ Ascend quantization needs to configure Ascend-related configuration at [offline target_device=ASCEND ``` -**Ascend full quantization supports dynamic Shape parameters**. The conversion command needs to set the same inputShape of the calibration dataset, which can be found in [Conversion Tool Parameter Description](https://www.mindspore.cn/lite/docs/en/master/mindir/converter_tool.html#description-of-parameters). +**Ascend full quantization supports dynamic Shape parameters**. The conversion command needs to set the same inputShape of the calibration dataset, which can be found in [Conversion Tool Parameter Description](https://www.mindspore.cn/lite/docs/en/r2.6.0/mindir/converter_tool.html#description-of-parameters). - The general form of the conversion command in the Ascend fully quantized static shape scenario is: @@ -301,7 +301,7 @@ quant_strategy=ACWL ## Configuration Parameter -Post training quantization can be enabled by configuring `configFile` through [Conversion Tool](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html). The configuration file adopts the style of [`INI`](https://en.wikipedia.org/wiki/INI_file), For quantization, configurable parameters include: +Post training quantization can be enabled by configuring `configFile` through [Conversion Tool](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html). The configuration file adopts the style of [`INI`](https://en.wikipedia.org/wiki/INI_file). For quantization, configurable parameters include: - `[common_quant_param]: Public quantization parameters` - `[weight_quant_param]: Fixed bit quantization parameters` diff --git a/docs/lite/docs/source_en/advanced/third_party.rst b/docs/lite/docs/source_en/advanced/third_party.rst index 3e2e3e3b55ef10187e16148fdb3af256b9996a47..d5e73b3a2a71d8d1ee45b088191828950a3a2ebf 100644 --- a/docs/lite/docs/source_en/advanced/third_party.rst +++ b/docs/lite/docs/source_en/advanced/third_party.rst @@ -1,9 +1,9 @@ Third-party Access ================================= -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party.rst - :alt: View Source on Gitee +.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg + :target: https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party.rst + :alt: View Source On Gitee .. toctree:: :maxdepth: 1 diff --git a/docs/lite/docs/source_en/advanced/third_party/ascend_info.md b/docs/lite/docs/source_en/advanced/third_party/ascend_info.md index 6ff8adf0aec9a5d70719612483b06e43d518576b..d254ced40e1bc06ad981ffa655e77a4eeabc7581 100644 --- a/docs/lite/docs/source_en/advanced/third_party/ascend_info.md +++ b/docs/lite/docs/source_en/advanced/third_party/ascend_info.md @@ -1,11 +1,11 @@ # Integrated Ascend -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/ascend_info.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/ascend_info.md) > - The Ascend backend support on device-side version will be deprecated later. For related usage of the Ascend backend, please refer to the cloud-side inference version documentation. -> - [Build Cloud-side MindSpore Lite](https://mindspore.cn/lite/docs/en/master/mindir/build.html) -> - [Cloud-side Model Converter](https://mindspore.cn/lite/docs/en/master/mindir/converter.html) -> - [Cloud-side Benchmark Tool](https://mindspore.cn/lite/docs/en/master/mindir/benchmark.html) +> - [Build Cloud-side MindSpore Lite](https://mindspore.cn/lite/docs/en/r2.6.0/mindir/build.html) +> - [Cloud-side Model Converter](https://mindspore.cn/lite/docs/en/r2.6.0/mindir/converter.html) +> - [Cloud-side Benchmark Tool](https://mindspore.cn/lite/docs/en/r2.6.0/mindir/benchmark.html) This document describes how to use MindSpore Lite to perform inference and use the dynamic shape function on Linux in the Ascend environment. Currently, MindSpore Lite supports the Atlas 200/300/500 inference product and Atlas inference series. @@ -75,7 +75,7 @@ export PYTHONPATH=${TBE_IMPL_PATH}:${PYTHONPATH} MindSpore Lite provides an offline model converter to convert various models (Caffe, ONNX, TensorFlow, and MindIR) into models that can be inferred on the Ascend hardware. First, use the converter to convert a model into an `ms` model. Then, use the runtime inference framework matching the converter to perform inference. The process is as follows: -1. [Download](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) the converter dedicated for Ascend. Currently, only Linux is supported. +1. [Download](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) the converter dedicated for Ascend. Currently, only Linux is supported. 2. Decompress the downloaded package. @@ -115,7 +115,7 @@ First, use the converter to convert a model into an `ms` model. Then, use the ru CONVERT RESULT SUCCESS:0 ``` - For details about parameters of the converter_lite converter, see ["Parameter Description" in Converting Models for Inference](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html#parameter-description). + For details about parameters of the converter_lite converter, see ["Parameter Description" in Converting Models for Inference](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html#parameter-description). Note: If the input shape of the original model is uncertain, specify inputShape when using the converter to convert a model. In addition, set configFile to the value of input_shape_vector parameter in acl_option_cfg_param. The command is as follows: @@ -141,16 +141,16 @@ Table 1 [acl_option_cfg_param] parameter configuration | `dynamic_batch_size` | Optional| Specifies the [dynamic batch size](#dynamic-batch-size) parameter.| String | `"2,4"`| | `dynamic_image_size` | Optional| Specifies the [dynamic image size](#dynamic-image-size) parameter.| String | `"96,96;32,32"` | | `fusion_switch_config_file_path` | Optional| Configure the path and name of the [fusion pattern switch](https://www.hiascend.com/document/detail/zh/canncommercial/700/devtools/auxiliarydevtool/aoepar_16_034.html) file.| String | - | -| `insert_op_config_file_path` | Optional| Inserts the [AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/devaids/auxiliarydevtool/atlasatc_16_0025.html) operator into a model.| String | [AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/devaids/auxiliarydevtool/atlasatc_16_0025.html) configuration file path| +| `insert_op_config_file_path` | Optional| Inserts the [AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha001/devaids/devtools/atc/atlasatc_16_0016.html) operator into a model.| String | [AIPP](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha001/devaids/devtools/atc/atlasatc_16_0016.html) configuration file path| ## Runtime -After obtaining the converted model, use the matching runtime inference framework to perform inference. For details about how to use runtime to perform inference, see [Using C++ Interface to Perform Inference](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html). +After obtaining the converted model, use the matching runtime inference framework to perform inference. For details about how to use runtime to perform inference, see [Using C++ Interface to Perform Inference](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html). ## Executinge the Benchmark MindSpore Lite provides a benchmark test tool, which can be used to perform quantitative (performance) analysis on the execution time consumed by forward inference of the MindSpore Lite model. In addition, you can perform comparative error (accuracy) analysis based on the output of a specified model. -For details about the inference tool, see [benchmark](https://www.mindspore.cn/lite/docs/en/master/tools/benchmark_tool.html). +For details about the inference tool, see [benchmark](https://www.mindspore.cn/lite/docs/en/r2.6.0/tools/benchmark_tool.html). - Performance analysis @@ -170,7 +170,7 @@ For details about the inference tool, see [benchmark](https://www.mindspore.cn/l ### Dynamic Shape -The batch size is not fixed in certain scenarios. For example, in the target detection+facial recognition cascade scenario, the number of detected targets is subject to change, which means that the batch size of the targeted recognition input is dynamic. It would be a great waste of compute resources to perform inferences using the maximum batch size or image size. Thanks to Lite's support for dynamic batch size and dynamic image size on the Atlas 200/300/500 inference product, you can configure the [acl_option_cfg_param] dynamic parameter through configFile to convert a model into an `ms` model, and then use the [resize](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html#resizing-the-input-dimension) function of the model to change the input shape during inference. +The batch size is not fixed in certain scenarios. For example, in the target detection+facial recognition cascade scenario, the number of detected targets is subject to change, which means that the batch size of the targeted recognition input is dynamic. It would be a great waste of compute resources to perform inferences using the maximum batch size or image size. Thanks to Lite's support for dynamic batch size and dynamic image size on the Atlas 200/300/500 inference product, you can configure the [acl_option_cfg_param] dynamic parameter through configFile to convert a model into an `ms` model, and then use the [resize](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html#resizing-the-input-dimension) function of the model to change the input shape during inference. #### Dynamic Batch Size @@ -204,7 +204,7 @@ The batch size is not fixed in certain scenarios. For example, in the target det - Inference - After the dynamic batch size is enabled, during model inference, the input shape is corresponding to the size configured in converter. To change the input shape, use the model [resize](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html#resizing-the-input-dimension) function. + After the dynamic batch size is enabled, during model inference, the input shape is corresponding to the size configured in converter. To change the input shape, use the model [resize](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html#resizing-the-input-dimension) function. - Precautions @@ -245,7 +245,7 @@ The batch size is not fixed in certain scenarios. For example, in the target det - Inference - After the dynamic image size is enabled, during model inference, the input shape is corresponding to the size configured in converter. To change the input shape, use the model [resize](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html#resizing-the-input-dimension) function. + After the dynamic image size is enabled, during model inference, the input shape is corresponding to the size configured in converter. To change the input shape, use the model [resize](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html#resizing-the-input-dimension) function. - Precautions @@ -255,4 +255,4 @@ The batch size is not fixed in certain scenarios. For example, in the target det ## Supported Operators -For details about the supported operators, see [Lite Operator List](https://www.mindspore.cn/lite/docs/en/master/reference/operator_list_lite.html). +For details about the supported operators, see [Lite Operator List](https://www.mindspore.cn/lite/docs/en/r2.6.0/reference/operator_list_lite.html). diff --git a/docs/lite/docs/source_en/advanced/third_party/asic.rst b/docs/lite/docs/source_en/advanced/third_party/asic.rst index 1ee00b2ef4f3d261dfa5811228ed5ba155470843..a6c0857d044eb27e356784c2b2a400de855494c5 100644 --- a/docs/lite/docs/source_en/advanced/third_party/asic.rst +++ b/docs/lite/docs/source_en/advanced/third_party/asic.rst @@ -1,9 +1,9 @@ Application Specific Integrated Circuit Integration Instructions ================================================================ -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/asic.rst - :alt: View Source on Gitee +.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg + :target: https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/asic.rst + :alt: View Source On Gitee .. toctree:: :maxdepth: 1 diff --git a/docs/lite/docs/source_en/advanced/third_party/converter_register.md b/docs/lite/docs/source_en/advanced/third_party/converter_register.md index 9067f4fe3cd62d7256e3011105912bf270e1505e..3e15a7c15397b4bd1ba2ed218715b92c22f2e6f5 100644 --- a/docs/lite/docs/source_en/advanced/third_party/converter_register.md +++ b/docs/lite/docs/source_en/advanced/third_party/converter_register.md @@ -1,34 +1,36 @@ # Building Custom Operators Offline -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/converter_register.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/converter_register.md) ## Overview -Our [Conversion Tool](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html) is a highly flexible tool. In addition to the basic ability of model converter, we have designed a set of registration mechanism, which allows users to expand, including node-parse extension, model-parse extension and graph-optimization extension. The users can combined them as needed to achieve their own intention. +MindSpore Lite [Conversion Tool](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html), in addition to the basic model conversion function, also supports user-defined model optimization and construction to generate models with user-defined operators. -node-parse extension: The users can define the process to parse a certain node of a model by themselves, which only support ONNX, CAFFE, TF and TFLITE. The related interface is [NodeParser](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_converter_NodeParser.html), [NodeParserRegistry](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_registry_NodeParserRegistry.html). -model-parse extension: The users can define the process to parse a model by themselves, which only support ONNX, CAFFE, TF and TFLITE. The related interface is [ModelParser](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_converter_ModelParser.html), [ModelParserRegistry](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_registry_ModelParserRegistry.html). -graph-optimization extension: After parsing a model, a graph structure defined by MindSpore will show up and then, the users can define the process to optimize the parsed graph. The related interfaces are [PassBase](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_registry_PassBase.html), [PassPosition](https://mindspore.cn/lite/api/en/master/generate/enum_mindspore_registry_PassPosition-1.html), [PassRegistry](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_registry_PassRegistry.html). +We have designed a set of registration mechanism, which allows users to expand, including node-parse extension, model-parse extension and graph-optimization extension. The users can combined them as needed to achieve their own intention. -> The node-parse extension needs to rely on the flatbuffers, protobuf and the serialization files of third-party frameworks, at the same time, the version of flatbuffers and the protobuf needs to be consistent with that of the released package, the serialized files must be compatible with that used by the released package. Note that the flatbuffers, protobuf and the serialization files are not provided in the released package, users need to compile and generate the serialized files by themselves. The users can obtain the basic information about [flabuffers](https://gitee.com/mindspore/mindspore/blob/master/cmake/external_libs/flatbuffers.cmake), [probobuf](https://gitee.com/mindspore/mindspore/blob/master/cmake/external_libs/protobuf.cmake), [ONNX prototype file](https://gitee.com/mindspore/mindspore/tree/master/third_party/proto/onnx), [CAFFE prototype file](https://gitee.com/mindspore/mindspore/tree/master/third_party/proto/caffe), [TF prototype file](https://gitee.com/mindspore/mindspore/tree/master/third_party/proto/tensorflow) and [TFLITE prototype file](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/tools/converter/parser/tflite/schema.fbs) from the [MindSpore WareHouse](https://gitee.com/mindspore/mindspore/tree/master). +node-parse extension: The users can define the process to parse a certain node of a model by themselves, which only support ONNX, CAFFE, TF and TFLITE. The related interface is [NodeParser](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_converter_NodeParser.html), [NodeParserRegistry](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_registry_NodeParserRegistry.html). +model-parse extension: The users can define the process to parse a model by themselves, which only support ONNX, CAFFE, TF and TFLITE. The related interface is [ModelParser](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_converter_ModelParser.html), [ModelParserRegistry](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_registry_ModelParserRegistry.html). +graph-optimization extension: After parsing a model, a graph structure defined by MindSpore will show up and then, the users can define the process to optimize the parsed graph. The related interfaces are [PassBase](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_registry_PassBase.html), [PassPosition](https://mindspore.cn/lite/api/en/r2.6.0/generate/enum_mindspore_registry_PassPosition-1.html), [PassRegistry](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_registry_PassRegistry.html). + +> The node-parse extension needs to rely on the flatbuffers, protobuf and the serialization files of third-party frameworks, at the same time, the version of flatbuffers and the protobuf needs to be consistent with that of the released package, the serialized files must be compatible with that used by the released package. Note that the flatbuffers, protobuf and the serialization files are not provided in the released package, users need to compile and generate the serialized files by themselves. The users can obtain the basic information about [flabuffers](https://gitee.com/mindspore/mindspore/blob/v2.6.0/cmake/external_libs/flatbuffers.cmake), [probobuf](https://gitee.com/mindspore/mindspore/blob/v2.6.0/cmake/external_libs/protobuf.cmake), [ONNX prototype file](https://gitee.com/mindspore/mindspore/tree/v2.6.0/third_party/proto/onnx), [CAFFE prototype file](https://gitee.com/mindspore/mindspore/tree/v2.6.0/third_party/proto/caffe), [TF prototype file](https://gitee.com/mindspore/mindspore/tree/v2.6.0/third_party/proto/tensorflow) and [TFLITE prototype file](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/tools/converter/parser/tflite/schema.fbs) from the [MindSpore WareHouse](https://gitee.com/mindspore/mindspore/tree/v2.6.0). > -> MindSpore Lite alse providers a series of registration macros to facilitate user access. These macros include node-parse registration [REG_NODE_PARSER](https://www.mindspore.cn/lite/api/en/master/generate/define_node_parser_registry.h_REG_NODE_PARSER-1.html), model-parse registration [REG_MODEL_PARSER](https://www.mindspore.cn/lite/api/en/master/generate/define_model_parser_registry.h_REG_MODEL_PARSER-1.html), graph-optimization registration [REG_PASS](https://www.mindspore.cn/lite/api/en/master/generate/define_pass_registry.h_REG_PASS-1.html) and graph-optimization scheduled registration [REG_SCHEDULED_PASS](https://www.mindspore.cn/lite/api/en/master/generate/define_pass_registry.h_REG_SCHEDULED_PASS-1.html) +> MindSpore Lite alse providers a series of registration macros to facilitate user access. These macros include node-parse registration [REG_NODE_PARSER](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_node_parser_registry.h_REG_NODE_PARSER-1.html), model-parse registration [REG_MODEL_PARSER](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_model_parser_registry.h_REG_MODEL_PARSER-1.html), graph-optimization registration [REG_PASS](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_pass_registry.h_REG_PASS-1.html) and graph-optimization scheduled registration [REG_SCHEDULED_PASS](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_pass_registry.h_REG_SCHEDULED_PASS-1.html) -The expansion capability of MindSpore Lite conversion tool only support on Linux system currently. +The expansion capability of MindSpore Lite conversion tool only supports on Linux system currently. In this chapter, we will show the users a sample of extending MindSpore Lite converter tool, covering the example of expanding node, example of optimizing graph, compiling and linking. The example will help the users understand the extension ability as soon as possible. > Due to that model-parse extension is a modular extension ability, the chapter will not introduce in details. However, we still provide the users with a simplified unit case for inference. -The chapter takes a [add.tflite](https://download.mindspore.cn/model_zoo/official/lite/quick_start/add.tflite), which only includes an opreator of adding, as an example. We will show the users how to convert the single operator of adding to that of [Custom](https://www.mindspore.cn/lite/docs/en/master/advanced/third_party/register_kernel.html#custom-operators) and finally obtain a model which only includs a single operator of custom. +The chapter takes a [add.tflite](https://download.mindspore.cn/model_zoo/official/lite/quick_start/add.tflite), which only includes an opreator of adding, as an example. We will show the users how to convert the single operator of adding to that of [Custom](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/third_party/register_kernel.html#custom-operators) and finally obtain a model which only includs a single operator of custom. -The code related to the example can be obtained from the path [mindspore/lite/examples/converter_extend](https://gitee.com/mindspore/mindspore/tree/master/mindspore/lite/examples/converter_extend). +The code related to the example can be obtained from the path [mindspore/lite/examples/converter_extend](https://gitee.com/mindspore/mindspore/tree/v2.6.0/mindspore/lite/examples/converter_extend). ## Node Extension -1. Self-defined node-parse: The users need to inherit the base class [NodeParser](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_converter_NodeParser.html), and then, choose a interface to override according to model frameworks. +1. Self-defined node-parse: The users need to inherit the base class [NodeParser](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_converter_NodeParser.html), and then, choose a interface to override according to model frameworks. -2. Node-parse Registration: The users can directly call the registration interface [REG_NODE_PARSER](https://www.mindspore.cn/lite/api/en/master/generate/define_node_parser_registry.h_REG_NODE_PARSER-1.html), so that the self-defined node-parse will be registered in the converter tool of MindSpore Lite. +2. Node-parse Registration: The users can directly call the registration interface [REG_NODE_PARSER](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_node_parser_registry.h_REG_NODE_PARSER-1.html), so that the self-defined node-parse will be registered in the converter tool of MindSpore Lite. ```c++ class AddParserTutorial : public NodeParser { // inherit the base class @@ -40,20 +42,20 @@ class AddParserTutorial : public NodeParser { // inherit the base class const std::unique_ptr &tflite_model) override; }; -REG_NODE_PARSER(kFmkTypeTflite, ADD, std::make_shared()); // call the registration macro +REG_NODE_PARSER(kFmkTypeTflite, ADD, std::make_shared()); // call the registration interface ``` -For the sample code, please refer to [node_parser](https://gitee.com/mindspore/mindspore/tree/master/mindspore/lite/examples/converter_extend/node_parser). +For the sample code, please refer to [node_parser](https://gitee.com/mindspore/mindspore/tree/v2.6.0/mindspore/lite/examples/converter_extend/node_parser). ## Model Extension -For the sample code, please refer to the unit case [ModelParserRegistryTest](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc). +For the sample code, please refer to the unit case [ModelParserRegistryTest](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/test/ut/tools/converter/registry/model_parser_registry_test.cc). ### Optimization Extension -1. Self-defined Pass: The users need to inherit the base class [PassBase](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_registry_PassBase.html), and override the interface function [Execute](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_dataset_Execute.html). +1. Self-defined Pass: The users need to inherit the base class [PassBase](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_registry_PassBase.html), and override the interface function [Execute](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_dataset_Execute.html). -2. Pass Registration: The users can directly call the registration interface [REG_PASS](https://www.mindspore.cn/lite/api/en/master/generate/define_pass_registry.h_REG_PASS-1.html), so that the self-defined pass can be registered in the converter tool of MindSpore Lite. +2. Pass Registration: The users can directly call the registration interface [REG_PASS](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_pass_registry.h_REG_PASS-1.html), so that the self-defined pass can be registered in the converter tool of MindSpore Lite. ```c++ class PassTutorial : public registry::PassBase { // inherit the base class @@ -73,13 +75,13 @@ REG_PASS(PassTutorial, opt::PassTutorial) // register PassBase's sub REG_SCHEDULED_PASS(POSITION_BEGIN, {"PassTutorial"}) // register scheduling logic ``` -For the sample code, please refer to [pass](https://gitee.com/mindspore/mindspore/tree/master/mindspore/lite/examples/converter_extend/pass). +For the sample code, please refer to [pass](https://gitee.com/mindspore/mindspore/tree/v2.6.0/mindspore/lite/examples/converter_extend/pass). -> In the offline phase of conversion, we will infer the basic information of output tensors of each node of the model, including the format, data type and shape. So, in this phase, users need to provide the inferring process of self-defined operator. Here, users can refer to [Operator Infershape Extension](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html#operator-infershape-extension). +> In the offline phase of conversion, we will infer the basic information of output tensors of each node of the model, including the format, data type and shape. So, in this phase, users need to provide the inferring process of self-defined operator. Here, users can refer to [Operator Infershape Extension](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html#operator-infershape-extension), and the sample code can be found in [infer](https://gitee.com/mindspore/mindspore/tree/r2.6.0/mindspore/lite/examples/converter_extend/infer). ## Example -### Compile +### Compilation - Environment Requirements @@ -90,21 +92,21 @@ For the sample code, please refer to [pass](https://gitee.com/mindspore/mindspor - Compilation preparation - The release package of MindSpore Lite doesn't provide serialized files of other frameworks, therefore, users need to compile and obtain by yourselves. Here, please refer to [Overview](https://www.mindspore.cn/lite/docs/en/master/advanced/third_party/converter_register.html#overview). + The release package of MindSpore Lite doesn't provide serialized files of other frameworks, therefore, users need to compile and obtain by yourselves. Here, please refer to [Overview](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/third_party/converter_register.html#overview). - The case is a tflite model, users need to compile [flatbuffers](https://gitee.com/mindspore/mindspore/blob/master/cmake/external_libs/flatbuffers.cmake) and combine the [TFLITE Proto File](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/tools/converter/parser/tflite/schema.fbs) to generate the serialized file. + The case is a tflite model, users need to compile [flatbuffers](https://gitee.com/mindspore/mindspore/blob/v2.6.0/cmake/external_libs/flatbuffers.cmake) and combine the [TFLITE Proto File](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/tools/converter/parser/tflite/schema.fbs) to generate the serialized file. After generating, users need to create a directory `schema` under the directory of `mindspore/lite/examples/converter_extend` and then place the serialized file in it. - Compilation and Build - Execute the script [build.sh](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/examples/converter_extend/build.sh) in the directory of `mindspore/lite/examples/converter_extend`. And then, the released package of MindSpore Lite will be downloaded and the demo will be compiled automatically. + Execute the script [build.sh](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/examples/converter_extend/build.sh) in the directory of `mindspore/lite/examples/converter_extend`. And then, the released package of MindSpore Lite will be downloaded and the demo will be compiled automatically. ```bash bash build.sh ``` - > If the automatic download is failed, users can download the specified package manually, of which the hardware platform is CPU and the system is Ubuntu-x64 [mindspore-lite-{version}-linux-x64.tar.gz](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html), After unzipping, please copy the directory of `tools/converter/lib` and `tools/converter/include` to the directory of `mindspore/lite/examples/converter_extend`. + > If the automatic download is failed, users can download the specified package manually, of which the hardware platform is CPU and the system is Ubuntu-x64 [mindspore-lite-{version}-linux-x64.tar.gz](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html), After unzipping, please copy the directory of `tools/converter/lib` and `tools/converter/include` to the directory of `mindspore/lite/examples/converter_extend`. > > After manually downloading and storing the specified file, users need to execute the `build.sh` script to complete the compilation and build process. @@ -166,7 +168,7 @@ If the user needs to turn off the specified operator fusions, the fusion configu ```ini [registry] -# When parameter `disable_fusion` is configured as `off`, the user can turn off the specified operator fusions by configuring parameter `fusion_blacklists`. While parameter `disable_fusion` is configured as `on`, the parameter `fusion_blacklists` does not work. +# When parameter `disable_fusion` is configured as `off`, the user can turn off the specified operator fusions by configuring parameter `fusion_blacklists`. While parameter `disable_fusion` is configured as `on`, all operator fusions are turned off and the parameter `fusion_blacklists` does not work. disable_fusion=off fusion_blacklists=ConvActivationFusion,MatMulActivationFusion ``` diff --git a/docs/lite/docs/source_en/advanced/third_party/delegate.md b/docs/lite/docs/source_en/advanced/third_party/delegate.md index 2e861ccd7cc21463ee548bb8ca44f70cb9aa7d71..81e3849b8bcfc007fea3967c0cf29133ebab672c 100644 --- a/docs/lite/docs/source_en/advanced/third_party/delegate.md +++ b/docs/lite/docs/source_en/advanced/third_party/delegate.md @@ -1,6 +1,6 @@ # Using Delegate to Support Third-party AI Framework (Device) -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/delegate.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/delegate.md) ## Overview @@ -10,14 +10,14 @@ Delegate of MindSpore Lite is used to support third-party AI frameworks (such as Using Delegate to support a third-party AI framework mainly includes the following steps: -1. Add a custom delegate class: Inherit the [Delegate](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Delegate.html) class to implement XXXDelegate. -2. Implementing the Init Function: The [Init](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Delegate.html) function needs to check whether the device supports the delegate framework and to apply for resources related to delegate. -3. Implementing the Build Function: The [Build](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Delegate.html) function will implement the kernel support judgment, the sub-graph construction, and the online graph building. -4. Implementing the sub-graph Kernel: Inherit the [Kernel](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_kernel_Kernel.html#class-kernel) to implement delegate sub-graph Kernel. +1. Add a custom delegate class: Inherit the [Delegate](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Delegate.html) class to implement XXXDelegate. +2. Implementing the Init Function: The [Init](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Delegate.html) function needs to check whether the device supports the delegate framework and to apply for resources related to delegate. +3. Implementing the Build Function: The [Build](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Delegate.html) function will implement the kernel support judgment, the sub-graph construction, and the online graph building. +4. Implementing the sub-graph Kernel: Inherit the [Kernel](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_kernel_Kernel.html#class-kernel) to implement delegate sub-graph Kernel. ### Adding a Custom Delegate Class -XXXDelegate should inherit from [Delegate](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Delegate.html). In the constructor of XXXDelegate, configure settings for third-party AI framework to build and execute the model, such as NPU frequency, CPU thread number, etc. +XXXDelegate should inherit from [Delegate](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Delegate.html). In the constructor of XXXDelegate, configure settings for third-party AI framework to build and execute the model, such as NPU frequency, CPU thread number, etc. ```cpp class XXXDelegate : public Delegate { @@ -34,7 +34,7 @@ class XXXDelegate : public Delegate { ### Implementing the Init -[Init](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Delegate.html) will be called during the [Build](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) process of [Model](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html#class-model). The specific location is in the [LiteSession::Init](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/lite_session.cc#L696) function of MindSpore Lite internal process. +[Init](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Delegate.html) will be called during the [Build](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) process of [Model](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html#class-model). The specific location is in the [LiteSession::Init](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/lite_session.cc#L696) function of MindSpore Lite internal process. ```cpp Status XXXDelegate::Init() { @@ -45,16 +45,16 @@ Status XXXDelegate::Init() { ### Implementing the Build -The input parameter of the [Build(DelegateModel *model)](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Delegate.html) interface is [DelegateModel](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_DelegateModel.html#template-class-delegatemodel). +The input parameter of the [Build(DelegateModel *model)](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Delegate.html) interface is [DelegateModel](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_DelegateModel.html#template-class-delegatemodel). -> [std::vector *kernels_](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_kernel_Kernel.html): A list of kernels that have been selected by MindSpore Lite and topologically sorted. +> [std::vector *kernels_](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_kernel_Kernel.html): A list of kernels that have been selected by MindSpore Lite and topologically sorted. > -> [const std::map primitives_](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_DelegateModel.html): A map of kernel and its attribute `schema::Primitive`, which is used to analyze the original attribute information. +> [const std::map primitives_](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_DelegateModel.html): A map of kernel and its attribute `schema::Primitive`, which is used to analyze the original attribute information. -[Build](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Delegate.html) will be called during the [Build](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) process of [Model](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html#class-model). The specific location is in the [Schedule::Schedule](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/scheduler.cc#L132) function of MindSpore Lite internal process. At this time, the inner kernels have been selected by MindSpore Lite. The following steps should be implemented in Build function: +[Build](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Delegate.html) will be called during the [Build](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) process of [Model](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html#class-model). The specific location is in the [Schedule::Schedule](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/scheduler.cc#L132) function of MindSpore Lite internal process. At this time, the inner kernels have been selected by MindSpore Lite. The following steps should be implemented in Build function: -1. Traverse the kernel list, use [GetPrimitive](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_DelegateModel.html) to get the attribute of kernel. Analyze the attribute to judge whether the delegate framework supports it. -2. For a continuous supported kernel list, construct a delegate sub-graph kernel and [Replace](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_DelegateModel.html) the continuous supported kernels with it. +1. Traverse the kernel list, use [GetPrimitive](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_DelegateModel.html) to get the attribute of kernel. Analyze the attribute to judge whether the delegate framework supports it. +2. For a continuous supported kernel list, construct a delegate sub-graph kernel and [Replace](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_DelegateModel.html) the continuous supported kernels with it. ```cpp Status XXXDelegate::Build(DelegateModel *model) { @@ -95,10 +95,10 @@ kernel::Kernel *XXXDelegate::CreateXXXGraph(KernelIter from, KernelIter end, Del } ``` -The delegate sub-graph kernel `XXXGraph` should inherit from [Kernel](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_kernel_Kernel.html#class-kernel). The realization of `XXXGraph` should focus on: +The delegate sub-graph kernel `XXXGraph` should inherit from [Kernel](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_kernel_Kernel.html#class-kernel). The realization of `XXXGraph` should focus on: 1. Find the correct in_tensors and out_tensors for `XXXGraph` according to the original kernels list. -2. Rewrite the Prepare, Resize, and Execute interfaces. [Prepare](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_kernel.html#prepare) will be called in [Build](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) of [Model](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html#class-model). [Execute](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_kernel.html#execute) will be called in [Predict](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) of Model. [ReSize](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore_kernel.html#resize) will be called in [Resize](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) of Model. +2. Rewrite the Prepare, Resize, and Execute interfaces. [Prepare](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_kernel.html#prepare) will be called in [Build](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) of [Model](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html#class-model). [Execute](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_kernel.html#execute) will be called in [Predict](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) of Model. [ReSize](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore_kernel.html#resize) will be called in [Resize](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) of Model. ```cpp class XXXGraph : public kernel::Kernel { @@ -127,7 +127,7 @@ class XXXGraph : public kernel::Kernel { ## Calling Delegate by Lite Framework -MindSpore Lite schedules user-defined delegate by [Context](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Context.html#class-context). Use [SetDelegate](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#setdelegate) to set a custom delegate for Context. Delegate will be passed by [Build](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) to MindSpore Lite. If the Delegate in the Context is a null pointer, the process will call the inner inference of MindSpore Lite. +MindSpore Lite schedules user-defined delegate by [Context](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Context.html#class-context). Use [SetDelegate](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#setdelegate) to set a custom delegate for Context. Delegate will be passed by [Build](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) to MindSpore Lite. If the Delegate in the Context is a null pointer, the process will call the inner inference of MindSpore Lite. ```cpp auto context = std::make_shared(); @@ -156,7 +156,7 @@ if (build_ret != mindspore::kSuccess) { ## Example of NPUDelegate -Currently, MindSpore Lite uses the [NPUDelegate](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/delegate/npu/npu_delegate.h#L29) for the NPU backend. This tutorial gives a brief description of NPUDelegate, so that users can quickly understand the usage of Delegate APIs. +Currently, MindSpore Lite uses the [NPUDelegate](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/delegate/npu/npu_delegate.h#L29) for the NPU backend. This tutorial gives a brief description of NPUDelegate, so that users can quickly understand the usage of Delegate APIs. ### Adding the NPUDelegate Class @@ -190,7 +190,7 @@ class NPUDelegate : public Delegate { ### Implementing the Init of NPUDelegate -[Init](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/delegate/npu/npu_delegate.cc#L75) function is used to apply resource for NPU and determine whether the hardware supports NPU. +[Init](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/delegate/npu/npu_delegate.cc#L75) function is used to apply resource for NPU and determine whether the hardware supports NPU. ```cpp Status NPUDelegate::Init() { @@ -217,7 +217,7 @@ Status NPUDelegate::Init() { ### Implementing the Build of NPUDelegate -The [Build](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/delegate/npu/npu_delegate.cc#L163) interface parses the DelegateModel and mainly implements the kernel support judgment, the sub-graph construction, and the online graph building. +The [Build](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/delegate/npu/npu_delegate.cc#L163) interface parses the DelegateModel and mainly implements the kernel support judgment, the sub-graph construction, and the online graph building. ```cpp Status NPUDelegate::Build(DelegateModel *model) { @@ -257,7 +257,7 @@ Status NPUDelegate::Build(DelegateModel *model) { ### Creating NPUGraph -The following [Sample Code](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/delegate/npu/npu_delegate.cc#L273) is the CreateNPUGraph interface of NPUDelegate, used to generate an NPU sub-graph kernel. +The following [Sample Code](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/delegate/npu/npu_delegate.cc#L273) is the CreateNPUGraph interface of NPUDelegate, used to generate an NPU sub-graph kernel. ```cpp kernel::Kernel *NPUDelegate::CreateNPUGraph(const std::vector &ops) { @@ -279,7 +279,7 @@ kernel::Kernel *NPUDelegate::CreateNPUGraph(const std::vector &ops) { ### Adding the NPUGraph Class -[NPUGraph](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/delegate/npu/npu_graph.h#L29) inherits from [Kernel](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_kernel_Kernel.html#class-kernel). And we need to rewrite the Prepare, Execute, and ReSize interfaces. +[NPUGraph](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/delegate/npu/npu_graph.h#L29) inherits from [Kernel](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_kernel_Kernel.html#class-kernel). And we need to rewrite the Prepare, Execute, and ReSize interfaces. ```cpp class NPUGraph : public kernel::Kernel { @@ -306,7 +306,7 @@ class NPUGraph : public kernel::Kernel { }; ``` -[NPUGraph::Prepare](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/delegate/npu/npu_graph.cc#L306) mainly implements: +[NPUGraph::Prepare](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/delegate/npu/npu_graph.cc#L306) mainly implements: ```cpp int NPUGraph::Prepare() { @@ -314,7 +314,7 @@ int NPUGraph::Prepare() { } ``` -[NPUGraph::Execute](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/src/litert/delegate/npu/npu_graph.cc#L322) mainly implements: +[NPUGraph::Execute](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/src/litert/delegate/npu/npu_graph.cc#L322) mainly implements: ```cpp int NPUGraph::Execute() { @@ -325,4 +325,4 @@ int NPUGraph::Execute() { } ``` -> [NPU](https://www.mindspore.cn/lite/docs/en/master/advanced/third_party/npu_info.html) is a third-party AI framework that added by MindSpore Lite internal developers. The usage of NPU is slightly different. You can set the [Context](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Context.html#class-context) through [SetDelegate](https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html#setdelegate), or you can add the description of the NPU device [KirinNPUDeviceInfo](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_KirinNPUDeviceInfo.html#class-kirinnpudeviceinfo) to [MutableDeviceInfo](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Context.html) of the Context. +> [NPU](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/third_party/npu_info.html) is a third-party AI framework that added by MindSpore Lite internal developers. The usage of NPU is slightly different. You can set the [Context](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Context.html#class-context) through [SetDelegate](https://www.mindspore.cn/lite/api/zh-CN/r2.6.0/api_cpp/mindspore.html#setdelegate), or you can add the description of the NPU device [KirinNPUDeviceInfo](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_KirinNPUDeviceInfo.html#class-kirinnpudeviceinfo) to [MutableDeviceInfo](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Context.html) of the Context. diff --git a/docs/lite/docs/source_en/advanced/third_party/npu_info.md b/docs/lite/docs/source_en/advanced/third_party/npu_info.md index 4fb8d1d33d6940383237295f6e8c022cc46ed8ae..b449e4df77e2b749226914bba402790727d619b8 100644 --- a/docs/lite/docs/source_en/advanced/third_party/npu_info.md +++ b/docs/lite/docs/source_en/advanced/third_party/npu_info.md @@ -1,29 +1,26 @@ # NPU Integration Information -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/npu_info.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/npu_info.md) ## Steps ### Environment Preparation -Besides basic [Environment Preparation](https://www.mindspore.cn/lite/docs/en/master/build/build.html), HUAWEI HiAI DDK, which contains -APIs (including building, loading models and calculation processes) and interfaces implemented to encapsulate dynamic libraries (namely libhiai*.so), -is required for the use of NPU. Download [DDK 100.510.010.010](https://developer.huawei.com/consumer/en/doc/development/hiai-Library/ddk-download-0000001053590180), -and set the directory of extracted files as `${HWHIAI_DDK}`. Our build script uses this environment viriable to seek DDK. +Besides basic [Environment Preparation](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html), using the NPU requires the integration of the HUAWEI HiAI DDK. +HUAWEI HiAI DDK, which contains APIs (including building, loading models and calculation processes) and interfaces implemented to encapsulate dynamic libraries (namely libhiai*.so), is required for the use of NPU. +Download [DDK 100.510.010.010](https://developer.huawei.com/consumer/en/doc/development/hiai-Library/ddk-download-0000001053590180), and set the directory of extracted files as `${HWHIAI_DDK}`. Our build script uses this environment viriable to seek DDK. ### Build -Under the Linux operating system, one can easily build MindSpore Lite Package integrating NPU interfaces and libraries using build.sh under -the root directory of MindSpore [Source Code](https://gitee.com/mindspore/mindspore). The command is as follows. -It will build MindSpore Lite's package under the output directory under the MindSpore source code root directory, -which contains the NPU's dynamic library, the libmindspore-lite dynamic library, and the test tool Benchmark. +Under the Linux operating system, one can easily build MindSpore Lite Package integrating NPU interfaces and libraries using build.sh under the root directory of MindSpore [Source Code](https://gitee.com/mindspore/mindspore). The command is as follows. +It will build MindSpore Lite's package under the output directory under the MindSpore source code root directory, which contains the NPU's dynamic library, the libmindspore-lite dynamic library, and the test tool Benchmark. ```bash export MSLITE_ENABLE_NPU=ON bash build.sh -I arm64 -j8 ``` -For more information about compilation, see [Linux Environment Compilation](https://www.mindspore.cn/lite/docs/en/master/build/build.html#linux-environment-compilation). +For more information about compilation, see [Linux Environment Compilation](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html#linux-environment-compilation). ### Integration @@ -31,11 +28,11 @@ For more information about compilation, see [Linux Environment Compilation](http When developers need to integrate the use of NPU features, it is important to note: - - [Configure the NPU backend](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html#configuring-the-npu-backend). - For more information about using Runtime to perform inference, see [Using Runtime to Perform Inference (C++)](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html). + - [Configure the NPU backend](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html#configuring-the-npu-backend). + For more information about using Runtime to perform inference, see [Using Runtime to Perform Inference (C++)](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html). - - Compile and execute the binary. If you use dynamic linking, refer to [compile output](https://www.mindspore.cn/lite/docs/en/master/build/build.html) when the compile option is `-I arm64` or `-I arm32`. - Configured environment variables will dynamically load libhiai.so, libhiai_ir.so, libhiai_ir_build.so, libhiai_hcl_model_runtime.so. For example, + - Compile and execute the binary. If you use dynamic linking, refer to [compile output](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html) when the compile option is `-I arm64` or `-I arm32`. + Configured environment variables will dynamically load libhiai.so, libhiai_ir.so, libhiai_ir_build.so, libhiai_hcl_model_runtime.so. For example, ```bash export LD_LIBRARY_PATH=mindspore-lite-{version}-android-{arch}/runtime/third_party/hiai_ddk/lib/:$LD_LIBRARY_PATH @@ -57,11 +54,9 @@ For more information about compilation, see [Linux Environment Compilation](http ./benchmark --device=NPU --modelFile=./models/test_benchmark.ms --inDataFile=./input/test_benchmark.bin --inputShapes=1,32,32,1 --accuracyThreshold=3 --benchmarkDataFile=./output/test_benchmark.out ``` -For more information about the use of Benchmark, see [Benchmark Use](https://www.mindspore.cn/lite/docs/en/master/tools/benchmark_tool.html). +For more information about the use of Benchmark, see [Benchmark Use](https://www.mindspore.cn/lite/docs/en/r2.6.0/tools/benchmark_tool.html). -For environment variable settings, you need to set the directory where the libmindspore-lite.so -(under the directory `mindspore-lite-{version}-android-{arch}/runtime/lib`) and NPU libraries -(under the directory `mindspore-lite-{version}-android-{arch}/runtime/third_party/hiai_ddk/lib/`) are located, to `${LD_LIBRARY_PATH}`. +For environment variable settings, you need to set the directory where the libmindspore-lite.so (under the directory `mindspore-lite-{version}-android-{arch}/runtime/lib`) and NPU libraries (under the directory `mindspore-lite-{version}-android-{arch}/runtime/third_party/hiai_ddk/lib/`) are located, to `${LD_LIBRARY_PATH}`. ## Supported Chips @@ -69,4 +64,4 @@ For supported NPU chips, see [Chipset Platforms and Supported HUAWEI HiAI Versio ## Supported Operators -For supported NPU operators, see [Lite Operator List](https://www.mindspore.cn/lite/docs/en/master/reference/operator_list_lite.html). \ No newline at end of file +For supported NPU operators, see [Lite Operator List](https://www.mindspore.cn/lite/docs/en/r2.6.0/reference/operator_list_lite.html). \ No newline at end of file diff --git a/docs/lite/docs/source_en/advanced/third_party/register.rst b/docs/lite/docs/source_en/advanced/third_party/register.rst index 3d9cf5ce4d149645b05faeb77a3eb98cc3ca7bdc..7e9d80863b365682306ebf75aae435d84750c800 100644 --- a/docs/lite/docs/source_en/advanced/third_party/register.rst +++ b/docs/lite/docs/source_en/advanced/third_party/register.rst @@ -1,9 +1,9 @@ Custom Kernel =============== -.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg - :target: https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/register.rst - :alt: View Source on Gitee +.. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg + :target: https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/register.rst + :alt: View Source On Gitee .. toctree:: :maxdepth: 1 diff --git a/docs/lite/docs/source_en/advanced/third_party/register_kernel.md b/docs/lite/docs/source_en/advanced/third_party/register_kernel.md index df9c49a858bc1bcdadf88210c23415b4a9c13480..b3c1aa8ceca5e0c9df27007d43ad1d60fb06c4bb 100644 --- a/docs/lite/docs/source_en/advanced/third_party/register_kernel.md +++ b/docs/lite/docs/source_en/advanced/third_party/register_kernel.md @@ -1,6 +1,6 @@ # Building Custom Operators Online -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/register_kernel.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/register_kernel.md) ## Implementing Custom Operators @@ -18,11 +18,11 @@ View the operator prototype definition in mindspore/lite/schema/ops.fbs. Check w ### Common Operators -For details about code related to implementation, registration, and InferShape of an operator, see [the code repository](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/test/ut/src/registry/registry_test.cc). +For details about code related to implementation, registration, and InferShape of an operator, see [the code repository](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/test/ut/src/registry/registry_test.cc). #### Implementing Common Operators -Inherit [mindspore::kernel::Kernel](https://www.mindspore.cn/lite/api/en/master/api_cpp/mindspore_kernel.html) and overload necessary APIs. The following describes how to customize an Add operator: +Inherit [mindspore::kernel::Kernel](https://www.mindspore.cn/lite/api/en/r2.6.0/api_cpp/mindspore_kernel.html) and overload necessary APIs. The following describes how to customize an Add operator: 1. An operator inherits a kernel. 2. PreProcess() pre-allocates memory. @@ -74,7 +74,7 @@ int TestCustomAdd::Execute() { #### Registering Common Operators -Currently, the generated macro [REGISTER_KERNEL](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_registry_RegisterKernel.html) is provided for operator registration. The implementation procedure is as follows: +Currently, the generated macro [REGISTER_KERNEL](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_registry_RegisterKernel.html) is provided for operator registration. The implementation procedure is as follows: 1. The TestCustomAddCreator function is used to create a kernel. 2. Use the macro REGISTER_KERNEL to register the kernel. Assume that the vendor is BuiltInTest. @@ -96,7 +96,7 @@ REGISTER_KERNEL(CPU, BuiltInTest, kFloat32, PrimitiveType_AddFusion, TestCustomA Reload the Infer function after inheriting KernelInterface to implement the InferShape capability. The implementation procedure is as follows: -1. Inherit [KernelInterface](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_kernel_KernelInterface.html). +1. Inherit [KernelInterface](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_kernel_KernelInterface.html). 2. Overload the Infer function to derive the shape, format, and data_type of the output tensor. The following uses the custom Add operator as an example: @@ -120,7 +120,7 @@ class TestCustomAddInfer : public KernelInterface { #### Registering the Common Operator InferShape -Currently, the generated macro [REGISTER_KERNEL_INTERFACE](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_registry_RegisterKernelInterface.html) is provided for registering the operator InferShape. The procedure is as follows: +Currently, the generated macro [REGISTER_KERNEL_INTERFACE](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_registry_RegisterKernelInterface.html) is provided for registering the operator InferShape. The procedure is as follows: 1. Use the CustomAddInferCreator function to create a KernelInterface instance. 2. Call the REGISTER_KERNEL_INTERFACE macro to register the common operator InferShape. Assume that the vendor is BuiltInTest. @@ -133,7 +133,7 @@ REGISTER_KERNEL_INTERFACE(BuiltInTest, PrimitiveType_AddFusion, CustomAddInferCr ### Custom Operators -For details about code related to parsing, creating, and operating custom operators, see [the code repository](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc). +For details about code related to parsing, creating, and operating custom operators, see [the code repository](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/test/ut/tools/converter/registry/pass_registry_test.cc). #### Defining Custom Operators @@ -220,11 +220,11 @@ REG_SCHEDULED_PASS(POSITION_BEGIN, schedule) // Set the external Pass sche } // namespace mindspore::opt ``` -For details about code related to implementation, registration, and InferShape of a custom operator, see [the code repository](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/test/ut/src/registry/registry_custom_op_test.cc). +For details about code related to implementation, registration, and InferShape of a custom operator, see [the code repository](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/test/ut/src/registry/registry_custom_op_test.cc). #### Implementing Custom Operators -The implementation procedure of a custom operator is the same as that of a common operator, because they are specific subclasses of [Kernel](https://www.mindspore.cn/lite/api/en/master/api_cpp/mindspore_kernel.html). +The implementation procedure of a custom operator is the same as that of a common operator, because they are specific subclasses of [Kernel](https://www.mindspore.cn/lite/api/en/r2.6.0/api_cpp/mindspore_kernel.html). If the custom operator does not run on the CPU platform, the result needs to be copied back to the output tensor after the running is complete. The following describes how to create a custom operator with the Add capability: 1. An operator inherits a kernel. @@ -277,7 +277,7 @@ int TestCustomOp::Execute() { #### Custom Operator Attribute Decoding Example -In the example, the byte stream in the attribute is copied to the buffer. +In the example, the byte stream in the attribute is copied to the buf. ```cpp auto prim = primitive_->value_as_Custom(); @@ -295,7 +295,7 @@ In the example, the byte stream in the attribute is copied to the buffer. #### Registering Custom Operators -Currently, the generated macro [REGISTER_CUSTOM_KERNEL](https://www.mindspore.cn/lite/api/en/master/generate/define_register_kernel.h_REGISTER_CUSTOM_KERNEL-1.html) is provided for operator registration. The procedure is as follows: +Currently, the generated macro [REGISTER_CUSTOM_KERNEL](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_register_kernel.h_REGISTER_CUSTOM_KERNEL-1.html) is provided for operator registration. The procedure is as follows: 1. The TestCustomAddCreator function is used to create a kernel. 2. Use the macro REGISTER_CUSTOM_KERNEL to register an operator. Assume that the vendor is BuiltInTest and the operator type is Add. @@ -316,7 +316,7 @@ REGISTER_CUSTOM_KERNEL(CPU, BuiltInTest, kFloat32, Add, TestCustomAddCreator) The overall implementation is the same as that of the common operator InferShape. The procedure is as follows: -1. Inherit [KernelInterface](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_kernel_KernelInterface.html). +1. Inherit [KernelInterface](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_kernel_KernelInterface.html). 2. Overload the Infer function to derive the shape, format, and data_type of the output tensor. ```cpp @@ -336,10 +336,10 @@ class TestCustomOpInfer : public KernelInterface { #### Registering the Custom Operator InferShape -Currently, the generated macro [REGISTER_CUSTOM_KERNEL_INTERFACE](https://www.mindspore.cn/lite/api/en/master/generate/define_register_kernel_interface.h_REGISTER_CUSTOM_KERNEL_INTERFACE-1.html) is provided for registering the custom operator InferShape. The procedure is as follows: +Currently, the generated macro [REGISTER_CUSTOM_KERNEL_INTERFACE](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_register_kernel_interface.h_REGISTER_CUSTOM_KERNEL_INTERFACE-1.html) is provided for registering the custom operator InferShape. The procedure is as follows: 1. Use the CustomAddInferCreator function to create a custom KernelInterface. -2. The macro [REGISTER_CUSTOM_KERNEL_INTERFACE](https://www.mindspore.cn/lite/api/en/master/generate/define_register_kernel_interface.h_REGISTER_CUSTOM_KERNEL_INTERFACE-1.html) is provided for registering the InferShape capability. The operator type Add must be the same as that in REGISTER_CUSTOM_KERNEL_INTERFACE. +2. The macro [REGISTER_CUSTOM_KERNEL_INTERFACE](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/define_register_kernel_interface.h_REGISTER_CUSTOM_KERNEL_INTERFACE-1.html) is provided for registering the InferShape capability. The operator type Add must be the same as that in REGISTER_CUSTOM_KERNEL_INTERFACE. ```cpp std::shared_ptr CustomAddInferCreator() { return std::make_shared(); } @@ -349,9 +349,9 @@ REGISTER_CUSTOM_KERNEL_INTERFACE(BuiltInTest, Add, CustomAddInferCreator) ## Custom GPU Operators -A set of GPU-related functional APIs are provided to facilitate the development of the GPU-based custom operator and enable the GPU-based custom operator to share the same resources with the internal GPU-based operators to improve the scheduling efficiency. For details about the APIs, see [mindspore::registry::opencl](https://www.mindspore.cn/lite/api/en/master/api_cpp/mindspore_registry_opencl.html). +A set of GPU-related functional APIs are provided to facilitate the development of the GPU-based custom operator and enable the GPU-based custom operator to share the same resources with the internal GPU-based operators to improve the scheduling efficiency. For details about the APIs, see [mindspore::registry::opencl](https://www.mindspore.cn/lite/api/en/r2.6.0/api_cpp/mindspore_registry_opencl.html). This document describes how to develop a custom GPU operator by parsing sample code. Before reading this document, you need to understand [Implement Custom Operators](#implementing-custom-operators). -The [code repository](https://gitee.com/mindspore/mindspore/blob/master/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc) contains implementation and registration of custom GPU operators. +The [code repository](https://gitee.com/mindspore/mindspore/blob/v2.6.0/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc) contains implementation and registration of custom GPU operators. ### Registering Operators @@ -376,7 +376,7 @@ using CreateKernel = std::function( ``` In this example, the operator instance creation function is implemented as follows. The function returns a `CustomAddKernel` class instance. This class is the user-defined operator class that inherits the `kernel::Kernel` class. For details about the implementation of this class, see [Implementing Operators](#implementing-operators). -In the function, in addition to transferring the function parameters to the constructor function of the `CustomAddKernel` class, a Boolean variable is also transferred. The variable is used to control whether the data type processed by the created `CustomAddKernel` instance is FLOAT32 or FLOAT16. +In the function, in addition to transferring the function parameters to the constructor function of the `CustomAddKernel` class, a Boolean variable is also transferred. The variable is used to control whether the data type processed by the created `CustomAddKernel` instance is float32 or float16. ```cpp namespace custom_gpu_demo { @@ -394,7 +394,7 @@ std::shared_ptr CustomAddCreator(const std::vector &in #### Registering Operators When registering GPU operators, you must declare the device type as GPU and transfer the operator instance creation function `CustomAddCreator` implemented in the previous step. -In this example, the Float32 implementation of the Custom_Add operator is registered. The registration code is as follows. For details about other parameters in the registration macro, see the [API](https://www.mindspore.cn/lite/api/en/master/api_cpp/mindspore_registry.html). +In this example, the Float32 implementation of the Custom_Add operator is registered. The registration code is as follows. For details about other parameters in the registration macro, see the [API](https://www.mindspore.cn/lite/api/en/r2.6.0/api_cpp/mindspore_registry.html). ```cpp const auto kFloat32 = DataType::kNumberTypeFloat32; @@ -404,7 +404,7 @@ REGISTER_CUSTOM_KERNEL(GPU, BuiltInTest, kFloat32, Custom_Add, CustomAddCreator) ### Implementing Operators -In this example, the operator is implemented as the `CustomAddKernel` class. This class inherits [mindspore::kernel::Kernel](https://www.mindspore.cn/lite/api/en/master/api_cpp/mindspore_kernel.html) and reloads necessary APIs to implement the custom operator computation. +In this example, the operator is implemented as the `CustomAddKernel` class. This class inherits [mindspore::kernel::Kernel](https://www.mindspore.cn/lite/api/en/r2.6.0/api_cpp/mindspore_kernel.html) and reloads necessary APIs to implement the custom operator computation. #### Constructor and Destructor Functions @@ -428,7 +428,7 @@ class CustomAddKernel : public kernel::Kernel { - opencl_runtime_ - An instance of the OpenCLRuntimeWrapper class. In an operator, this object can be used to call the OpenCL-related API [mindspore::registry::opencl](https://www.mindspore.cn/lite/api/en/master/api_cpp/mindspore_registry_opencl.html) provided by MindSpore Lite. + An instance of the OpenCLRuntimeWrapper class. In an operator, this object can be used to call the OpenCL-related API [mindspore::registry::opencl](https://www.mindspore.cn/lite/api/en/r2.6.0/api_cpp/mindspore_registry_opencl.html) provided by MindSpore Lite. - fp16_enable_ @@ -440,7 +440,7 @@ class CustomAddKernel : public kernel::Kernel { - Other variables - Other variables are required for OpenCL operations. For details, see [mindspore::registry::opencl](https://www.mindspore.cn/lite/api/en/master/api_cpp/mindspore_registry_opencl.html). + Other variables are required for OpenCL operations. For details, see [mindspore::registry::opencl](https://www.mindspore.cn/lite/api/en/r2.6.0/api_cpp/mindspore_registry_opencl.html). ```c++ class CustomAddKernel : public kernel::Kernel { @@ -558,7 +558,7 @@ In this example, the Prepare API is overloaded to load and build the custom Open 3. Build the OpenCL code. - Use `fp16_enable_` to specify different build options to generate the code for processing FLOAT16 or FLOAT32 data. + Use `fp16_enable_` to specify different build options to generate the code for processing float16 or float32 data. Use `opencl_runtime_` to call the `OpenCLRuntimeWrapper::BuildKernel` API, obtain the built `cl::Kernel` variable, and save it in `kernel_`. ```cpp @@ -712,7 +712,7 @@ In this example, the Prepare API is overloaded to load and build the custom Open ... ``` - The `PackNHWCToNHWC4` function is implemented as follows, including the conversion between the FLOAT16 and FLOAT32 types. + The `PackNHWCToNHWC4` function is implemented as follows, including the conversion between the float16 and float32 types. ```cpp void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor, diff --git a/docs/lite/docs/source_en/advanced/third_party/tensorrt_info.md b/docs/lite/docs/source_en/advanced/third_party/tensorrt_info.md index 5a07742dc338971e23d8875a1fd79da023a4ed59..a0140f0ac2aafa360c7bc73f385c4c2993d3831e 100644 --- a/docs/lite/docs/source_en/advanced/third_party/tensorrt_info.md +++ b/docs/lite/docs/source_en/advanced/third_party/tensorrt_info.md @@ -1,12 +1,12 @@ # TensorRT Integration Information -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/advanced/third_party/tensorrt_info.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/advanced/third_party/tensorrt_info.md) ## Steps ### Environment Preparation -Besides basic [Environment Preparation](https://www.mindspore.cn/lite/docs/en/master/build/build.html), CUDA and TensorRT is required as well. Current version supports [CUDA 10.1](https://developer.nvidia.com/cuda-10.1-download-archive-base) and [TensorRT 6.0.1.5](https://developer.nvidia.com/nvidia-tensorrt-6x-download), and [CUDA 11.1](https://developer.nvidia.com/cuda-11.1.1-download-archive) and [TensorRT 8.5.1](https://developer.nvidia.com/nvidia-tensorrt-8x-download). +Besides basic [Environment Preparation](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html), CUDA and TensorRT is required as well. Current version supports [CUDA 10.1](https://developer.nvidia.com/cuda-10.1-download-archive-base) and [TensorRT 6.0.1.5](https://developer.nvidia.com/nvidia-tensorrt-6x-download), and [CUDA 11.1](https://developer.nvidia.com/cuda-11.1.1-download-archive) and [TensorRT 8.5.1](https://developer.nvidia.com/nvidia-tensorrt-8x-download). Install the appropriate version of CUDA and set the installed directory as environment variable `${CUDA_HOME}`. Our build script uses this environment variable to seek CUDA. @@ -14,24 +14,24 @@ Install TensorRT of the corresponding CUDA version, and set the installed direct ### Build -In the Linux environment, use the build.sh script in the root directory of MindSpore [Source Code](https://gitee.com/mindspore/mindspore) to build the MindSpore Lite package integrated with TensorRT. First configure the environment variable `MSLITE_GPU_BACKEND=tensorrt`, and then execute the compilation command as follows. +In the Linux environment, use the build.sh script in the root directory of MindSpore [Source Code](https://gitee.com/mindspore/mindspore) to build the MindSpore Lite package integrated with TensorRT. First configure the environment variable `MSLITE_GPU_BACKEND=tensorrt`, and then execute the compilation command as follows. It will build a package for MindSpore Lite in the output directory under the root of the MindSpore source code, containing `libmindspore-lite.so` and the test tool Benchmark. ```bash bash build.sh -I x86_64 ``` -For more information about compilation, see [Linux Environment Compilation](https://www.mindspore.cn/lite/docs/en/master/build/build.html#linux-environment-compilation). +For more information about compilation, see [Linux Environment Compilation](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html#linux-environment-compilation). ### Integration - Integration instructions When developers need to integrate the use of TensorRT features, it is important to note: - - [Configure the TensorRT backend](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html#configuring-the-gpu-backend), - For more information about using Runtime to perform inference, see [Using Runtime to Perform Inference (C++)](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html). + - [Configure the TensorRT backend](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html#configuring-the-gpu-backend) in the code. + For more information about using Runtime to perform inference, see [Using Runtime to Perform Inference (C++)](https://www.mindspore.cn/lite/docs/en/r2.6.0/infer/runtime_cpp.html). - - Compile and execute the binary. If you use dynamic linking, please refer to [Compilation Output](https://www.mindspore.cn/lite/docs/en/master/build/build.html#directory-structure) with compilation option `-I x86_64`. - Please set environment variables to dynamically link related libs. + - Compile and execute the binary. If you use dynamic linking, please refer to [Compilation Output](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html#directory-structure) with compilation option `-I x86_64`. + Please set environment variables to dynamically link related libs. ```bash export LD_LIBRARY_PATH=mindspore-lite-{version}-{os}-{arch}/runtime/lib/:$LD_LIBRARY_PATH @@ -41,7 +41,7 @@ For more information about compilation, see [Linux Environment Compilation](http - Using Benchmark testing TensorRT inference - Pass the build package to a device with a TensorRT environment(TensorRT 6.0.1.5) and use the Benchmark tool to test TensorRT inference. Examples are as follows: + Users can also test TensorRT inference using MindSpore Lite Benchmark tool. The location of the compiled Benchmark is shown in [Compiled Output](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html). Pass the build package to a device with a TensorRT environment(TensorRT 6.0.1.5) and use the Benchmark tool to test TensorRT inference. Examples are as follows: - Test performance @@ -55,14 +55,14 @@ For more information about compilation, see [Linux Environment Compilation](http ./benchmark --device=GPU --modelFile=./models/test_benchmark.ms --inDataFile=./input/test_benchmark.bin --inputShapes=1,32,32,1 --accuracyThreshold=3 --benchmarkDataFile=./output/test_benchmark.out ``` - For more information about the use of Benchmark, see [Benchmark Use](https://www.mindspore.cn/lite/docs/en/master/tools/benchmark.html). + For more information about the use of Benchmark, see [Benchmark Use](https://www.mindspore.cn/lite/docs/en/r2.6.0/tools/benchmark.html). For environment variable settings, you need to set the directory where the `libmindspore-lite.so` (under the directory `mindspore-lite-{version}-{os}-{arch}/runtime/lib`), TensorRT and CUDA `so` libraries are located, to `${LD_LIBRARY_PATH}`. - Using TensorRT engine serialization - TensorRT backend inference supports serializing the built TensorRT model (Engine) into a binary file and saves it locally. When it is used the next time, the model can be deserialized and loaded from the local, avoiding rebuilding and reducing overhead. To support this function, users need to use the [LoadConfig](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) interface to load the configuration file in the code, you need to specify the saving path of serialization file in the configuration file: + TensorRT backend inference supports serializing the built TensorRT model (Engine) into a binary file and saves it locally. When it is used the next time, the model can be deserialized and loaded from the local, avoiding rebuilding and reducing overhead. To support this function, users need to use the [LoadConfig](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) interface to load the configuration file in the code, you need to specify the saving path of serialization file in the configuration file: ``` [ms_cache] @@ -73,7 +73,7 @@ For more information about compilation, see [Linux Environment Compilation](http By default, TensorRT optimizes the model based on the input shapes (batch size, image size, and so on) at which it was defined. However, the input dimension can be adjusted at runtime by configuring the profile. In the profile, the minimum, dynamic and optimal shape of each input can be set. - TensorRT creates an optimized engine for each profile, choosing CUDA kernels that work for all shapes within the [minimum ~ maximum] range. And in the profile, multiple input dimensions can be configured for a single input. To support this function, users need to use the [LoadConfig](https://www.mindspore.cn/lite/api/en/master/generate/classmindspore_Model.html) interface to load the configuration file in the code. + TensorRT creates an optimized engine for each profile, choosing CUDA kernels that work for all shapes within the [minimum ~ maximum] range. And in the profile, multiple input dimensions can be configured for a single input. To support this function, users need to use the [LoadConfig](https://www.mindspore.cn/lite/api/en/r2.6.0/generate/classmindspore_Model.html) interface to load the configuration file in the code. If min, opt, and Max are the minimum, optimal, and maximum dimensions, and real_shape is the shape of the input tensor, the following conditions must hold: @@ -102,4 +102,4 @@ For more information about compilation, see [Linux Environment Compilation](http ## Supported Operators -For supported TensorRT operators, see [Lite Operator List](https://www.mindspore.cn/lite/docs/en/master/reference/operator_list_lite.html). +For supported TensorRT operators, see [Lite Operator List](https://www.mindspore.cn/lite/docs/en/r2.6.0/reference/operator_list_lite.html). diff --git a/docs/lite/docs/source_en/build/build.md b/docs/lite/docs/source_en/build/build.md index 7779c02b7fd56be468c194500c0250c75fbb693f..a7b3be8f916dcc640d00198a63107e7f22a41b2e 100644 --- a/docs/lite/docs/source_en/build/build.md +++ b/docs/lite/docs/source_en/build/build.md @@ -1,6 +1,6 @@ # Building Device-side -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/build/build.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/build/build.md) This chapter introduces how to quickly compile MindSpore Lite, which includes the following modules: @@ -22,7 +22,7 @@ Modules in MindSpore Lite: ### Environment Requirements - The compilation environment supports Linux x86_64 only. Ubuntu 18.04.02 LTS is recommended. -- Compilation dependencies of cpp: +- Compilation dependencies of c++: - [GCC](https://gcc.gnu.org/releases.html) >= 7.3.0 - [CMake](https://cmake.org/download/) >= 3.18.3 - [Git](https://git-scm.com/downloads) >= 2.28.0 @@ -83,50 +83,50 @@ The construction of modules is controlled by environment variables. Users can co - General module compilation options -| Option | Parameter Description | Value Range | Defaults | -| -------- | ----- | ---- | ---- | -| MSLITE_GPU_BACKEND | Set the GPU backend, only opencl is valid when the target OS is not OpenHarmony and `-I arm64`, and only tensorrt is valid when `-I x86_64` | opencl, tensorrt, off | opencl when `-I arm64`, off when `-I x86_64` | -| MSLITE_ENABLE_NPU | Whether to compile NPU operator, only valid when the target OS is not OpenHarmony `-I arm64` or `-I arm32` | on, off | off | -| MSLITE_ENABLE_TRAIN | Whether to compile the training version | on, off | on | -| MSLITE_ENABLE_SSE | Whether to enable SSE instruction set, only valid when `-I x86_64` | on, off | off | -| MSLITE_ENABLE_AVX | Whether to enable AVX instruction set, only valid when `-I x86_64` | on, off | off | -| MSLITE_ENABLE_AVX512 | Whether to enable AVX512 instruction set, only valid when `-I x86_64` | on, off | off | -| MSLITE_ENABLE_CONVERTER | Whether to compile the model conversion tool, only valid when `-I x86_64` | on, off | on | -| MSLITE_ENABLE_TOOLS | Whether to compile supporting tools | on, off | on | -| MSLITE_ENABLE_TESTCASES | Whether to compile test cases | on, off | off | -| MSLITE_ENABLE_MODEL_ENCRYPTION | Whether to support model encryption and decryption | on, off | off | -| MSLITE_ENABLE_MODEL_PRE_INFERENCE | Whether to enable pre-inference during model compilation | on, off | off | -| MSLITE_ENABLE_GITEE_MIRROR | Whether to enable download third_party from gitee mirror | on, off | off | - -> - For TensorRT and NPU compilation environment configuration, refer to [Application Specific Integrated Circuit Integration Instructions](https://www.mindspore.cn/lite/docs/en/master/advanced/third_party/asic.html). -> - When the AVX instruction set is enabled, the CPU of the running environment needs to support both AVX and FMA features. -> - The compilation time of the model conversion tool is long. If it is not necessary, it is recommended to use `MSLITE_ENABLE_CONVERTER` to turn off the compilation of the conversion tool to speed up the compilation. -> - The version supported by the OpenSSL encryption library is 1.1.1k, which needs to be downloaded and compiled by the user. For the compilation, please refer to: . In addition, the path of libcrypto.so.1.1 should be added to LD_LIBRARY_PATH. -> - When pre-inference during model compilation is enabled, for the non-encrypted model, the inference framework will create a child process for pre-inference when Build interface is called. After the child process returns successfully, the main precess will formally execute the process of graph compilation. -> - At present, OpenHarmony only supports CPU reasoning, not GPU reasoning. + | Option | Parameter Description | Value Range | Defaults | + | -------- | ----- | ---- | ---- | + | MSLITE_GPU_BACKEND | Set the GPU backend, only opencl is valid when the target OS is not OpenHarmony and `-I arm64`, and only tensorrt is valid when `-I x86_64` | opencl, tensorrt, off | opencl when `-I arm64`, off when `-I x86_64` | + | MSLITE_ENABLE_NPU | Whether to compile NPU operator, only valid when the target OS is not OpenHarmony `-I arm64` or `-I arm32` | on, off | off | + | MSLITE_ENABLE_TRAIN | Whether to compile the training version | on, off | on | + | MSLITE_ENABLE_SSE | Whether to enable SSE instruction set, only valid when `-I x86_64` | on, off | off | + | MSLITE_ENABLE_AVX | Whether to enable AVX instruction set, only valid when `-I x86_64` | on, off | off | + | MSLITE_ENABLE_AVX512 | Whether to enable AVX512 instruction set, only valid when `-I x86_64` | on, off | off | + | MSLITE_ENABLE_CONVERTER | Whether to compile the model conversion tool, only valid when `-I x86_64` | on, off | on | + | MSLITE_ENABLE_TOOLS | Whether to compile supporting tools | on, off | on | + | MSLITE_ENABLE_TESTCASES | Whether to compile test cases | on, off | off | + | MSLITE_ENABLE_MODEL_ENCRYPTION | Whether to support model encryption and decryption | on, off | off | + | MSLITE_ENABLE_MODEL_PRE_INFERENCE | Whether to enable pre-inference during model compilation | on, off | off | + | MSLITE_ENABLE_GITEE_MIRROR | Whether to enable download third_party from gitee mirror | on, off | off | + + > - For TensorRT and NPU compilation environment configuration, refer to [Application Specific Integrated Circuit Integration Instructions](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/third_party/asic.html). + > - When the AVX instruction set is enabled, the CPU of the running environment needs to support both AVX and FMA features. + > - The compilation time of the model conversion tool is long. If it is not necessary, it is recommended to use `MSLITE_ENABLE_CONVERTER` to turn off the compilation of the conversion tool to speed up the compilation. + > - The version supported by the OpenSSL encryption library is 1.1.1k, which needs to be downloaded and compiled by the user. For the compilation, please refer to: . In addition, the path of libcrypto.so.1.1 should be added to LD_LIBRARY_PATH. + > - When pre-inference during model compilation is enabled, for the non-encrypted model, the inference framework will create a child process for pre-inference when Build interface is called. After the child process returns successfully, the main precess will formally execute the process of graph compilation. + > - At present, OpenHarmony only supports CPU reasoning, not GPU reasoning. - Runtime feature compilation options -If the user is sensitive to the package size of the framework, the following options can be configured to reduce the package size by reducing the function of the runtime model reasoning framework. Then, the user can further reduce the package size by operator reduction through the [cropper tool](https://www.mindspore.cn/lite/docs/en/master/tools/cropper_tool.html). + If the user is sensitive to the package size of the framework, the following options can be configured to reduce the package size by reducing the function of the runtime model reasoning framework. Then, the user can further reduce the package size by operator reduction through the [cropper tool](https://www.mindspore.cn/lite/docs/en/r2.6.0/tools/cropper_tool.html). -| Option | Parameter Description | Value Range | Defaults | -| -------- | ----- | ---- | ---- | -| MSLITE_STRING_KERNEL | Whether to support string data reasoning model, such as smart_reply.tflite | on,off | on | -| MSLITE_ENABLE_CONTROLFLOW | Whether to support control flow model | on,off | on | -| MSLITE_ENABLE_WEIGHT_DECODE | Whether to support weight quantitative model | on,off | on | -| MSLITE_ENABLE_CUSTOM_KERNEL | Whether to support southbound operator registration | on,off | on | -| MSLITE_ENABLE_DELEGATE | Whether to support Delegate mechanism | on,off | on | -| MSLITE_ENABLE_FP16 | Whether to support FP16 operator | on,off | off when `-I x86_64`, on when `-I arm64`, when `-I arm32`, if the Android_NDK version is greater than r21e, it is on, otherwise it is off | -| MSLITE_ENABLE_INT8 | Whether to support INT8 operator | on,off | on | + | Option | Parameter Description | Value Range | Defaults | + | -------- | ----- | ---- | ---- | + | MSLITE_STRING_KERNEL | Whether to support string data reasoning model, such as smart_reply.tflite | on,off | on | + | MSLITE_ENABLE_CONTROLFLOW | Whether to support control flow model | on,off | on | + | MSLITE_ENABLE_WEIGHT_DECODE | Whether to support weight quantitative model | on,off | on | + | MSLITE_ENABLE_CUSTOM_KERNEL | Whether to support southbound operator registration | on,off | on | + | MSLITE_ENABLE_DELEGATE | Whether to support Delegate mechanism | on,off | on | + | MSLITE_ENABLE_FP16 | Whether to support FP16 operator | on,off | off when `-I x86_64`, on when `-I arm64`, when `-I arm32`, if the Android_NDK version is greater than r21e, it is on, otherwise it is off | + | MSLITE_ENABLE_INT8 | Whether to support INT8 operator | on,off | on | -> - Since the implementation of NPU and TensorRT depends on the Delegate mechanism, the Delegate mechanism cannot be turned off when using NPU or TensorRT. If the Delegate mechanism is turned off, the related functions must also be turned off. + > - Since the implementation of NPU and TensorRT depends on the Delegate mechanism, the Delegate mechanism cannot be turned off when using NPU or TensorRT. If the Delegate mechanism is turned off, the related functions must also be turned off. ### Compilation Example First, download source code from the MindSpore code repository. ```bash -git clone https://gitee.com/mindspore/mindspore.git +git clone -b v2.6.0 https://gitee.com/mindspore/mindspore.git ``` Then, run the following commands in the root directory of the source code to compile MindSpore Lite of different versions: @@ -323,7 +323,7 @@ The script `build.bat` in the root directory of MindSpore can be used to compile First, use the git tool to download the source code from the MindSpore code repository. ```bat -git clone https://gitee.com/mindspore/mindspore.git +git clone -b v2.6.0 https://gitee.com/mindspore/mindspore.git ``` Then, use the cmd tool to compile MindSpore Lite in the root directory of the source code and execute the following commands. @@ -416,7 +416,7 @@ The script `build.sh` in the root directory of MindSpore can be used to compile First, use the git tool to download the source code from the MindSpore code repository. ```bash -git clone https://gitee.com/mindspore/mindspore.git +git clone -b v2.6.0 https://gitee.com/mindspore/mindspore.git ``` Then, use the cmd tool to compile MindSpore Lite in the root directory of the source code and execute the following commands. diff --git a/docs/lite/docs/source_en/conf.py b/docs/lite/docs/source_en/conf.py index cd3e2e46c38df1f0da351bb4a3a43d8737b6eed9..caa7d70db0e14db4d17940e80666dfcec5e34dc6 100644 --- a/docs/lite/docs/source_en/conf.py +++ b/docs/lite/docs/source_en/conf.py @@ -22,7 +22,7 @@ copyright = 'MindSpore' author = 'MindSpore Lite' # The full version, including alpha/beta/rc tags -release = 'master' +release = '2.6.0' # -- General configuration --------------------------------------------------- diff --git a/docs/lite/docs/source_en/converter/converter_tool.md b/docs/lite/docs/source_en/converter/converter_tool.md index 0e104e88211e96d57579ff43900a2b3b4b71d997..1e78d45687b44cfecb3085e2a4ba7b6485951677 100644 --- a/docs/lite/docs/source_en/converter/converter_tool.md +++ b/docs/lite/docs/source_en/converter/converter_tool.md @@ -1,6 +1,6 @@ # Device-side Models Conversion -[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/master/docs/lite/docs/source_en/converter/converter_tool.md) +[![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.6.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.6.0/docs/lite/docs/source_en/converter/converter_tool.md) ## Overview @@ -16,7 +16,7 @@ The `ms` model converted by the conversion tool supports the conversion tool and To use the MindSpore Lite model conversion tool, you need to prepare the environment as follows: -- [Compile](https://www.mindspore.cn/lite/docs/en/master/build/build.html) or [download](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) model transfer tool. +- [Compile](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html) or [download](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) model transfer tool. - Add the path of dynamic library required by the conversion tool to the environment variables LD_LIBRARY_PATH. @@ -85,9 +85,9 @@ The following describes the parameters in detail. > - The Caffe model is divided into two files: model structure `*.prototxt`, corresponding to the `--modelFile` parameter; model weight `*.caffemodel`, corresponding to the `--weightFile` parameter. > - The priority of `--fp16` option is very low. For example, if quantization is enabled, `--fp16` will no longer take effect on const tensors that have been quantized. All in all, this option only takes effect on const tensors of float32 when serializing model. > - `inputDataFormat`: generally, in the scenario of integrating third-party hardware of NCHW specification, designated as NCHW will have a significant performance improvement over NHWC. In other scenarios, users can also set as needed. -> - The `configFile` configuration files uses the `key=value` mode to define related parameters. For the configuration parameters related to quantization, please refer to [quantization](https://www.mindspore.cn/lite/docs/en/master/advanced/quantization.html). For the configuration parameters related to extension, please refer to [Extension Configuration](https://www.mindspore.cn/lite/docs/en/master/advanced/third_party/converter_register.html#extension-configuration). +> - The `configFile` configuration files uses the `key=value` mode to define related parameters. For the configuration parameters related to quantization, please refer to [quantization](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/quantization.html). For the configuration parameters related to extension, please refer to [Extension Configuration](https://www.mindspore.cn/lite/docs/en/r2.6.0/advanced/third_party/converter_register.html#extension-configuration). > - `--optimize` parameter is used to set the mode of optimization during the offline conversion. If this parameter is set to none, no relevant graph optimization operations will be performed during the offline conversion phase of the model, and the relevant graph optimization operations will be done during the execution of the inference phase. The advantage of this parameter is that the converted model can be deployed directly to any CPU/GPU/Ascend hardware backend since it is not optimized in a specific way, while the disadvantage is that the initialization time of the model increases during inference execution. If this parameter is set to general, general optimization will be performed, such as constant folding and operator fusion (the converted model only supports CPU/GPU hardware backend, not Ascend backend). If this parameter is set to gpu_oriented, the general optimization and extra optimization for GPU hardware will be performed (the converted model only supports GPU hardware backend). If this parameter is set to ascend_oriented, the optimization for Ascend hardware will be performed (the converted model only supports Ascend hardware backend). -> - The encryption and decryption function only takes effect when `MSLITE_ENABLE_MODEL_ENCRYPTION=on` is set at [compile](https://www.mindspore.cn/lite/docs/en/master/build/build.html) time and only supports Linux x86 platforms, and the key is a string represented by hexadecimal. For example, if the key is defined as `b'0123456789ABCDEF'`, the corresponding hexadecimal representation is `30313233343536373839414243444546`. Users on the Linux platform can use the `xxd` tool to convert the key represented by the bytes to a hexadecimal representation. +> - The encryption and decryption function only takes effect when `MSLITE_ENABLE_MODEL_ENCRYPTION=on` is set at [compile](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html) time and only supports Linux x86 platforms, and the key is a string represented by hexadecimal. For example, if the key is defined as `b'0123456789ABCDEF'`, the corresponding hexadecimal representation is `30313233343536373839414243444546`. Users on the Linux platform can use the `xxd` tool to convert the key represented by the bytes to a hexadecimal representation. It should be noted that the encryption and decryption algorithm has been updated in version 1.7. As a result, the new version of the converter tool does not support the conversion of the encrypted model exported by MindSpore in version 1.6 and earlier. > - Parameters `--input_shape` and dynamicDims are stored in the model during conversion. Call model.get_model_info("input_shape") and model.get_model_info("dynamic_dims") to get it when using the model. @@ -178,7 +178,7 @@ The following describes how to use the conversion command by using several commo To use the MindSpore Lite model conversion tool, the following environment preparations are required. -- [Compile](https://www.mindspore.cn/lite/docs/en/master/build/build.html) or [download](https://www.mindspore.cn/lite/docs/en/master/use/downloads.html) model transfer tool. +- [Compile](https://www.mindspore.cn/lite/docs/en/r2.6.0/build/build.html) or [download](https://www.mindspore.cn/lite/docs/en/r2.6.0/use/downloads.html) model transfer tool. - Add the path of dynamic library required by the conversion tool to the environment variables PATH. @@ -186,7 +186,7 @@ To use the MindSpore Lite model conversion tool, the following environment prepa set PATH=%PACKAGE_ROOT_PATH%\tools\converter\lib;%PATH% ```` - %PACKAGE_ROOT_PATH% is the decompressed package path obtained by compiling or downloading. + ${PACKAGE_ROOT_PATH} is the decompressed package path obtained by compiling or downloading. ### Directory Structure @@ -208,7 +208,7 @@ mindspore-lite-{version}-win-x64 ### Parameter Description -Refer to the Linux environment model conversion tool [parameter description](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html#parameter-description). +Refer to the Linux environment model conversion tool [parameter description](https://www.mindspore.cn/lite/docs/en/r2.6.0/converter/converter_tool.html#parameter-description). ### Example @@ -246,7 +246,7 @@ Several common examples are selected below to illustrate the use of conversion c call converter_lite --fmk=MINDIR --modelFile=model.mindir --outputFile=model ``` - > The `MindIR` model exported by MindSpore v1.1.1 or earlier is recommended to be converted to the `ms` model using the converter tool of the corresponding version. MindSpore v1.1.1 and later versions, the converter tool will be forward compatible. + > The `MindIR` model exported by version earlier than MindSpore v1.1.1 is recommended to be converted to the `ms` model using the converter tool of the corresponding version. In MindSpore v1.1.1 and later versions, the converter tool will be forward compatible. - TensorFlow Lite model`model.tflite` diff --git a/docs/lite/docs/source_en/index.rst b/docs/lite/docs/source_en/index.rst index 87bd4564f10a0c38524bddb8c93e5d27df7ec7da..62aa538a360f74d6b926e49eec6ba8a2f1084851 100644 --- a/docs/lite/docs/source_en/index.rst +++ b/docs/lite/docs/source_en/index.rst @@ -215,7 +215,7 @@ MindSpore Lite Documentation