diff --git a/docs/xai/docs/Makefile b/docs/xai/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..1eff8952707bdfa503c8d60c1e9a903053170ba2 --- /dev/null +++ b/docs/xai/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source_zh_cn +BUILDDIR = build_zh_cn + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/xai/docs/_ext/overwriteobjectiondirective.txt b/docs/xai/docs/_ext/overwriteobjectiondirective.txt new file mode 100644 index 0000000000000000000000000000000000000000..745f53083d102d2a458c5f7d8bd7daaabd63ce1d --- /dev/null +++ b/docs/xai/docs/_ext/overwriteobjectiondirective.txt @@ -0,0 +1,392 @@ +import re +import inspect +import importlib +from typing import Any, Dict, List, Tuple +from typing import cast + +from docutils import nodes +from docutils.nodes import Node +from docutils.parsers.rst import directives, roles + +from sphinx import addnodes +from sphinx.addnodes import desc_signature +from sphinx.deprecation import RemovedInSphinx40Warning, deprecated_alias +from sphinx.util import docutils, logging +from sphinx.util.docfields import DocFieldTransformer, Field, TypedField +from sphinx.util.docutils import SphinxDirective +from sphinx.util.typing import DirectiveOption + +if False: + # For type annotation + from sphinx.application import Sphinx + + +# RE to strip backslash escapes +nl_escape_re = re.compile(r'\\\n') +strip_backslash_re = re.compile(r'\\(.)') + +logger = logging.getLogger(__name__) + +def optional_int(argument): + """ + Check for an integer argument or None value; raise ``ValueError`` if not. + """ + if argument is None: + return None + else: + value = int(argument) + if value < 0: + raise ValueError('negative value; must be positive or zero') + return value + +def get_api(fullname): + try: + module_name, api_name= ".".join(fullname.split('.')[:-1]), fullname.split('.')[-1] + module_import = importlib.import_module(module_name) + except ModuleNotFoundError: + module_name, api_name = ".".join(fullname.split('.')[:-2]), ".".join(fullname.split('.')[-2:]) + module_import = importlib.import_module(module_name) + api = eval(f"module_import.{api_name}") + return api + +def get_example(name: str): + try: + api_doc = inspect.getdoc(get_api(name)) + example_str = re.findall(r'Examples:\n([\w\W]*?)(\n\n|$)', api_doc) + if not example_str: + return [] + example_str = re.sub(r'\n\s+', r'\n', example_str[0][0]) + example_str = example_str.strip() + example_list = example_str.split('\n') + return ["", "**样例:**", ""] + example_list + [""] + except: + return [] + +def get_platforms(name: str): + try: + api_doc = inspect.getdoc(get_api(name)) + example_str = re.findall(r'Supported Platforms:\n\s+(.*?)\n\n', api_doc) + if not example_str: + example_str_leak = re.findall(r'Supported Platforms:\n\s+(.*)', api_doc) + if example_str_leak: + example_str = example_str_leak[0].strip() + example_list = example_str.split('\n') + example_list = [' ' + example_list[0]] + return ["", "支持平台:"] + example_list + [""] + return [] + example_str = example_str[0].strip() + example_list = example_str.split('\n') + example_list = [' ' + example_list[0]] + return ["", "支持平台:"] + example_list + [""] + except: + return [] + +def get_side_effect(name: str): + try: + api_doc = inspect.getdoc(get_api(name)) + side_effect_str = re.findall(r'Side Effects:\n\s+(.*?)\n\n', api_doc) + side_effect_str = side_effect_str[0].strip() + side_effect_list = side_effect_str.split('\n') + return ["", "**副作用:**", ""] + side_effect_list + [""] + except: + return [] + +class ObjectDescription(SphinxDirective): + """ + Directive to describe a class, function or similar object. Not used + directly, but subclassed (in domain-specific directives) to add custom + behavior. + """ + + has_content = True + required_arguments = 1 + optional_arguments = 0 + final_argument_whitespace = True + option_spec = { + 'noindex': directives.flag, + } # type: Dict[str, DirectiveOption] + + # types of doc fields that this directive handles, see sphinx.util.docfields + doc_field_types = [] # type: List[Field] + domain = None # type: str + objtype = None # type: str + indexnode = None # type: addnodes.index + + # Warning: this might be removed in future version. Don't touch this from extensions. + _doc_field_type_map = {} # type: Dict[str, Tuple[Field, bool]] + + def get_field_type_map(self) -> Dict[str, Tuple[Field, bool]]: + if self._doc_field_type_map == {}: + self._doc_field_type_map = {} + for field in self.doc_field_types: + for name in field.names: + self._doc_field_type_map[name] = (field, False) + + if field.is_typed: + typed_field = cast(TypedField, field) + for name in typed_field.typenames: + self._doc_field_type_map[name] = (field, True) + + return self._doc_field_type_map + + def get_signatures(self) -> List[str]: + """ + Retrieve the signatures to document from the directive arguments. By + default, signatures are given as arguments, one per line. + + Backslash-escaping of newlines is supported. + """ + lines = nl_escape_re.sub('', self.arguments[0]).split('\n') + # remove backslashes to support (dummy) escapes; helps Vim highlighting + return [strip_backslash_re.sub(r'\1', line.strip()) for line in lines] + + def handle_signature(self, sig: str, signode: desc_signature) -> Any: + """ + Parse the signature *sig* into individual nodes and append them to + *signode*. If ValueError is raised, parsing is aborted and the whole + *sig* is put into a single desc_name node. + + The return value should be a value that identifies the object. It is + passed to :meth:`add_target_and_index()` unchanged, and otherwise only + used to skip duplicates. + """ + raise ValueError + + def add_target_and_index(self, name: Any, sig: str, signode: desc_signature) -> None: + """ + Add cross-reference IDs and entries to self.indexnode, if applicable. + + *name* is whatever :meth:`handle_signature()` returned. + """ + return # do nothing by default + + def before_content(self) -> None: + """ + Called before parsing content. Used to set information about the current + directive context on the build environment. + """ + pass + + def after_content(self) -> None: + """ + Called after parsing content. Used to reset information about the + current directive context on the build environment. + """ + pass + + def check_class_end(self, content): + for i in content: + if not i.startswith('.. include::') and i != "\n" and i != "": + return False + return True + + def extend_items(self, rst_file, start_num, num): + ls = [] + for i in range(1, num+1): + ls.append((rst_file, start_num+i)) + return ls + + def run(self) -> List[Node]: + """ + Main directive entry function, called by docutils upon encountering the + directive. + + This directive is meant to be quite easily subclassable, so it delegates + to several additional methods. What it does: + + * find out if called as a domain-specific directive, set self.domain + * create a `desc` node to fit all description inside + * parse standard options, currently `noindex` + * create an index node if needed as self.indexnode + * parse all given signatures (as returned by self.get_signatures()) + using self.handle_signature(), which should either return a name + or raise ValueError + * add index entries using self.add_target_and_index() + * parse the content and handle doc fields in it + """ + if ':' in self.name: + self.domain, self.objtype = self.name.split(':', 1) + else: + self.domain, self.objtype = '', self.name + self.indexnode = addnodes.index(entries=[]) + + node = addnodes.desc() + node.document = self.state.document + node['domain'] = self.domain + # 'desctype' is a backwards compatible attribute + node['objtype'] = node['desctype'] = self.objtype + node['noindex'] = noindex = ('noindex' in self.options) + + self.names = [] # type: List[Any] + signatures = self.get_signatures() + for i, sig in enumerate(signatures): + # add a signature node for each signature in the current unit + # and add a reference target for it + signode = addnodes.desc_signature(sig, '') + signode['first'] = False + node.append(signode) + try: + # name can also be a tuple, e.g. (classname, objname); + # this is strictly domain-specific (i.e. no assumptions may + # be made in this base class) + name = self.handle_signature(sig, signode) + except ValueError: + # signature parsing failed + signode.clear() + signode += addnodes.desc_name(sig, sig) + continue # we don't want an index entry here + if name not in self.names: + self.names.append(name) + if not noindex: + # only add target and index entry if this is the first + # description of the object with this name in this desc block + self.add_target_and_index(name, sig, signode) + + contentnode = addnodes.desc_content() + node.append(contentnode) + if self.names: + # needed for association of version{added,changed} directives + self.env.temp_data['object'] = self.names[0] + self.before_content() + try: + example = get_example(self.names[0][0]) + platforms = get_platforms(self.names[0][0]) + side_effect = get_side_effect(self.names[0][0]) + except IndexError: + example = '' + platforms = '' + side_effect = '' + logger.warning(f'Error API names in {self.arguments[0]}.') + extra = side_effect + platforms + example + if extra: + if self.objtype == "method": + self.content.data.extend(extra) + else: + index_num = 0 + for num, i in enumerate(self.content.data): + if i.startswith('.. py:method::') or self.check_class_end(self.content.data[num:]): + index_num = num + break + if index_num: + count = len(self.content.data) + for i in extra: + self.content.data.insert(index_num-count, i) + else: + self.content.data.extend(extra) + try: + self.content.items.extend(self.extend_items(self.content.items[0][0], self.content.items[-1][1], len(extra))) + except IndexError: + logger.warning(f'{self.names[0][0]} has error format.') + self.state.nested_parse(self.content, self.content_offset, contentnode) + self.env.app.emit('object-description-transform', + self.domain, self.objtype, contentnode) + DocFieldTransformer(self).transform_all(contentnode) + self.env.temp_data['object'] = None + self.after_content() + return [self.indexnode, node] + + +class DefaultRole(SphinxDirective): + """ + Set the default interpreted text role. Overridden from docutils. + """ + + optional_arguments = 1 + final_argument_whitespace = False + + def run(self) -> List[Node]: + if not self.arguments: + docutils.unregister_role('') + return [] + role_name = self.arguments[0] + role, messages = roles.role(role_name, self.state_machine.language, + self.lineno, self.state.reporter) + if role: + docutils.register_role('', role) + self.env.temp_data['default_role'] = role_name + else: + literal_block = nodes.literal_block(self.block_text, self.block_text) + reporter = self.state.reporter + error = reporter.error('Unknown interpreted text role "%s".' % role_name, + literal_block, line=self.lineno) + messages += [error] + + return cast(List[nodes.Node], messages) + + +class DefaultDomain(SphinxDirective): + """ + Directive to (re-)set the default domain for this source file. + """ + + has_content = False + required_arguments = 1 + optional_arguments = 0 + final_argument_whitespace = False + option_spec = {} # type: Dict + + def run(self) -> List[Node]: + domain_name = self.arguments[0].lower() + # if domain_name not in env.domains: + # # try searching by label + # for domain in env.domains.values(): + # if domain.label.lower() == domain_name: + # domain_name = domain.name + # break + self.env.temp_data['default_domain'] = self.env.domains.get(domain_name) + return [] + +from sphinx.directives.code import ( # noqa + Highlight, CodeBlock, LiteralInclude +) +from sphinx.directives.other import ( # noqa + TocTree, Author, VersionChange, SeeAlso, + TabularColumns, Centered, Acks, HList, Only, Include, Class +) +from sphinx.directives.patches import ( # noqa + Figure, Meta +) +from sphinx.domains.index import IndexDirective # noqa + +deprecated_alias('sphinx.directives', + { + 'Highlight': Highlight, + 'CodeBlock': CodeBlock, + 'LiteralInclude': LiteralInclude, + 'TocTree': TocTree, + 'Author': Author, + 'Index': IndexDirective, + 'VersionChange': VersionChange, + 'SeeAlso': SeeAlso, + 'TabularColumns': TabularColumns, + 'Centered': Centered, + 'Acks': Acks, + 'HList': HList, + 'Only': Only, + 'Include': Include, + 'Class': Class, + 'Figure': Figure, + 'Meta': Meta, + }, + RemovedInSphinx40Warning) + + +# backwards compatible old name (will be marked deprecated in 3.0) +DescDirective = ObjectDescription + + +def setup(app: "Sphinx") -> Dict[str, Any]: + directives.register_directive('default-role', DefaultRole) + directives.register_directive('default-domain', DefaultDomain) + directives.register_directive('describe', ObjectDescription) + # new, more consistent, name + directives.register_directive('object', ObjectDescription) + + app.add_event('object-description-transform') + + return { + 'version': 'builtin', + 'parallel_read_safe': True, + 'parallel_write_safe': True, + } + diff --git a/docs/xai/docs/_ext/overwriteviewcode.txt b/docs/xai/docs/_ext/overwriteviewcode.txt new file mode 100644 index 0000000000000000000000000000000000000000..130defb02fa41468e2bfda3b3843314c8cd0dcce --- /dev/null +++ b/docs/xai/docs/_ext/overwriteviewcode.txt @@ -0,0 +1,274 @@ +""" + sphinx.ext.viewcode + ~~~~~~~~~~~~~~~~~~~ + + Add links to module code in Python object descriptions. + + :copyright: Copyright 2007-2020 by the Sphinx team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +import traceback +import warnings +from typing import Any, Dict, Iterable, Iterator, Set, Tuple + +from docutils import nodes +from docutils.nodes import Element, Node + +import sphinx +from sphinx import addnodes +from sphinx.application import Sphinx +from sphinx.config import Config +from sphinx.deprecation import RemovedInSphinx30Warning +from sphinx.environment import BuildEnvironment +from sphinx.locale import _, __ +from sphinx.pycode import ModuleAnalyzer +from sphinx.util import get_full_modname, logging, status_iterator +from sphinx.util.nodes import make_refnode + + +logger = logging.getLogger(__name__) + + +def _get_full_modname(app: Sphinx, modname: str, attribute: str) -> str: + try: + return get_full_modname(modname, attribute) + except AttributeError: + # sphinx.ext.viewcode can't follow class instance attribute + # then AttributeError logging output only verbose mode. + logger.verbose('Didn\'t find %s in %s', attribute, modname) + return None + except Exception as e: + # sphinx.ext.viewcode follow python domain directives. + # because of that, if there are no real modules exists that specified + # by py:function or other directives, viewcode emits a lot of warnings. + # It should be displayed only verbose mode. + logger.verbose(traceback.format_exc().rstrip()) + logger.verbose('viewcode can\'t import %s, failed with error "%s"', modname, e) + return None + + +def doctree_read(app: Sphinx, doctree: Node) -> None: + env = app.builder.env + if not hasattr(env, '_viewcode_modules'): + env._viewcode_modules = {} # type: ignore + if app.builder.name == "singlehtml": + return + if app.builder.name.startswith("epub") and not env.config.viewcode_enable_epub: + return + + def has_tag(modname, fullname, docname, refname): + entry = env._viewcode_modules.get(modname, None) # type: ignore + if entry is False: + return + + code_tags = app.emit_firstresult('viewcode-find-source', modname) + if code_tags is None: + try: + analyzer = ModuleAnalyzer.for_module(modname) + analyzer.find_tags() + except Exception: + env._viewcode_modules[modname] = False # type: ignore + return + + code = analyzer.code + tags = analyzer.tags + else: + code, tags = code_tags + + if entry is None or entry[0] != code: + entry = code, tags, {}, refname + env._viewcode_modules[modname] = entry # type: ignore + _, tags, used, _ = entry + if fullname in tags: + used[fullname] = docname + return True + + for objnode in doctree.traverse(addnodes.desc): + if objnode.get('domain') != 'py': + continue + names = set() # type: Set[str] + for signode in objnode: + if not isinstance(signode, addnodes.desc_signature): + continue + modname = signode.get('module') + fullname = signode.get('fullname') + if fullname and modname==None: + if fullname.split('.')[-1].lower() == fullname.split('.')[-1] and fullname.split('.')[-2].lower() != fullname.split('.')[-2]: + modname = '.'.join(fullname.split('.')[:-2]) + fullname = '.'.join(fullname.split('.')[-2:]) + else: + modname = '.'.join(fullname.split('.')[:-1]) + fullname = fullname.split('.')[-1] + fullname_new = fullname + # logger.warning(f'modename:{modname}') + # logger.warning(f'fullname:{fullname}') + refname = modname + if env.config.viewcode_follow_imported_members: + new_modname = app.emit_firstresult( + 'viewcode-follow-imported', modname, fullname, + ) + if not new_modname: + new_modname = _get_full_modname(app, modname, fullname) + modname = new_modname + # logger.warning(f'new_modename:{modname}') + if not modname: + continue + # fullname = signode.get('fullname') + # if fullname and modname==None: + fullname = fullname_new + if not has_tag(modname, fullname, env.docname, refname): + continue + if fullname in names: + # only one link per name, please + continue + names.add(fullname) + pagename = '_modules/' + modname.replace('.', '/') + inline = nodes.inline('', _('[源代码]'), classes=['viewcode-link']) + onlynode = addnodes.only(expr='html') + onlynode += addnodes.pending_xref('', inline, reftype='viewcode', refdomain='std', + refexplicit=False, reftarget=pagename, + refid=fullname, refdoc=env.docname) + signode += onlynode + + +def env_merge_info(app: Sphinx, env: BuildEnvironment, docnames: Iterable[str], + other: BuildEnvironment) -> None: + if not hasattr(other, '_viewcode_modules'): + return + # create a _viewcode_modules dict on the main environment + if not hasattr(env, '_viewcode_modules'): + env._viewcode_modules = {} # type: ignore + # now merge in the information from the subprocess + env._viewcode_modules.update(other._viewcode_modules) # type: ignore + + +def missing_reference(app: Sphinx, env: BuildEnvironment, node: Element, contnode: Node + ) -> Node: + # resolve our "viewcode" reference nodes -- they need special treatment + if node['reftype'] == 'viewcode': + return make_refnode(app.builder, node['refdoc'], node['reftarget'], + node['refid'], contnode) + + return None + + +def collect_pages(app: Sphinx) -> Iterator[Tuple[str, Dict[str, Any], str]]: + env = app.builder.env + if not hasattr(env, '_viewcode_modules'): + return + highlighter = app.builder.highlighter # type: ignore + urito = app.builder.get_relative_uri + + modnames = set(env._viewcode_modules) # type: ignore + + for modname, entry in status_iterator( + sorted(env._viewcode_modules.items()), # type: ignore + __('highlighting module code... '), "blue", + len(env._viewcode_modules), # type: ignore + app.verbosity, lambda x: x[0]): + if not entry: + continue + code, tags, used, refname = entry + # construct a page name for the highlighted source + pagename = '_modules/' + modname.replace('.', '/') + # highlight the source using the builder's highlighter + if env.config.highlight_language in ('python3', 'default', 'none'): + lexer = env.config.highlight_language + else: + lexer = 'python' + highlighted = highlighter.highlight_block(code, lexer, linenos=False) + # split the code into lines + lines = highlighted.splitlines() + # split off wrap markup from the first line of the actual code + before, after = lines[0].split('
')
+        lines[0:1] = [before + '
', after]
+        # nothing to do for the last line; it always starts with 
anyway + # now that we have code lines (starting at index 1), insert anchors for + # the collected tags (HACK: this only works if the tag boundaries are + # properly nested!) + maxindex = len(lines) - 1 + for name, docname in used.items(): + type, start, end = tags[name] + backlink = urito(pagename, docname) + '#' + refname + '.' + name + lines[start] = ( + '
%s' % (name, backlink, _('[文档]')) + + lines[start]) + lines[min(end, maxindex)] += '
' + # try to find parents (for submodules) + parents = [] + parent = modname + while '.' in parent: + parent = parent.rsplit('.', 1)[0] + if parent in modnames: + parents.append({ + 'link': urito(pagename, '_modules/' + + parent.replace('.', '/')), + 'title': parent}) + parents.append({'link': urito(pagename, '_modules/index'), + 'title': _('Module code')}) + parents.reverse() + # putting it all together + context = { + 'parents': parents, + 'title': modname, + 'body': (_('

Source code for %s

') % modname + + '\n'.join(lines)), + } + yield (pagename, context, 'page.html') + + if not modnames: + return + + html = ['\n'] + # the stack logic is needed for using nested lists for submodules + stack = [''] + for modname in sorted(modnames): + if modname.startswith(stack[-1]): + stack.append(modname + '.') + html.append('') + stack.append(modname + '.') + html.append('
  • %s
  • \n' % ( + urito('_modules/index', '_modules/' + modname.replace('.', '/')), + modname)) + html.append('' * (len(stack) - 1)) + context = { + 'title': _('Overview: module code'), + 'body': (_('

    All modules for which code is available

    ') + + ''.join(html)), + } + + yield ('_modules/index', context, 'page.html') + + +def migrate_viewcode_import(app: Sphinx, config: Config) -> None: + if config.viewcode_import is not None: + warnings.warn('viewcode_import was renamed to viewcode_follow_imported_members. ' + 'Please update your configuration.', + RemovedInSphinx30Warning, stacklevel=2) + + +def setup(app: Sphinx) -> Dict[str, Any]: + app.add_config_value('viewcode_import', None, False) + app.add_config_value('viewcode_enable_epub', False, False) + app.add_config_value('viewcode_follow_imported_members', True, False) + app.connect('config-inited', migrate_viewcode_import) + app.connect('doctree-read', doctree_read) + app.connect('env-merge-info', env_merge_info) + app.connect('html-collect-pages', collect_pages) + app.connect('missing-reference', missing_reference) + # app.add_config_value('viewcode_include_modules', [], 'env') + # app.add_config_value('viewcode_exclude_modules', [], 'env') + app.add_event('viewcode-find-source') + app.add_event('viewcode-follow-imported') + return { + 'version': sphinx.__display_version__, + 'env_version': 1, + 'parallel_read_safe': True + } diff --git a/docs/xai/docs/requirements.txt b/docs/xai/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..49a77fdec3a5c745edd40eaa223883c31500e975 --- /dev/null +++ b/docs/xai/docs/requirements.txt @@ -0,0 +1,10 @@ +sphinx >= 2.2.1, <= 2.4.4 +docutils == 0.16 +myst_parser == 0.14.0 +sphinx-markdown-tables +sphinx_rtd_theme == 0.5.2 +numpy +nbsphinx +IPython +ipykernel +jieba diff --git a/docs/xai/docs/source_en/conf.py b/docs/xai/docs/source_en/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..1603e8bf6b59f180623ecd95d7adf6c4d6c5d6ae --- /dev/null +++ b/docs/xai/docs/source_en/conf.py @@ -0,0 +1,129 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +import IPython +import re +sys.path.append(os.path.abspath('../_ext')) +from sphinx.ext import autodoc as sphinx_autodoc + +import mindspore_xai + +# -- Project information ----------------------------------------------------- + +project = 'MindSpore' +copyright = '2021, MindSpore' +author = 'MindSpore' + +# The full version, including alpha/beta/rc tags +release = 'master' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.coverage', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_markdown_tables', + 'myst_parser', + 'nbsphinx', + 'sphinx.ext.mathjax', + 'IPython.sphinxext.ipython_console_highlighting' +] + +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + +suppress_warnings = [ + 'nbsphinx', +] + +pygments_style = 'sphinx' + +autodoc_inherit_docstrings = False + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + 'python': ('https://docs.python.org/', '../../../../resource/python_objects.inv'), + 'numpy': ('https://docs.scipy.org/doc/numpy/', '../../../../resource/numpy_objects.inv'), +} + +# Modify default signatures for autodoc. +autodoc_source_path = os.path.abspath(sphinx_autodoc.__file__) +autodoc_source_re = re.compile(r'stringify_signature\(.*?\)') +get_param_func_str = r"""\ +import re +import inspect as inspect_ + +def get_param_func(func): + try: + source_code = inspect_.getsource(func) + if func.__doc__: + source_code = source_code.replace(func.__doc__, '') + all_params_str = re.findall(r"def [\w_\d\-]+\(([\S\s]*?)(\):|\) ->.*?:)", source_code) + all_params = re.sub("(self|cls)(,|, )?", '', all_params_str[0][0].replace("\n", "").replace("'", "\"")) + return all_params + except: + return '' + +def get_obj(obj): + if isinstance(obj, type): + return obj.__init__ + + return obj +""" + +with open(autodoc_source_path, "r+", encoding="utf8") as f: + code_str = f.read() + code_str = autodoc_source_re.sub('"(" + get_param_func(get_obj(self.object)) + ")"', code_str, count=0) + exec(get_param_func_str, sphinx_autodoc.__dict__) + exec(code_str, sphinx_autodoc.__dict__) + +sys.path.append(os.path.abspath('../../../../resource/sphinx_ext')) +import anchor_mod +import nbsphinx_mod + + +sys.path.append(os.path.abspath('../../../../resource/search')) +import search_code + +sys.path.append(os.path.abspath('../../../../resource/custom_directives')) +from custom_directives import IncludeCodeDirective + +def setup(app): + app.add_directive('includecode', IncludeCodeDirective) diff --git a/docs/xai/docs/source_en/images/grad_cam_saliency.png b/docs/xai/docs/source_en/images/grad_cam_saliency.png new file mode 100644 index 0000000000000000000000000000000000000000..62ef2e908621dbcebf805151a1f8e37bd894884b Binary files /dev/null and b/docs/xai/docs/source_en/images/grad_cam_saliency.png differ diff --git a/docs/xai/docs/source_en/images/lime_tabular.png b/docs/xai/docs/source_en/images/lime_tabular.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a05861c0f1c69a1d0f55e0b7b0aa7bc427201f Binary files /dev/null and b/docs/xai/docs/source_en/images/lime_tabular.png differ diff --git a/docs/xai/docs/source_en/images/rise_plus_saliency.png b/docs/xai/docs/source_en/images/rise_plus_saliency.png new file mode 100644 index 0000000000000000000000000000000000000000..3c0f4c43395dd2baa10a1d1441a73acb7aa9d6a5 Binary files /dev/null and b/docs/xai/docs/source_en/images/rise_plus_saliency.png differ diff --git a/docs/xai/docs/source_en/images/saliency_overlay.png b/docs/xai/docs/source_en/images/saliency_overlay.png new file mode 100644 index 0000000000000000000000000000000000000000..19fa0ee5cce1bc041462747debf679b162a3f08a Binary files /dev/null and b/docs/xai/docs/source_en/images/saliency_overlay.png differ diff --git a/docs/xai/docs/source_en/images/shap_gradient.png b/docs/xai/docs/source_en/images/shap_gradient.png new file mode 100644 index 0000000000000000000000000000000000000000..5ad88ddb6f32d1b50739892f8a2829431e8655e8 Binary files /dev/null and b/docs/xai/docs/source_en/images/shap_gradient.png differ diff --git a/docs/xai/docs/source_en/images/shap_kernel.png b/docs/xai/docs/source_en/images/shap_kernel.png new file mode 100644 index 0000000000000000000000000000000000000000..0cd7e1b5e1951dd6633b66afd2a5d7081fa83d90 Binary files /dev/null and b/docs/xai/docs/source_en/images/shap_kernel.png differ diff --git a/docs/xai/docs/source_en/images/xai_en.png b/docs/xai/docs/source_en/images/xai_en.png new file mode 100644 index 0000000000000000000000000000000000000000..a169b965dbbf4ea369279ef257c3e3ae77bd5bda Binary files /dev/null and b/docs/xai/docs/source_en/images/xai_en.png differ diff --git a/docs/xai/docs/source_en/index.rst b/docs/xai/docs/source_en/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..29e1d1a0d8796aeac66123119ff4052024c3c416 --- /dev/null +++ b/docs/xai/docs/source_en/index.rst @@ -0,0 +1,40 @@ +MindSpore XAI Documents +=========================== + +Currently, most deep learning models are black-box models with good performance but poor explainability. The MindSpore XAI - a MindSpore-based explainable AI toolbox - provides a variety of explanation and decision methods to help you better understand, trust, and improve models. It also evaluates the explanation methods from various dimensions, enabling you to compare and select methods best suited to your environment. + +.. raw:: html + + + +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Installation + + installation + +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Guide + + using_cv_explainers + using_cv_benchmarks + using_tabular_explainers + using_tabsim + using_tbnet + +.. toctree:: + :maxdepth: 1 + :caption: API References + + mindspore_xai.explainer + mindspore_xai.benchmark + mindspore_xai.tool + +.. toctree:: + :maxdepth: 1 + :caption: FAQ + + troubleshoot diff --git a/docs/xai/docs/source_en/installation.md b/docs/xai/docs/source_en/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..cbefaebe068ff298d7843cfbf4f40b865c3ee5b3 --- /dev/null +++ b/docs/xai/docs/source_en/installation.md @@ -0,0 +1,55 @@ +# MindSpore XAI Installation + + + +## System Requirements + +- OS: EulerOS-aarch64, CentOS-aarch64, CentOS-x86, Ubuntu-aarch64 or Ubuntu-x86 +- Device: Ascend 910 or GPU CUDA 10.1, 11.1 +- Python 3.7.5 or 3.9.0 +- MindSpore 1.7 or above + +## Installing by pip + +Download the `.whl` package from [MindSpore XAI download page](https://www.mindspore.cn/versions/en) and install with `pip`. + +```bash +pip install mindspore_xai-{version}-py3-none-any.whl +``` + +## Installing from Source Code + +1. Download source code from gitee.com: + + ```bash + git clone https://gitee.com/mindspore/xai.git + ``` + +2. Install the dependency python modules: + + ```bash + cd xai + pip install -r requirements.txt + ``` + +3. Install the XAI module from source code: + + ```bash + python setup.py install + ``` + +4. Optionally, you may build a `.whl` package for installation without step 3: + + ```bash + bash package.sh + pip install output/mindspore_xai-{version}-py3-none-any.whl + ``` + +## Installation Verification + +Upon successful installation, importing 'mindspore_xai' module in Python will cause no error: + +```python +import mindspore_xai +print(mindspore_xai.__version__) +``` diff --git a/docs/xai/docs/source_en/mindspore_xai.benchmark.rst b/docs/xai/docs/source_en/mindspore_xai.benchmark.rst new file mode 100644 index 0000000000000000000000000000000000000000..cdc230be7fb5a8ddb6cf52a735b9e247c0db0dc3 --- /dev/null +++ b/docs/xai/docs/source_en/mindspore_xai.benchmark.rst @@ -0,0 +1,5 @@ +mindspore_xai.benchmark +======================= + +.. automodule:: mindspore_xai.benchmark + :members: diff --git a/docs/xai/docs/source_en/mindspore_xai.explainer.rst b/docs/xai/docs/source_en/mindspore_xai.explainer.rst new file mode 100644 index 0000000000000000000000000000000000000000..788e850f3ff56d21347009bc6a4c5f6c29f6f587 --- /dev/null +++ b/docs/xai/docs/source_en/mindspore_xai.explainer.rst @@ -0,0 +1,5 @@ +mindspore_xai.explainer +========================= + +.. automodule:: mindspore_xai.explainer + :members: diff --git a/docs/xai/docs/source_en/mindspore_xai.tool.rst b/docs/xai/docs/source_en/mindspore_xai.tool.rst new file mode 100644 index 0000000000000000000000000000000000000000..98d0533c178fc70e3c50cc6cc25952ea7a352b2c --- /dev/null +++ b/docs/xai/docs/source_en/mindspore_xai.tool.rst @@ -0,0 +1,5 @@ +mindspore_xai.tool +========================= + +.. automodule:: mindspore_xai.tool.cv + :members: diff --git a/docs/xai/docs/source_en/troubleshoot.md b/docs/xai/docs/source_en/troubleshoot.md new file mode 100644 index 0000000000000000000000000000000000000000..44455a80f148e0ca31a55f41a8fb72b65c567248 --- /dev/null +++ b/docs/xai/docs/source_en/troubleshoot.md @@ -0,0 +1,36 @@ +# Troubleshooting + + + +## Import Errors + +**Q: What can I do if libgomp `cannot allocate memory in static TLS block` error occurs when importing `mindspore_xai` or its subpackages?** + +A: You have to do the following steps: + +Reinstall scikit-learn 1.0.2: + +```bash +pip install --force-reinstall scikit-learn==1.0.2 +``` + +List all site-packages directories: + +```bash +python -m site +``` + +There is a list of directories in `sys.path` shown the previous, find the `scikit_learn.libs/` sub-directory in the directories of the list. +Once you located the `scikit_learn.libs/` directory, say which is underneath `/`, list files inside it: + +```bash +ls /scikit_learn.libs +``` + +There is a dynamical library libgomp inside with a filename like `libgomp-XXX.so.XXX`, append the absolute path to the environment variable `LD_PRELOAD`: + +```bash +export LD_PRELOAD=$LD_PRELOAD:/scikit_learn.libs/libgomp-XXX.so.XXX +``` + +Run your MindSpore XAI scripts again. diff --git a/docs/xai/docs/source_en/using_cv_benchmarks.md b/docs/xai/docs/source_en/using_cv_benchmarks.md new file mode 100644 index 0000000000000000000000000000000000000000..06a75025d0133de79c02fa1e22c8665a292893fd --- /dev/null +++ b/docs/xai/docs/source_en/using_cv_benchmarks.md @@ -0,0 +1,84 @@ +# Using CV Benchmarks + + + +## What are CV Benchmarks + +Benchmarks are algorithms evaluating the goodness of saliency maps from explainers. MindSpore XAI currently provides 4 benchmarks for image classification scenario: `Robustness`, `Faithfulness`, `ClassSensitivity` and `Localization`. + +## Preparations + +The complete code of the tutorial below is [using_cv_benchmarks.py](https://gitee.com/mindspore/xai/blob/r1.8/examples/using_cv_benchmarks.py). + +Please follow the [Downloading Data Package](https://www.mindspore.cn/xai/docs/en/r1.8/using_cv_explainers.html#downloading-data-package) instructions to download the necessary files for the tutorial. + +With the tutorial package, we have to get the sample image, trained classifier, explainer and optionally the saliency map ready: + +```python +# have to change the current directory to xai/examples/ first +import mindspore as ms +from mindspore_xai.explainer import GradCAM + +from common.resnet import resnet50 +from common.dataset import load_image_tensor + +# only PYNATIVE_MODE is supported +ms.set_context(mode=ms.PYNATIVE_MODE) + +# 20 classes +num_classes = 20 + +# load the trained classifier +net = resnet50(num_classes) +param_dict = ms.load_checkpoint("xai_examples_data/ckpt/resnet50.ckpt") +ms.load_param_into_net(net, param_dict) + +# [1, 3, 224, 224] Tensor +boat_image = load_image_tensor('xai_examples_data/test/boat.jpg') + +# explainer +grad_cam = GradCAM(net, layer='layer4') + +# 3 is the class id of 'boat' +saliency = grad_cam(boat_image, targets=3) +``` + +## Using Robustness + +`Robustness` is the simplest benchmark, it perturbs the inputs by adding random noise and outputs the maximum sensitivity as evaluation score from the perturbations: + +```python +from mindspore.nn import Softmax +from mindspore_xai.benchmark import Robustness + +# the classifier use Softmax as activation function +robustness = Robustness(num_classes, activation_fn=Softmax()) +# the 'saliency' argument is optional +score = robustness.evaluate(grad_cam, boat_image, targets=3, saliency=saliency) +``` + +The returned `score` is a 1D tensor with only one float value for an 1xCx224x224 image tensor. + +## Using Faithfulness and ClassSensitivity + +The ways of using `Faithfulness` and `ClassSensitivity` are very similar to `Robustness`. However, `ClassSensitivity` is class agnostic, `targets` can not be specified. + +## Using Localization + +If the object region or bounding box is provided, `Localization` can be used. It evaluates base on how many saliency pixels fall inside the object region: + +```python +import numpy as np +import mindspore as ms +from mindspore_xai.benchmark import Localization + +# top-left:80,66 bottom-right:223,196 is the bounding box of a boat +mask = np.zeros([1, 1, 224, 224]) +mask[:, :, 66:196, 80:223] = 1 + +mask = ms.Tensor(mask, dtype=ms.float32) + +localization = Localization(num_classes) + +score = localization.evaluate(grad_cam, boat_image, targets=3, mask=mask) +``` diff --git a/docs/xai/docs/source_en/using_cv_explainers.md b/docs/xai/docs/source_en/using_cv_explainers.md new file mode 100644 index 0000000000000000000000000000000000000000..41673c3b90f188ddb0b396e46d93b3050aa4df51 --- /dev/null +++ b/docs/xai/docs/source_en/using_cv_explainers.md @@ -0,0 +1,246 @@ +# Using CV Explainers + + + +## What are CV Explainers + +Explainers are algorithms explaining the decisions made by AI models. MindSpore XAI currently provides 7 explainers for image classification scenario. Saliency maps (or heatmaps) are the outputs, their brightness represents the importance of the corresponding regions on the original image. + +A saliency map overlay on top of the original image: + +![saliency_overlay](./images/saliency_overlay.png) + +There are 2 categories of explainers: gradient based and perturbation based. The gradient based explainers rely on the backpropagation method to compute the pixel importance while the perturbation based explainers exploit random perturbations on the original images. + +| Explainer | Category | +|:------------------:|:---------------:| +| Gradient | gradient | +| GradCAM | gradient | +| GuidedBackprop | gradient | +| Deconvolution | gradient | +| Occlusion | perturbation | +| RISE | perturbation | +| RISEPlus | perturbation | + +## Preparations + +### Downloading Data Package + +First of all, we have to download the data package and put it underneath the `xai/examples/` directory of a local XAI [source package](https://gitee.com/mindspore/xai): + +```bash +wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/xai_examples_data.tar.gz +tar -xf xai_examples_data.tar.gz + +git clone https://gitee.com/mindspore/xai +mv xai_examples_data xai/examples/ +``` + +`xai/examples/` files: + +```bash +xai/examples/ +├── xai_examples_data/ +│ ├── ckpt/ +│ │ ├── resent50.ckpt +│ ├── train/ +│ └── test/ +├── common/ +│ ├── dataset.py +│ └── resnet.py +├── using_cv_explainers.py +├── using_rise_plus.py +└── using_cv_benchmarks.py +``` + +- `xai_examples_data/`: The extracted data package. +- `xai_examples_data/ckpt/resent50.ckpt`: ResNet50 checkpoint file. +- `xai_examples_data/test`: Test dataset. +- `xai_examples_data/train`: Training dataset. +- `common/dataset.py`: Dataset loader. +- `common/resnet.py`: ResNet model definitions. +- `using_cv_explainers.py`: Example of using explainers. +- `using_rise_plus.py`: Example of using RISEPlus explainer. +- `using_cv_benchmarks.py`: Example of using benchmarks. + +### Preparing Python Environment + +The complete code of the tutorial below is [using_cv_explainers.py](https://gitee.com/mindspore/xai/blob/r1.8/examples/using_cv_explainers.py). + +In order to explain an image classification predication, we have to have a trained CNN network (`nn.Cell`) and an image to be examined: + +```python +# have to change the current directory to xai/examples/ first +import mindspore as ms +from common.resnet import resnet50 +from common.dataset import load_image_tensor + +# only PYNATIVE_MODE is supported +ms.set_context(mode=ms.PYNATIVE_MODE) + +# 20 classes +num_classes = 20 + +# load the trained classifier +net = resnet50(num_classes) +param_dict = ms.load_checkpoint("xai_examples_data/ckpt/resnet50.ckpt") +ms.load_param_into_net(net, param_dict) + +# [1, 3, 224, 224] Tensor +boat_image = load_image_tensor("xai_examples_data/test/boat.jpg") +``` + +## Using GradCAM + +`GradCAM` is a typical and effective gradient based explainer: + +```python +from PIL import Image +import mindspore as ms +from mindspore_xai.explainer import GradCAM +from mindspore_xai.visual.cv import saliency_to_image + +# usually specify the last convolutional layer +grad_cam = GradCAM(net, layer="layer4") + +# 3 is the class id of 'boat' +saliency = grad_cam(boat_image, targets=3, show=False) + +# convert the saliency map to a PIL.Image.Image object +boat_img = Image.open("xai_examples_data/test/boat.jpg") +saliency_to_image(saliency, boat_img) +``` + +The returned `saliency` is a 1x1x224x224 tensor for an 1xCx224x224 image tensor, which stores all pixel importances (range:[0.0, 1.0]) to the classification decision of 'boat'. Users may specify any class to be explained. + +![grad_cam_saliency](./images/grad_cam_saliency.png) + +### Batch Explanation + +For gradient based explainers, batch explanation is usually more efficient. Other explainers may also batch the evaluations: + +```python +from common.dataset import load_dataset + +test_ds = load_dataset('xai_examples_data/test').batch(4) + +for images, labels in test_ds: + saliencies = grad_cam(images, targets=ms.Tensor([3, 3, 3, 3], dtype=ms.int32)) + # other custom operations ... +``` + +The returned `saliencies` is a 4x1x224x224 tensor for a 4xCx224x224 batched image tensor. + +### Using Other Explainers + +The ways of using other explainers are very similar to `GradCAM`, except `RISEPlus`. + +## Using RISEPlus + +The complete code of the tutorial below is [using_rise_plus.py](https://gitee.com/mindspore/xai/blob/r1.8/examples/using_rise_plus.py). + +`RISEPlus` is based on `RISE` with an introduction of Out-of-Distribution(OoD) detector. It solves the degeneration problem of `RISE` on samples that the classifier had never seem the similar in training. + +First, we need to train an OoD detector(`OoDNet`) with the classifier training dataset: + +```python +# have to change the current directory to xai/examples/ first +import mindspore as ms +from mindspore.nn import Softmax, SoftmaxCrossEntropyWithLogits +from mindspore_xai.tool.cv import OoDNet +from mindspore_xai.explainer import RISEPlus +from common.dataset import load_dataset +from common.resnet import resnet50 + +# only PYNATIVE_MODE is supported +ms.set_context(mode=ms.PYNATIVE_MODE) + +num_classes = 20 + +# classifier training dataset +train_ds = load_dataset('xai_examples_data/train').batch(4) + +# load the trained classifier +net = resnet50(num_classes) +param_dict = ms.load_checkpoint('xai_examples_data/ckpt/resnet50.ckpt') +ms.load_param_into_net(net, param_dict) + +ood_net = OoDNet(underlying=net, num_classes=num_classes) + +# use SoftmaxCrossEntropyWithLogits as loss function if the activation function of +# the classifier is Softmax, use BCEWithLogitsLoss if the activation function is Sigmoid +ood_net.train(train_ds, loss_fn=SoftmaxCrossEntropyWithLogits()) + +ms.save_checkpoint(ood_net, 'ood_net.ckpt') +``` + +The classifier for `OoDNet` must be a subclass of `nn.Cell`, in `__init__()` which must: + +- defines an `int` member attribute named `num_features` as the number of feature values to be returned by the feature layer. + +- defines a `bool` member attribute named `output_features` with `False` as initial value, OoDNet tells the classifier to return the feature tensor in `construct()` by setting `output_features` to `True`. + +A LeNet5 example of underlying classifier: + +```python +from mindspore import nn +from mindspore.common.initializer import Normal + +class MyLeNet5(nn.Cell): + + def __init__(self, num_class, num_channel): + super(MyLeNet5, self).__init__() + + # must add the following 2 attributes to your model + self.num_features = 84 # no. of features, int + self.output_features = False # output features flag, bool + + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, self.num_features, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(self.num_features, num_class, weight_init=Normal(0.02)) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + + # return the features tensor if output_features is True + if self.output_features: + return x + + x = self.fc3(x) + return x +``` + +Now we can use `RISEPlus` with the trained `OoDNet`: + +```python +from PIL import Image +from mindspore_xai.visual.cv import saliency_to_image + +# create a new classifier as the underlying when loading OoDNet from a checkpoint +ood_net = OoDNet(underlying=resnet50(num_classes), num_classes=num_classes) +param_dict = ms.load_checkpoint('ood_net.ckpt') +ms.load_param_into_net(ood_net, param_dict) + +rise_plus = RISEPlus(ood_net=ood_net, network=net, activation_fn=Softmax()) +saliency = rise_plus(boat_image, targets=3, show=False) + +boat_img = Image.open("xai_examples_data/test/boat.jpg") +saliency_to_image(saliency, boat_img) +``` + +The returned `saliency` is an 1x1x224x224 tensor for an 1xCx224x224 image tensor. + +![rise_plus_saliency](./images/rise_plus_saliency.png) diff --git a/docs/xai/docs/source_en/using_tabsim.md b/docs/xai/docs/source_en/using_tabsim.md new file mode 100644 index 0000000000000000000000000000000000000000..b41b3bb8d59a16cbfecacaac50d137fbd9a98d6b --- /dev/null +++ b/docs/xai/docs/source_en/using_tabsim.md @@ -0,0 +1,176 @@ +# Using TabSim Data Simulator + + + +## Introduction + +Sometimes, it is impossible to get sufficient amount of data for modeling or algorithm development, TabSim can be used for capturing the distribution of the tabular data and generating simulated data afterward. + +The complete code of the tutorial below is [using_tabsim.py](https://gitee.com/mindspore/xai/blob/r1.8/examples/using_tabsim.py). + +## Installation + +TabSim is part of the XAI package, no extra installation is required besides [MindSpore](https://mindspore.cn/install/en) and [XAI](https://www.mindspore.cn/xai/docs/en/r1.8/installation.html). + +## User Flow + +There are 2 phases in the TabSim user flow: + +1. Digestion: Analyzing the real tabular data, capturing the statistic characteristics and output the digest file. Accomplished by the commandline tool `mindspore_xai tabdig`. +2. Simulation: Generating simulated data according to the statistics stores in the digested file. Accomplished by the commandline tool `mindspore_xai tabsim`. + +## Digestion Phase + +```bash +mindspore_xai tabdig [--bins ] [--clip-sd ] +``` + +``: Path of the real CSV table to be simulated. + +``: Path of the digest file to be saved. + +``: [optional] Number of bins [2 - 32] for discretizing numeric columns, default: 10 + +``: [optional] Number of standard deviations away from the mean that defines the outliers, outlier values +will be clipped. default: 3, setting to 0 or less will disable the value clipping. + +### File Format of Real Data + +The real data must be a CSV file with header that contains the names and type of all columns. + +Header pattern: `|,|,|,...` + +``: Column name, allowed pattern: `[0-9a-zA-Z_\-]+` + +``: Column type, options: 'int', 'float', 'str', 'cat' + +- 'int': Integers +- 'float': Float numbers +- 'str': Strings, allowed pattern: `[0-9a-zA-Z_\-\+\.]*` +- 'cat': Catergorical values, the underlying data type is integer without order + +'int' and 'float' are numeric, while 'str' and 'cat' are discrete columns. There are at most 256 distinct values allowed +in each discrete column. + +Optionally, users may specify at most one label column by adding a '*' before a discrete column (numeric columns are not +allowed). + +Header example: `col_A|int,col_B|float,col_C|str,*col_D|cat` + +It is recommended users randomly pick around 1 million records from the real database to form the real data file for the +statistic accuracy and the memory constraints. + +### File Format of Digest Files + +Digest files are clear text json, they stores the name, type and value distributions of columns without any actual record. +Users should never modify the digest file manually; otherwise that may corrupt it. + +### Digestion Example + +We used the [Iris](https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html) dataset for the +demonstration. This dataset consists of 3 different types of irises’ petal and sepal lengths. The Python code below +writes tabular data to `real_table.csv`. + +```python +import sklearn.datasets + +iris = sklearn.datasets.load_iris() +features = iris.data +labels = iris.target +# save the tabular data to file +header = 'sepal_length|float,sepal_width|float,petal_length|float,petal_width|float,*class|cat' +with open('real_table.csv', 'w') as f: + f.write(header + '\n') + for i in range(len(labels)): + for feat in features[i]: + f.write("{},".format(feat)) + f.write("{}\n".format(labels[i])) +``` + +Content of `real_table.csv`: + +```text +sepal_length|float,sepal_width|float,petal_length|float,petal_width|float,*class|cat +5.1,3.5,1.4,0.2,0 +4.9,3.0,1.4,0.2,0 +4.7,3.2,1.3,0.2,0 +4.6,3.1,1.5,0.2,0 +5.0,3.6,1.4,0.2,0 +5.4,3.9,1.7,0.4,0 +4.6,3.4,1.4,0.3,0 +5.0,3.4,1.5,0.2,0 +4.4,2.9,1.4,0.2,0 +... +``` + +Then, we analyze the real tabular data, capturing the statistic characteristics and output it to `digest.json`. + +```bash +mindspore_xai tabdig real_table.csv digest.json +``` + +Content of `digest.json`: + +```json +{ + "label_col_idx": 4, + "columns": [ + { + "name": "sepal_length", + "idx": 0, + "ctype": "float", + "dtype": "float", + "is_numeric": true, + "is_label": false, + ... +``` + +## Simulation Phase + +```bash +mindspore_xai tabsim [--batch-size ] [--noise ] +``` + +``: Path of the digest file of the real data. + +``: Path of the simulated CSV table. + +``: Number of rows to be generated to ``. + +``: [optional] Number of rows in each batch, default: 10000 + +``: [optional] Noise level (0.0-1.0) of value picking probabilities, 0.0 means 100% follows the digested joint +distributions, higher the noise level more even the probabilities. default: 0 + +### File Format of Simulated Data + +The simulated CSV file has a similar format to the real data CSV file, but the header is different: + +`,,,...` + +It contains no `` and no `*` while the column order remains the same. + +### Simulation Example + +Here we generate 200000 rows of simulated data according to the statistics stored in the digested file, the simulated +data is output to `sim_table.csv`. + +```bash +mindspore_xai tabsim digest.json sim_table.csv 200000 +``` + +Content of `sim_table.csv`: + +```text +sepal_length,sepal_width,petal_length,petal_width,class +5.577184113278916,2.600922272560204,4.432243573999988,1.3937476921377445,1 +6.723739024436704,2.7995789972671985,4.093195099230183,1.377081159510022,1 +4.787110003892638,2.8994714750972608,1.221068662892122,0.18023497892950327,0 +5.47601589088659,2.683719381022501,4.429520567795243,1.44376166769605,1 +5.713634033969561,2.238437659593092,4.468051986603512,1.5218876291352155,1 +6.014412107785783,2.921972441210267,4.066770696930024,0.9183809029577147,1 +6.188386742135447,2.92122446931648,5.288862927543273,1.4537708701756062,2 +7.394485586937094,2.867479423550221,5.730391070749579,1.998759192383688,2 +5.468839597899383,2.8957462954323083,4.4090170094158525,1.502682955942951,1 +... +``` diff --git a/docs/xai/docs/source_en/using_tabular_explainers.md b/docs/xai/docs/source_en/using_tabular_explainers.md new file mode 100644 index 0000000000000000000000000000000000000000..af60703c60d6f313be509e3f9a99fc2d77d948d6 --- /dev/null +++ b/docs/xai/docs/source_en/using_tabular_explainers.md @@ -0,0 +1,183 @@ +# Using Tabular Explainers + + + +## Introduction + +In this tutorial we explain the tabular data classification using 3 different explainers, including `LIMETabular`, +`SHAPKernel`, and `SHAPGradient`. + +All explainers support `PYNATIVE_MODE`. All explainers except `SHAPGradient` support `GRAPH_MODE`. + +| Explainer | PYNATIVE_MODE | GRAPH_MODE | +|:------------:|:-------------------:|:------------------:| +| LIMETabular | Supported | Supported | +| SHAPKernel | Supported | Supported | +| SHAPGradient | Supported | | + +The complete code of the tutorial below is [using_tabular_explainers.py](https://gitee.com/mindspore/xai/blob/r1.8/examples/using_tabular_explainers.py). + +## Import Dataset + +We use the [Iris](https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html) dataset for the demonstration. +These data sets consist of 3 different types of irises’ petal and sepal lengths. + +```python +import sklearn.datasets +import mindspore as ms + +iris = sklearn.datasets.load_iris() + +# feature_names: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] +feature_names = iris.feature_names +# class_names: ['setosa', 'versicolor', 'virginica'] +class_names = list(iris.target_names) + +# convert data and labels from numpy array to mindspore tensor +# use the first 100 samples +data = ms.Tensor(iris.data, ms.float32)[:100] +labels = ms.Tensor(iris.target, ms.int32)[:100] + +# explain the first sample +inputs = data[:1] +# explain the label 'setosa'(class index 0) +targets = 0 +``` + +## Import Model + +Here we define a simple linear classifier. + +```python +import numpy as np +import mindspore.nn as nn + + +class LinearNet(nn.Cell): + def __init__(self): + super(LinearNet, self).__init__() + # input features: 4 + # output classes: 3 + self.linear = nn.Dense(4, 3, activation=nn.Softmax()) + + def construct(self, x): + x = self.linear(x) + return x + + +net = LinearNet() + +# load pre-trained parameters +weight = np.array([[0.648, 1.440, -2.05, -0.977], [0.507, -0.276, -0.028, -0.626], [-1.125, -1.183, 2.099, 1.605]]) +bias = np.array([0.308, 0.343, -0.652]) +net.linear.weight.set_data(ms.Tensor(weight, ms.float32)) +net.linear.bias.set_data(ms.Tensor(bias, ms.float32)) +``` + +## Using LIMETabular + +`LIMETabular` approximates the machine learning model with a local, interpretable model to explain each individual +prediction. + +```python +from mindspore_xai.explainer import LIMETabular + +# convert features to feature stats +feature_stats = LIMETabular.to_feat_stats(data, feature_names=feature_names) +# initialize the explainer +lime = LIMETabular(net, feature_stats, feature_names=feature_names, class_names=class_names) +# explain +lime_outputs = lime(inputs, targets, show=True) +print("LIMETabular:") +for i, exps in enumerate(lime_outputs): + for exp in exps: + print("Explanation for sample {} class {}:".format(i, class_names[targets])) + print(exp, '\n') +``` + +output: +> LIMETabular: +> +> Explanation for sample 0 class setosa: +> +> [('petal length (cm) <= 1.60', 0.8182714590301656), +> ('sepal width (cm) > 3.30', 0.0816516722404966), ('petal width (cm) <= 0.30', 0.03557190104069489), +> ('sepal length (cm) <= 5.10', -0.021441399016492325)] + +![lime_tabular](./images/lime_tabular.png) + +`LIMETabular` also supports a callable function, for example: + +```python +def predict_fn(x): + return net(x) + + +# initialize the explainer +lime = LIMETabular(predict_fn, feature_stats, feature_names=feature_names, class_names=class_names) +``` + +## Using SHAPKernel + +`SHAPKernel` is a method that uses a special weighted linear regression to compute the importance of each feature. + +```python +from mindspore_xai.explainer import SHAPKernel + +# initialize the explainer +shap_kernel = SHAPKernel(net, data, feature_names=feature_names, class_names=class_names) +# explain +shap_kernel_outputs = shap_kernel(inputs, targets, show=True) +print("SHAPKernel:") +for i, exps in enumerate(shap_kernel_outputs): + for exp in exps: + print("Explanation for sample {} class {}:".format(i, class_names[targets])) + print(exp, '\n') +``` + +output: +> SHAPKernel: +> +> Explanation for sample 0 class setosa: +> +> [-0.00403276 0.03651359 0.59952676 0.01399141] + +![shap_kernel](./images/shap_kernel.png) + +`SHAPKernel` also supports a callable function, for example: + +```python +# initialize the explainer +shap_kernel = SHAPKernel(predict_fn, data, feature_names=feature_names, class_names=class_names) +``` + +## Using SHAPGradient + +`SHAPGradient` explains a model using expected gradients (an extension of integrated gradients). + +```python +from mindspore_xai.explainer import SHAPGradient +import mindspore as ms + +# Gradient only works under PYNATIVE_MODE. +ms.set_context(mode=ms.PYNATIVE_MODE) +# initialize the explainer +shap_gradient = SHAPGradient(net, data, feature_names=feature_names, class_names=class_names) +# explain +shap_gradient_outputs = shap_gradient(inputs, targets, show=True) +print("SHAPGradient:") +for i, exps in enumerate(shap_gradient_outputs): + for exp in exps: + print("Explanation for sample {} class {}:".format(i, class_names[targets])) + print(exp, '\n') +``` + +output: + +> SHAPGradient: +> +> Explanation for sample 0 class setosa: +> +> [-0.0112452 0.08389313 0.47006473 0.0373782 ] + +![shap_gradient](./images/shap_gradient.png) diff --git a/docs/xai/docs/source_en/using_tbnet.md b/docs/xai/docs/source_en/using_tbnet.md new file mode 100644 index 0000000000000000000000000000000000000000..abf6abaec9614decac48f04594bc9ab73cec39b6 --- /dev/null +++ b/docs/xai/docs/source_en/using_tbnet.md @@ -0,0 +1,287 @@ +# Using TB-Net Whitebox Recommendation Model + + + +## What is TB-Net + +TB-Net is a white box recommendation model, which constructs subgraphs in knowledge graphs based on the interaction between users and items as well as the features of items, and then calculates paths in the graphs using a bidirectional conduction algorithm. Finally, we can obtain explainable recommendation results. + +Paper: Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems + +## Preparations + +### Downloading Data Package + +First of all, we have to download the data package and put it underneath the `models/whitebox/tbnet` directory of a local XAI [source package](https://gitee.com/mindspore/xai): + +```bash +wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz +tar -xf tbnet_data.tar.gz + +git clone https://gitee.com/mindspore/xai.git +mv data xai/models/whitebox/tbnet +``` + +`xai/models/whitebox/tbnet/` files: + +```bash +. +└─tbnet + ├─README.md + ├─README_CN.md + ├─data + │ └─steam # Steam user purchase history dataset + │ ├─LICENSE + │ ├─config.json # hyper-parameters and training configuration + │ ├─src_infer.csv # source datafile for inference + │ ├─src_test.csv # source datafile for evaluation + │ └─src_train.csv # source datafile for training + ├─src + │ ├─dataset.py # dataset loader + │ ├─embedding.py # embedding module + │ ├─metrics.py # model metrics + │ ├─path_gen.py # data preprocessor + │ ├─recommend.py # result aggregator + │ └─tbnet.py # TB-Net architecture + ├─export.py # export MINDIR/AIR script + ├─preprocess.py # data pre-processing script + ├─eval.py # evaluation script + ├─infer.py # inference and explaining script + ├─train.py # training script + └─tbnet_config.py # configuration reader +``` + +### Preparing Python Environment + +TB-Net is part of the XAI package, no extra installation is required besides [MindSpore](https://mindspore.cn/install/en) and [XAI](https://www.mindspore.cn/xai/docs/en/r1.8/installation.html). GPUs are supported. + +## Data Pre-processing + +The complete example code of this step is [preprocess.py](https://gitee.com/mindspore/xai/blob/r1.8/models/whitebox/tbnet/preprocess.py). + +Before training the TB-Net, we have to convert the source datafile to relation path data. + +### Source Datafile Format + +The source datafiles of the steam dataset all share the exact same CSV format with headers: + +`user,item,rating,developer,genre,category,release_year` + +The first 3 columns must be present with specific order and meaning: + +- `user`: String, user ID, records of the same user must be grouped in consecutive rows in a single file. Splitting the records across different files will give misleading results. +- `item`: String, item ID. +- `rating`: Character, either `c`(user had interactions (e.g. clicked) with the item but not purchased), `p`(user purchased the item) or `x`(other items). + +(Remark: There is no `c` rating item in the steam dataset.) + +Since the order and meaning of these columns are fixed, the names do not matter, users may choose other names like `uid,iid,act`, etc. + +The later columns `developer,genre,category,release_year` are for the item's string attribute IDs. Users should decide the column names (i.e. relation names) and keep them consistent in all source datafiles. There must be at least one attribute column with no maximum limit. In some cases, there are more than one values in each attribute, they should be separated by `;`. Leaving the attribute blank means the item has no such attribute. + +The content of source datafiles for different purposes are slightly different: + +- `src_train.csv`: For training, the numbers of rows of `p` rating and `c` + `x` rating items should be roughly the same by re-sampling, there is no need to list all items in every user. +- `src_test.csv`: For evaluation, very similar to `src_train.csv` but with less amount of data. +- `src_infer.csv`: For inference, must contain data of ONLY ONE user. ALL `c`, `p` and `x` rating items should be listed. In [preprocess.py](https://gitee.com/mindspore/xai/blob/r1.8/models/whitebox/tbnet/preprocess.py), only the `c` and `x` items are put as recommendation candidates in path data. + +### Converting to Relation Path Data + +```python +import io +import json +from src.path_gen import PathGen + +path_gen = PathGen(per_item_paths=39) + +path_gen.generate("./data/steam/src_train.csv", "./data/steam/train.csv") + +# save id maps for the later use by Recommender for inference +with io.open("./data/steam/id_maps.json", mode="w", encoding="utf-8") as f: + json.dump(path_gen.id_maps(), f, indent=4) + +# treat newly met items and references in src_test.csv and src_infer.csv as unseen entities +# dummy internal id 0 will be assigned to them +path_gen.grow_id_maps = False + +path_gen.generate("./data/steam/src_test.csv", "./data/steam/test.csv") + +# for inference, only take interacted('c') and other('x') items as candidate items, +# the purchased('p') items won't be recommended. +# assume there is only one user in src_infer.csv +path_gen.subject_ratings = "cx" + +path_gen.generate("./data/steam/src_infer.csv", "./data/steam/infer.csv") +``` + +`PathGen` is responsible for converting source datafile into relation path data. + +### Relation Path Data Format + +Relation path data are header-less CSV (all integer values), with columns: + +`,