From 3ed8f40eb8a0dbb072d6a3009d51962af4c35a60 Mon Sep 17 00:00:00 2001 From: huxiaobo Date: Sat, 17 Aug 2024 10:05:50 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=8A=A0=E5=85=A5MultiData=E6=A8=A1?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../visualization/compare/graph_comparator.py | 18 ++--- .../visualization/compare/mode_adapter.py | 12 ++-- .../pytorch/visualization/graph/base_node.py | 71 ++++++++++++++----- .../pytorch/visualization/graph/multi_data.py | 58 +++++++++++++++ .../msprobe/pytorch/visualization/utils.py | 4 ++ 5 files changed, 131 insertions(+), 32 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py index 2465cf1ae0..576d02d912 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -52,12 +52,12 @@ class GraphComparator: compare_in_dict[item[0]] = item precision_status, precision_index, other_dict = ( self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) - node.data[GraphConst.JSON_STATUS_KEY] = precision_status - node.data[GraphConst.JSON_INDEX_KEY] = precision_index - node.data.update(other_dict) + node.update_data({GraphConst.JSON_STATUS_KEY: precision_status}) + node.update_data({GraphConst.JSON_INDEX_KEY: precision_index}) + node.update_data(other_dict) if not precision_status: - self.ma.add_error_key(node.output_data) - node.get_suggestions() + self.ma.add_error_key(node.get_output_data()) + node.add_suggestions() def _parse_param(self, data_paths, stack_path, output_path): self.dump_path_param = { @@ -81,11 +81,11 @@ class GraphComparator: compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} for node in self.ma.compare_nodes: precision_status, precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) - node.data[GraphConst.JSON_STATUS_KEY] = precision_status - node.data[GraphConst.JSON_INDEX_KEY] = precision_index + node.update_data({GraphConst.JSON_STATUS_KEY: precision_status}) + node.update_data({GraphConst.JSON_INDEX_KEY: precision_index}) if not precision_status: - self.ma.add_error_key(node.output_data) - node.get_suggestions() + self.ma.add_error_key(node.get_output_data) # 留着,应该被ut查出来 + node.add_suggestions() def _compare_nodes(self, node_n): #递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py index b3f286c28f..d3d181286e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py @@ -119,21 +119,21 @@ class ModeAdapter: """ other_dict = {} if self.is_md5_compare(): - precision_status_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1]) + precision_status_in = ModeAdapter._add_md5_compare_data(node.get_input_data(), compare_data_dict[0]) + precision_status_out = ModeAdapter._add_md5_compare_data(node.get_output_data(), compare_data_dict[1]) # 所有输入输出md5对比通过,这个节点才算通过 precision_status = precision_status_in and precision_status_out precision_index = 1 if precision_status else 0 other_result = CompareConst.PASS if precision_status else CompareConst.DIFF other_dict[CompareConst.RESULT] = other_result elif self.is_summary_compare(): - precision_status_in, precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out, precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1]) + precision_status_in, precision_index_in = ModeAdapter._add_summary_compare_data(node.get_input_data(), compare_data_dict[0]) + precision_status_out, precision_index_out = ModeAdapter._add_summary_compare_data(node.get_output_data(), compare_data_dict[1]) precision_status = precision_status_in and precision_status_out precision_index = min(precision_index_in, precision_index_out) else: - min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict[0]) - min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict[0]) + min_thousandth_in = ModeAdapter._add_real_compare_data(node.get_input_data(), compare_data_dict[0]) + min_thousandth_out = ModeAdapter._add_real_compare_data(node.get_output_data(), compare_data_dict[0]) if min_thousandth_in and min_thousandth_out: change_percentage = abs(min_thousandth_in - min_thousandth_out) else: diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py index e8c86e243e..24264dd41d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py @@ -14,6 +14,7 @@ # limitations under the License. from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.graph.multi_data import MultiData from msprobe.pytorch.visualization.utils import Suggestions, GraphConst from msprobe.pytorch.visualization.builder.msprobe_adapter import format_node_data, compare_data @@ -22,14 +23,11 @@ class BaseNode: def __init__(self, node_op, node_id, up_node=None): self.op = node_op self.id = node_id - self.data = {} - self.output_data = {} - self.input_data = {} self.upnode = None self.add_upnode(up_node) self.subnodes = [] self.matched_node_link = [] - self.suggestions = {} + self.multi_data = MultiData() def __str__(self): info = f'id:\t{self.id}' @@ -39,26 +37,28 @@ class BaseNode: """ 用来判断两个节点是否可以被匹配上,认为结构上是否一致 """ - if not compare_data(self.input_data, other.input_data): + if not compare_data(self.get_input_data(), other.get_input_data()): return False - if not compare_data(self.output_data, other.output_data): + if not compare_data(self.get_output_data(), other.get_output_data()): return False return True - def get_suggestions(self): + def add_suggestions(self): """ 精度疑似有问题时,提供一些建议 """ + suggestions = self.get_suggestions() + # 这里应该有写入测试 if self.op == NodeOp.module: - self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module - self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL + suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module + suggestions[Suggestions.PTDBG] = Suggestions.PTDBG_URL elif self.op == NodeOp.function_api: - self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API - self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL + suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API + suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL def set_input_output(self, input_data, output_data): - self.input_data = input_data - self.output_data = output_data + self.set_input_data(input_data) + self.set_output_data(output_data) def add_upnode(self, node): """ @@ -86,13 +86,15 @@ class BaseNode: result = {} result['id'] = self.id result['node_type'] = self.op.value - result['data'] = self.data - result['output_data'] = format_node_data(self.output_data) - result['input_data'] = format_node_data(self.input_data) + result['data'] = self.get_data_dict() + output_data = self.get_output_data_dict() + result['output_data'] = [format_node_data(data) for data in output_data] + input_data = self.get_input_data_dict() + result['input_data'] = [format_node_data(data) for data in input_data] result['upnode'] = self.upnode.id if self.upnode else 'None' result['subnodes'] = [node.id for node in self.subnodes] result['matched_node_link'] = self.matched_node_link - result['suggestions'] = self.suggestions + result['suggestions'] = self.get_suggestions_dict() return result def get_ancestors(self): @@ -105,3 +107,38 @@ class BaseNode: ancestors.append(current_node.id) current_node = current_node.upnode return list(reversed(ancestors)) + + # 提供封装好的四个数据获取,方便外部调用 + def get_input_data(self): + return self.multi_data.get_data(GraphConst.MULTI_INPUT_KEY) + + def get_output_data(self): + return self.multi_data.get_data(GraphConst.MULTI_OUTPUT_KEY) + + def get_suggestions(self): + return self.multi_data.get_data(GraphConst.MULTI_SUGGEST_KEY) + + def get_data(self): + return self.multi_data.get_data(GraphConst.MULTI_DATA_KEY) + + def set_input_data(self, value): + return self.multi_data.set_data(GraphConst.MULTI_INPUT_KEY, value) + + def set_output_data(self, value): + return self.multi_data.set_data(GraphConst.MULTI_OUTPUT_KEY, value) + + def update_data(self, new_dict): + data = self.get_data() + data.update(new_dict) + + def get_input_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_INPUT_KEY) + + def get_output_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_OUTPUT_KEY) + + def get_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_DATA_KEY) + + def get_suggestions_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_SUGGEST_KEY) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py new file mode 100644 index 0000000000..0b628c3332 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.utils import GraphConst + + +class MultiData: + def __init__(self): + self._data = [] + self.set_repeat_time(1) + + def set_repeat_time(self, repeat_time): + if repeat_time <= 0: + return + self._repeat_time = repeat_time + self._data.clear() + self._data = [{ + GraphConst.MULTI_DATA_KEY: {}, + GraphConst.MULTI_OUTPUT_KEY: {}, + GraphConst.MULTI_INPUT_KEY: {}, + GraphConst.MULTI_SUGGEST_KEY: {}, + } for _ in range(repeat_time)] + self.set_index(0) + + def get_repeat_time(self): + return self._repeat_time + + def set_index(self, index): + if index < 0 or index >= self._repeat_time: + raise Exception('The index is out of range.') + self._index = index + + def get_data(self, key): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + return self._data[self._index].get(key, {}) + + def set_data(self, key, value): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + self._data[self._index][key] = value + + def to_dict(self, key): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + return [data[key] for data in self._data] diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py index a1b72ee850..6d607804e3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -127,3 +127,7 @@ class GraphConst: SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF, CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] + MULTI_DATA_KEY = 'data' + MULTI_OUTPUT_KEY = 'output_data' + MULTI_INPUT_KEY = 'input_data' + MULTI_SUGGEST_KEY = 'suggestions' -- Gitee From f98f25d1ce24eb879f923ef57b2066474747d7f2 Mon Sep 17 00:00:00 2001 From: huxiaobo Date: Sat, 17 Aug 2024 12:54:10 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=92=8Cbugfix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../visualization/compare/graph_comparator.py | 2 +- .../pytorch/visualization/graph/base_node.py | 3 +- .../builder/test_graph_builder.py | 4 +- .../compare/test_graph_comparator.py | 16 +++++ .../compare/test_mode_adapter.py | 3 +- .../visualization/graph/test_base_node.py | 16 ++--- .../visualization/graph/test_graph.py | 6 +- .../visualization/graph/test_multi_data.py | 61 +++++++++++++++++++ 8 files changed, 94 insertions(+), 17 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_multi_data.py diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py index 576d02d912..d275416d10 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -84,7 +84,7 @@ class GraphComparator: node.update_data({GraphConst.JSON_STATUS_KEY: precision_status}) node.update_data({GraphConst.JSON_INDEX_KEY: precision_index}) if not precision_status: - self.ma.add_error_key(node.get_output_data) # 留着,应该被ut查出来 + self.ma.add_error_key(node.get_output_data()) node.add_suggestions() def _compare_nodes(self, node_n): diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py index 24264dd41d..76de28a622 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py @@ -48,10 +48,9 @@ class BaseNode: 精度疑似有问题时,提供一些建议 """ suggestions = self.get_suggestions() - # 这里应该有写入测试 if self.op == NodeOp.module: suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module - suggestions[Suggestions.PTDBG] = Suggestions.PTDBG_URL + suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL elif self.op == NodeOp.function_api: suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py index 66eceea4b2..4cc6277048 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py @@ -48,5 +48,5 @@ class TestGraphBuilder(unittest.TestCase): data_dict = {"node1": {}} node = GraphBuilder._create_or_get_node(self.graph, data_dict, node_op, "node1") self.assertIn("node1", self.graph.node_map) - self.assertEqual(node.input_data, {}) - self.assertEqual(node.output_data, {}) + self.assertEqual(node.get_input_data(), {}) + self.assertEqual(node.get_output_data(), {}) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py index bece5380f0..80f9e742b1 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py @@ -2,6 +2,8 @@ import unittest from unittest.mock import patch from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator from msprobe.pytorch.visualization.graph.graph import Graph +from msprobe.pytorch.visualization.graph.base_node import BaseNode +from msprobe.pytorch.visualization.graph.node_op import NodeOp from msprobe.pytorch.visualization.utils import GraphConst @@ -30,3 +32,17 @@ class TestGraphComparator(unittest.TestCase): 'is_print_compare_log': True }) self.assertEqual(self.comparator.output_path, self.output_path) + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_bad_precision_status(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + node = BaseNode(NodeOp.function_api, 'test_api') + self.comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + self.comparator.ma.add_error_key(node.get_output_data()) + node.add_suggestions() + suggestions = node.get_suggestions() + self.assertEqual(suggestions['text'], '此api精度比对结果疑似异常,请使用msprobe工具的预检功能对api进行精度检测') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py index 7883a09a34..fd63e973b5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py @@ -1,4 +1,5 @@ import unittest +import json from unittest.mock import patch, MagicMock from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp @@ -58,4 +59,4 @@ class TestModeAdapter(unittest.TestCase): def test_get_tool_tip(self): self.adapter.compare_mode = GraphConst.MD5_COMPARE tips = self.adapter.get_tool_tip() - self.assertEqual(tips, {'md5': ToolTip.MD5}) + self.assertEqual(tips, json.dumps({'md5': ToolTip.MD5})) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py index 544950f358..1115a84cec 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py @@ -21,15 +21,15 @@ class TestBaseNode(unittest.TestCase): self.assertEqual(self.node, other_node) def test_get_suggestions(self): - self.node.get_suggestions() - self.assertIn(GraphConst.SUGGEST_KEY, self.node.suggestions) + self.node.add_suggestions() + self.assertIn(GraphConst.SUGGEST_KEY, self.node.get_suggestions()) def test_set_input_output(self): input_data = {'input1': 'value1'} output_data = {'output1': 'value2'} self.node.set_input_output(input_data, output_data) - self.assertEqual(self.node.input_data, input_data) - self.assertEqual(self.node.output_data, output_data) + self.assertEqual(self.node.get_input_data(), input_data) + self.assertEqual(self.node.get_output_data(), output_data) def test_add_upnode(self): self.node = BaseNode(self.node_op, self.node_id) @@ -49,13 +49,13 @@ class TestBaseNode(unittest.TestCase): expected_result = { 'id': self.node_id, 'node_type': self.node_op.value, - 'data': {}, - 'output_data': {}, - 'input_data': {}, + 'data': [{}], + 'output_data': [{}], + 'input_data': [{}], 'upnode': self.up_node.id, 'subnodes': [], 'matched_node_link': [], - 'suggestions': {} + 'suggestions': [{}] } self.assertEqual(self.node.to_dict(), expected_result) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py index 19d0987434..0811851b18 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py @@ -45,6 +45,6 @@ class TestGraph(unittest.TestCase): node_a = BaseNode(self.node_op, self.node_id) result = {} graph.dfs(node_a, result) - self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, - 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], - 'matched_node_link': [], 'suggestions': {}}}) + self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': [{}], + 'output_data': [{}], 'input_data': [{}], 'upnode': 'None', 'subnodes': [], + 'matched_node_link': [], 'suggestions': [{}]}}) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_multi_data.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_multi_data.py new file mode 100644 index 0000000000..99c37321f4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_multi_data.py @@ -0,0 +1,61 @@ +import unittest +import pytest +from msprobe.pytorch.visualization.graph.multi_data import MultiData +from msprobe.pytorch.visualization.utils import GraphConst + +class TestMultiData(unittest.TestCase): + + def setUp(self): + self.multi_data = MultiData() + self.key_list = [GraphConst.MULTI_DATA_KEY, GraphConst.MULTI_OUTPUT_KEY, + GraphConst.MULTI_INPUT_KEY, GraphConst.MULTI_SUGGEST_KEY] + + def test_init(self): + self.assertEqual(self.multi_data._repeat_time, 1) + for key in self.key_list: + self.assertEqual(self.multi_data._data[0][key], {}) + + def test_independence_in_repeats(self): + test_key = self.key_list[1] + self.multi_data.set_repeat_time(3) + # index0下写入数据 + self.multi_data.set_data(test_key, {'index0':1}) + self.assertEqual(self.multi_data.get_data(test_key), {'index0':1}) + # 切换到index2下检查数据,写入数据 + self.multi_data.set_index(2) + self.assertEqual(self.multi_data.get_data(test_key), {}) + self.multi_data.set_data(test_key, {'index2':1}) + self.assertEqual(self.multi_data.get_data(test_key), {'index2':1}) + # 切换会index0,检查数据独立性 + self.multi_data.set_index(0) + self.assertEqual(self.multi_data.get_data(test_key), {'index0':1}) + + def test_bad_case(self): + self.assertNotEqual(self.multi_data._repeat_time, 5) + # 正确的重复次数设置 + self.multi_data.set_repeat_time(5) + self.assertEqual(self.multi_data._repeat_time, 5) + # 不正确的重复次数设置 + self.multi_data.set_repeat_time(0) + self.assertEqual(self.multi_data._repeat_time, 5) + # 正确的index设置 + self.assertNotEqual(self.multi_data._index, 2) + self.multi_data.set_index(2) + self.assertEqual(self.multi_data._index, 2) + # 不正确的index设置 + with pytest.raises(Exception): + self.multi_data.set_index(5) + self.assertEqual(self.multi_data._index, 2) + + def test_data_output(self): + test_key = self.key_list[1] + self.multi_data.set_repeat_time(3) + # 数据写入 + self.multi_data.set_data(test_key, {'index0':1}) + self.multi_data.set_index(1) + self.multi_data.set_data(test_key, {'index1':1}) + self.multi_data.set_index(2) + self.multi_data.set_data(test_key, {'index2':1}) + # 批量数据输出 + result = self.multi_data.to_dict(test_key) + self.assertEqual(result, [{'index0':1}, {'index1':1}, {'index2':1}]) -- Gitee