diff --git a/profiler/precheck/LICENSE b/profiler/precheck/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a7e3dc07659db1518d51374024ea561019cf27d7 --- /dev/null +++ b/profiler/precheck/LICENSE @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/precheck/README.md b/profiler/precheck/README.md new file mode 100644 index 0000000000000000000000000000000000000000..570b67ae9651d98b2619ad5dd15b0ffde5c81c7f --- /dev/null +++ b/profiler/precheck/README.md @@ -0,0 +1,36 @@ +# precheck + +#### 介绍 +precheck1.0是一个一键性能预检工具,能够快速分析集群计算与通信性能是否达到标杆值。运行完成后能够在命令行窗口与文件夹中生成对应的性能数据。 + +#### 软件架构 +- analyse——analyse 主要分析模块,数据打屏模块,csv数据生成模块 +- common——config 设定算子名称与算子对应性能标杆 + ——constant 设定标定数值 + ——utils 数据打屏模块 +- entrance——entrance 数据采集模块,主程序 +- manager——group_manager 通信域构建与环境变量收集模块 +- pre_check——check 主入口 +- test_op 各算子测试模块 + +#### 安装教程 +1. 确保本机已经安装了昇腾NPU卡驱动包与对应的CANN包。 +2. 克隆代码仓库(请替换``为实际的仓库URL):`git clone ` +3. 新建conda环境,选择python版本为python=3.8:`conda create -n ms_pre_check python=3.8 conda activate ms_pre_check` +4. 进入profiler/precheck主文件夹中,运行:`pip install -r requirements.txt` + +#### 使用说明 + +##### 单机使用 +修改`run.sh`中的命令为: +`torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 ./pre_check/check.py` + +##### 多机使用 +修改`run.sh`中的命令为: +`torchrun --nnodes= --nproc_per_node= --node_rank= --master_addr="" --master_port= ./pre_check/check.py` +其中 +- `--nnodes=` 多机节点的总数量 +- `--nproc_per_node=` 单机内卡的数量 +- `--node_rank=` 机器的优先级,按0到(n-1)在每台机器上依次排序(假设总共有n台机器) +- `--master_addr=` 优先级为0的节点的IP地址 +- `--master_port=` 设置的端口,不被占用不冲突即可 \ No newline at end of file diff --git a/profiler/precheck/analyze/__init__.py b/profiler/precheck/analyze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/precheck/analyze/analyze.py b/profiler/precheck/analyze/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..893b56de3cd9b9674d8af4863d4136895648d4c4 --- /dev/null +++ b/profiler/precheck/analyze/analyze.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from common.constant import Constant +from common.config import OpLevel,DataLevel,Config +from common.utils import get_current_time_str,create_csv_writer +from manager.group_manager import GroupManager +from prettytable import PrettyTable +class Analysis: + + @staticmethod + def _add_data(tensor_data: list) -> list: + ''' + tensor input data to dict + ''' + keys=[ + Constant.RANK_ID, Constant.E2E_TIME, Constant.TOTAL_TIME, Constant.FREE_TIME, + Constant.MAX_TIME, Constant. MIN_TIME,Constant. SIZE,Constant. BANDWIDTH, Constant.COUNT + ] + result=[] + for op_index, tensor in enumerate(tensor_data): + converted_dict = dict(zip(keys, tensor)) + op_type = OpLevel(op_index).name + converted_dict[Constant.OP_TYPE] = op_type + if op_type in (OpLevel.MUL.name, OpLevel.MATMUL.name,): + converted_dict[Constant.BANDWIDTH] = Constant.NA + converted_dict[Constant.BANDWIDTH_BENCHMARK] = Constant.NA + else: + bench_mark = Config.BANDWIDTH_BENCHMARK_TABLE.get(op_type).get(GroupManager().get_rank_size()) + converted_dict[Constant.BANDWIDTH_BENCHMARK] = "PASS" if converted_dict[Constant.BANDWIDTH] >= bench_mark else "FAIL" + result.append(converted_dict) + return result + + @staticmethod + def _write_csv(csv_path: str, rank: int, complete_data: list): + ''' + 创建csv文件 + ''' + if not complete_data: + return + headers = list(complete_data[0].keys()) if rank == 0 else [] + list_data = [list(data.values()) for data in complete_data] + create_csv_writer(csv_path, headers, list_data) + + @staticmethod + def _print_result(self, complete_data: list): + table_data={} + for card_data in complete_data: + rank_id = card_data['RANK_ID'] + op_type = card_data['OP_TYPE'] + bandwidth_benchmark = card_data['BANDWIDTH_BENCHMARK'] + if rank_id not in table_data: + table_data[rank_id] = {} + table_data[rank_id][op_type]=bandwidth_benchmark + all_op_types = [card_data['OP_TYPE'] for card_data in complete_data] + table_rows=[ + [rank_id] + [table_data[rank_id].get(op_type, 'N/A') for op_type in all_op_types] for rank_id in sorted(table_data.keys()) + ] + headers=['RANK_ID'] + all_op_types + if rank_id == 0: + table = PrettyTable(headers) + table.add_rows(table_rows) + table.border = False + + def analyze(self, rank: int, gather_list: list) -> bool: + ''' + Analyze 主要对外入口 + ''' + if rank != 0: + return True + csv_path=os.path.abspath(f"./analyze_{get_current_time_str()}.csv") + for node,node_tensor in enumerate(gather_list): + node_tensor_list = node_tensor.tolist() + for rank_id,tensor_data in enumerate(node_tensor_list): + complete_data = self._add_data(tensor_data) + self._print_result(complete_data) + self._write_csv(csv_path,node * GroupManager().get_rank_size() + rank_id, complete_data) + return True \ No newline at end of file diff --git a/profiler/precheck/common/__init__.py b/profiler/precheck/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/precheck/common/path_manager.py b/profiler/precheck/common/path_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c948149371619bc11cc9e0dba0812b1688278e --- /dev/null +++ b/profiler/precheck/common/path_manager.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +import shutil +import platform + +from .constant import Constant + + +class PathManager: + MAX_PATH_LENGTH = 4096 + MAX_FILE_NAME_LENGTH = 255 + DATA_FILE_AUTHORITY = 0o640 + DATA_DIR_AUTHORITY = 0o750 + WINDOWS = "windows" + + @classmethod + def check_input_directory_path(cls, path: str): + """ + Function Description: + check whether the path is valid, some businesses can accept a path that does not exist, + so the function do not verify whether the path exists + Parameter: + path: the path to check, whether the incoming path is absolute or relative depends on the business + Exception Description: + when invalid data throw exception + """ + cls.input_path_common_check(path) + base_name = os.path.basename(path) + if os.path.isfile(path): + msg = f"Invalid input path which is a file path: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_input_file_path(cls, path: str): + """ + Function Description: + check whether the file path is valid, some businesses can accept a path that does not exist, + so the function do not verify whether the path exists + Parameter: + path: the file path to check, whether the incoming path is absolute or relative depends on the business + Exception Description: + when invalid data throw exception + """ + cls.input_path_common_check(path) + base_name = os.path.basename(path) + if os.path.isdir(path): + msg = f"Invalid input path which is a directory path: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_path_length(cls, path: str): + if len(path) > cls.MAX_PATH_LENGTH: + raise RuntimeError("Length of input path exceeds the limit.") + path_split_list = path.split("/") + for path in path_split_list: + path_list = path.split("\\") + for name in path_list: + if len(name) > cls.MAX_FILE_NAME_LENGTH: + raise RuntimeError("Length of input path exceeds the limit.") + + @classmethod + def input_path_common_check(cls, path: str): + cls.check_path_length(path) + + if os.path.islink(path): + msg = f"Invalid input path which is a soft link." + raise RuntimeError(msg) + + if platform.system().lower() == cls.WINDOWS: + pattern = r'(\.|:|\\|/|_|-|\s|[~0-9a-zA-Z\u4e00-\u9fa5])+' + else: + pattern = r'(\.|/|_|-|\s|[~0-9a-zA-Z])+' + if not re.fullmatch(pattern, path): + msg = f"Invalid input path." + raise RuntimeError(msg) + + @classmethod + def check_path_owner_consistent(cls, path: str): + """ + Function Description: + check whether the path belong to process owner + Parameter: + path: the path to check + Exception Description: + when invalid path, prompt the user + """ + base_name = os.path.basename(path) + if not os.path.exists(path): + msg = f"Invalid path: {base_name}" + raise RuntimeError(msg) + if platform.system().lower() == cls.WINDOWS: + return + if os.stat(path).st_uid != os.getuid(): + check_msg = input("The path does not belong to you, do you want to continue? [y/n]") + if check_msg.lower() != "y": + raise RuntimeError("The user choose not to continue.") + + @classmethod + def check_path_writeable(cls, path): + """ + Function Description: + check whether the path is writable + Parameter: + path: the path to check + Exception Description: + when invalid data throw exception + """ + cls.check_path_owner_consistent(path) + if os.path.islink(path): + msg = f"Invalid path which is a soft link." + raise RuntimeError(msg) + base_name = os.path.basename(path) + if not os.access(path, os.W_OK): + msg = f"The path permission check failed: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_path_readable(cls, path): + """ + Function Description: + check whether the path is writable + Parameter: + path: the path to check + Exception Description: + when invalid data throw exception + """ + cls.check_path_owner_consistent(path) + if os.path.islink(path): + msg = f"Invalid path which is a soft link." + raise RuntimeError(msg) + base_name = os.path.basename(path) + if not os.access(path, os.R_OK): + msg = f"The path permission check failed: {base_name}" + raise RuntimeError(msg) + + @classmethod + def remove_path_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to remove path: {base_name}" + cls.check_path_writeable(path) + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + try: + shutil.rmtree(path) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def make_dir_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to make directory: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + return + try: + os.makedirs(path, mode=cls.DATA_DIR_AUTHORITY) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def create_file_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to create file: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + return + try: + os.close(os.open(path, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY)) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def get_realpath(cls, path: str) -> str: + if os.path.islink(path): + msg = f"Invalid input path which is a soft link." + raise RuntimeError(msg) + return os.path.realpath(path) + + @classmethod + def check_file_size(cls, file_path: str): + if not os.path.exists(file_path): + raise FileNotFoundError(f"The file {file_path} does not exists.") + file_size = os.path.getsize(file_path) + if file_size > Constant.MAX_FILE_SIZE_5_GB: + check_msg = input( + f"The file({file_path}) size exceeds the preset max value. Continue reading the file? [y/n]") + if check_msg.lower() != "y": + raise RuntimeError(f"[WARNING] The user choose not to read the file: {file_path}") diff --git a/profiler/precheck/common/utils.py b/profiler/precheck/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a93ca43b558d199c59fa03a3dd67f3f6b43f31d --- /dev/null +++ b/profiler/precheck/common/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import csv +import os +from datetime import datetime +from datetime import timezone +from common.constant import Constant +from path_manager import PathManager + +def create_csv_writer(csv_file_path: str, headers: list,data: list): + PathManager.check_path_writeable(csv_file_path) + with FdOpen(csv_file_path, newline='', operate = 'a') as _csv_file: + writer = csv.writer(_csv_file) + if headers: + writer.writerow(headers) + slice_count = len(data) // Constant.DATA_LEN + for index in range(slice_count): + writer.writerows(data[index * Constant.DATA_LEN:(index + 1) * Constant.DATA_LEN]) + writer.writerows(data[slice_count * Constant.DATA_LEN:]) + +def get_current_time_str() -> str: + utc_time = datetime.now(tz=timezone.utc) + current_time = utc_time.replace(tzinfo=timezone.utc).astimezone(tz=None) + return current_time.strftime("%Y%m%d%H%M%S") + +class FdOpen: + def __init__(self: any, file_path: str, flags: int = Constant.WRITE_FLAGS, mode: int = Constant.WRITE_MODES, operate: str = 'w', newline: str = None) -> None: + self.file_path = file_path + self.flags = flags + self.mode = mode + self.operate = operate + self.newline = newline + self.fd = None + self.file_open = None + + def __enter__(self: any) -> any: + self.fd = os.open(self.file_path, self.flags, self.mode) + if self.newline is None: + self.file_open = os.fdopen(self.fd, self.operate) + else: + self.file_open = os.fdopen(self.fd, self.operate, newline=self.newline) + return self.file_open + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.file_open: + self.file_open.close() + elif self.fd: + os.close(self.fd) \ No newline at end of file diff --git a/profiler/precheck/entrance/__init__.py b/profiler/precheck/entrance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/precheck/entrance/entrance.py b/profiler/precheck/entrance/entrance.py new file mode 100644 index 0000000000000000000000000000000000000000..eff09c4f231daf5a5818d3cb848b6402ee28c152 --- /dev/null +++ b/profiler/precheck/entrance/entrance.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import os +from typing import List + +import torch +import torch_npu +import torch.distributed as dist + +from analyze.analyze import Analysis +from common.config import Config +from manager.group_manager import GroupManager +from test_op.test_op import Statistics + +class Entrance: + statistics_list = [] + + @staticmethod + def generate_tensor(statistics_list: List[Statistics], local_rank) -> any: + return torch.Tensor([[s.rank_id, s.e2e_time, s.total_time, s.free_time, s.max_time, s.min_time, + s.size, s.bandwidth,s.count ] for s in statistics_list]).npu(local_rank) + + @staticmethod + def gather_rank_data(tensor, local_rank) -> list: + dist.barrier() + gather_list = [] + rank_size = GroupManager().get_rank_size() + rank = GroupManager().get_rank() + if local_rank == 0: + for _ in range(rank_size): + gather_list.append((torch.zeros(len(Config.CHECK_LIST),9)).npu(local_rank)) + dist.gather(tensor, gather_list = gather_list, dst= rank // rank_size * rank_size,group = GroupManager().get_local_group()) + if rank % rank_size == 0: + rank_tensor = torch.stack(gather_list) + gather_list = [] + if rank == 0: + for _ in range(GroupManager().get_world_size() // rank_size): + gather_list.append((torch.zeros(rank_size, len(Config.CHECK_LIST),9)).npu(local_rank)) + dist.gather(rank_tensor, gather_list = gather_list, dst= 0, group=GroupManager().get_gather_group()) + return gather_list + + def run(self): + GroupManager() + rank = GroupManager().get_rank() + local_rank = GroupManager().get_local_rank() + index=1 + while True: + statistics_list = ['']* len(Config.CHECK_LIST) + if Config.CHECK_LIST is None: + break + for op_name, op_func in Config.CHECK_LIST.items(): + statistics_list[op_name.op_func] = op_func(rank=local_rank).run() + tensor = self.generate_tensor(statistics_list, local_rank) + gather_list = self.gather_rank_data(tensor, local_rank) + if Analysis().analyze(rank=rank, gather_list=gather_list): + break + index += 1 + + def main(self: any) -> None: + torch.npu.set_device(int(os.environ['LOCAL_RANK'])) + dist.init_process_group(backend='hccl', rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE']), timeout=datetime.timedelta(seconds=1800)) + self.run() \ No newline at end of file diff --git a/profiler/precheck/manager/__init__.py b/profiler/precheck/manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/precheck/manager/group_manager.py b/profiler/precheck/manager/group_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1737789ddb8433ff4d51e0583214e55395009e84 --- /dev/null +++ b/profiler/precheck/manager/group_manager.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import os +import torch +import torch.distributed as dist +from common.singleton import singleton +@singleton +class GroupManager: + _initialized = False + def __init__(self:any)->None: + if not self._initialized: + self._initialized = True + self._rank = int(os.environ['RANK']) + self._local_rank = int(os.environ['LOCAL_RANK']) + self._world_size = int(os.environ['WORLD_SIZE']) + self._group_rank = int(os.environ['GROUP_RANK']) + self._rank_size = int(os.environ['LOCAL_WORLD_SIZE']) + self._local_group = None + self._gather_group = None + + def get_rank(self): + return self._rank + + def get_local_rank(self): + return self._local_rank + + def get_world_size(self): + return self._world_size + + def get_rank_size(self): + return self._rank_size + + def get_local_group(self): + if self._local_group is None: + groups=[x for x in range(self._group_rank * self._rank_size , (self._group_rank+1) * self._rank_size )] + self._local_group = dist.new_group(ranks = groups, timeout = datetime.timedelta(seconds=1800)) + return self._local_group + + def get_gather_group(self): + if self._gather_group is None: + groups = [x for x in range(self._world_size) if x % self._rank_size == 0] + self._gather_group = dist.new_group(ranks = groups, timeout = datetime.timedelta(seconds=1800)) + return self._gather_group + + \ No newline at end of file diff --git a/profiler/precheck/pre_check/__init__.py b/profiler/precheck/pre_check/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/precheck/pre_check/check.py b/profiler/precheck/pre_check/check.py new file mode 100644 index 0000000000000000000000000000000000000000..87232b13400df6631760ebd2632e63dc4fca5c77 --- /dev/null +++ b/profiler/precheck/pre_check/check.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +import sys +if __name__ == '__main__': + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) + MODEL_PATH = "entrance.entrance" + ENTRANCE_CLASS = "Entrance" + os.umask(0o027) + entrance_module = importlib.import_module(MODEL_PATH) + if hasattr(entrance_module,ENTRANCE_CLASS): + getattr(entrance_module,ENTRANCE_CLASS)().main() \ No newline at end of file diff --git a/profiler/precheck/run.sh b/profiler/precheck/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..47b65c6156242cc3de70bf0074e62ccfd4fe59d9 --- /dev/null +++ b/profiler/precheck/run.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 ./pre_check/check.py \ No newline at end of file diff --git a/profiler/precheck/test_op/__init__.py b/profiler/precheck/test_op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/precheck/test_op/test_op.py b/profiler/precheck/test_op/test_op.py new file mode 100644 index 0000000000000000000000000000000000000000..29ba17836cc4a8286ff189cd033187fc9301fcaa --- /dev/null +++ b/profiler/precheck/test_op/test_op.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from dataclasses import dataclass +from common.constant import Constant + +@dataclass +class Statistics: + rank_id: int = -1 + e2e_time: float = 0.0 + total_time: float = 0.0 + free_time: float = 0.0 + max_time: float = 0.0 + min_time: float = sys.float_info.max + size: float = 0.0 + bandwidth: float = 0.0 + count: int = 0 + +class TestOp: + def __init__(self : any, pre_train: int=10, train: int=100, rank: int = 0) -> None: + self.pre_train = pre_train + self.train = train + self.rank = rank + @staticmethod + def calculate_tensor_size_g(tensor: any) -> float: + return tensor.numel() * tensor.element_size() / Constant.BYTE_SIZE / Constant.BYTE_SIZE / Constant.BYTE_SIZE + + def run(self): + pass \ No newline at end of file