diff --git a/profiler/msprof_analyze/precheck/common/file.py b/profiler/msprof_analyze/precheck/common/file.py index 15837a0884fd79a9404de3995ba396846f605e02..c4204937b15c0458df2aa8c3f88a04ea63eab04e 100644 --- a/profiler/msprof_analyze/precheck/common/file.py +++ b/profiler/msprof_analyze/precheck/common/file.py @@ -69,6 +69,13 @@ class File: except Exception as e: logger.error(f"Failed to create file: {e}") + @staticmethod + def get_json_files(file_dir, end_type: str): + for root, _, files in os.walk(file_dir): + for file in files: + if file.endswith(end_type): + yield os.path.join(root, file) + class FileOpen: """ diff --git a/profiler/msprof_analyze/precheck/env_check/analyze.py b/profiler/msprof_analyze/precheck/env_check/analyze.py index 37274f83a757233760455b053153fd7e32d2d8ab..0123b6d321a3fa69297c5b29baf0ae0c286117c9 100644 --- a/profiler/msprof_analyze/precheck/env_check/analyze.py +++ b/profiler/msprof_analyze/precheck/env_check/analyze.py @@ -21,6 +21,15 @@ from msprof_analyze.precheck.common.constant import Constant class TimeAnalyze(): def __init__(self, run_time): self.run_time = run_time + self.slow_time = [] + self.slow_rank = [] + + @staticmethod + def get_key(dictionary, value): + for key, val in dictionary.items(): + if val == value: + return key + return None @staticmethod def get_key(dictionary, value): @@ -33,15 +42,6 @@ class TimeAnalyze(): if not self.run_time: logging.ERROR("Running time is undefined.") return None - - slow_rank = None - slow_time = None - mean_time = 0 - max_ratio = 0 - - # 耗时极值编号和数据 - slow_time = [] - slow_rank = [] # 计算快慢差异 try: @@ -56,8 +56,8 @@ class TimeAnalyze(): for run_time_value in self.run_time.values(): if run_time_value > mean_time * (1 + Constant.RATIO_THRESHOLD): - slow_time.append(run_time_value) - slow_rank.append(self.get_key(self.run_time, run_time_value)) + self.slow_time.append(run_time_value) + self.slow_rank.append(self.get_key(self.run_time, run_time_value)) # 判断是否存在问题 if max_ratio > Constant.RATIO_THRESHOLD: @@ -65,7 +65,34 @@ class TimeAnalyze(): else: isproblem = False - return slow_rank, slow_time, max_ratio, isproblem + return self.slow_rank, self.slow_time, max_ratio, isproblem + + def bandwidth_analyze(self): + if not self.run_time: + logging.ERROR("Running time is undefined.") + return None + + try: + mean_value = sum(self.run_time) / len(self.run_time) + except ZeroDivisionError as e: + raise RuntimeError("The input parameter is undefined.") from e + try: + min_ratio = (min(self.run_time) - mean_value) / mean_value + except ZeroDivisionError as e: + raise RuntimeError("The input parameter has value of zero.") from e + + for indx, run_time_value in enumerate(self.run_time): + if run_time_value < mean_value * (1 - Constant.RATIO_THRESHOLD): + self.slow_time.append(run_time_value) + self.slow_rank.append(indx) + + # 判断是否存在问题 + if min_ratio < Constant.RATIO_THRESHOLD: + isproblem = True + else: + isproblem = False + return self.slow_rank, self.slow_time, min_ratio, isproblem + class Timer: diff --git a/profiler/msprof_analyze/precheck/env_check/communication_check.py b/profiler/msprof_analyze/precheck/env_check/communication_check.py index 4c6308e09f2789f6f6bef3098599769e369329c1..c613a6b6ecf0837736979f120257de64120f8ccd 100644 --- a/profiler/msprof_analyze/precheck/env_check/communication_check.py +++ b/profiler/msprof_analyze/precheck/env_check/communication_check.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict import os from typing import List, Optional import argparse @@ -24,6 +25,7 @@ import torch_npu from torch import distributed as dist from msprof_analyze.precheck.distributed_cluster.distributed_cluster_base import DistributedClusterBase from msprof_analyze.precheck.env_check.environment_check import HardwareCheck +from msprof_analyze.precheck.env_check.analyze import TimeAnalyze from msprof_analyze.precheck.common.file import File, FileOpen, FdOpen @@ -40,40 +42,47 @@ class CommunicationCheck(HardwareCheck): self.cp = args.context_parallel_size self.ep = args.expert_model_parallel_size self.no_share_storage = args.no_shared_storage - self.tmp_path = '/tmp/p2p_test' + self.iteration = 5 self.validate_parallel_strategy() + self.world_size = self.proc_num * self.nodes + self.node_rank = args.node_rank @staticmethod def analyze(data_path, args): """ 遍历指定路径下的所有JSON文件,并获取指定键对应的value """ - file_name: str = 'communication_check.json' - index_name = ['dp_avg_cost', 'pp_avg_cost', 'tp_avg_cost', 'cp_avg_cost', 'ep_avg_cost'] - result = {} - - # 生成所有JSON文件的路径 - def get_json_files(file_dir): - for root, _, files in os.walk(file_dir): - for file in files: - if file.endswith('.json'): - yield os.path.join(root, file) - json_files = get_json_files(data_path) - for file_path in json_files: - try: - # 打开并加载JSON文件 - with FileOpen(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) - for key in data: - if key in index_name: - result.setdefault(key, []).append(data[key]) - except json.JSONDecodeError: - logging.error(f"Error decoding JSON in {file_path}") - except Exception as e: - logging.error(f"An error occurred while processing {file_path}: {e}") - with FdOpen(os.path.join(args.output, file_name), 'w', encoding='utf-8') as f: - json.dump(result, f, indent=4) - + logging.info("Start to analyze Communication") + file_name: str = 'communication_worse.txt' + total_dict = {} + result_dict = [] + + json_files = list(File.get_json_files(data_path, 'comm.json')) + if not json_files: + raise FileNotFoundError(f"No JSON files found in {data_path}") + for file in json_files: + with FileOpen(file, 'r') as f: + data = json.load(f.file_reader) + total_dict.update(data) + # 计算平均值,并分析当前结 + if len(total_dict) == 0: + raise ValueError(f"No data found in {data_path}") + ## reuse the analyze + cpu_analyze = TimeAnalyze(total_dict) + slow_rank, slow_time, _, isproblem = cpu_analyze.time_analyze() + # 打印及保存结果 + if not isproblem: + logging.info(f"No bad communication found") + else: + for rank, time in zip(slow_rank, slow_time): + log_info = f"The performance of communication between rank {rank} cost {time}." + result_dict.append(log_info) + logging.info("the result with performance problem " + \ + f"can be read in file {file_name} at path {data_path}") + with FdOpen(os.path.join(data_path, file_name), 'w') as f: + for line in result_dict: + f.write(line + '\n') + logging.info("Finished the Communication analyze.") def validate_parallel_strategy(self): if self.tp * self.pp * self.cp * self.ep == 0: @@ -89,166 +98,117 @@ class CommunicationCheck(HardwareCheck): self.dcb.initialize_communication_group(self.tp, self.pp, self.cp, self.ep) - def perform_p2p_communication_async(self, comm_ranks, tensor_send, tensor_recv, proc_group): + def perform_p2p_communication_async(self, + src_rank, + target_ranks, + tensor_send, + tensor_recv): """ Perform point-to-point communication asynchronously. """ - send_ops = [] - recv_ops = [] - if not self.is_pipeline_last_rank(comm_ranks): - send_ops.append( - dist.isend( - tensor=tensor_send, - dst=self.get_rank_to_next(ranks=comm_ranks), - group=proc_group - ) - ) - - if not self.is_pipeline_first_rank(comm_ranks): - tensor_recv_next = torch.empty(tensor_send.size(), dtype=torch.float32).npu() - recv_ops.append( - dist.irecv( - tensor=tensor_recv_next, - src=self.get_rank_to_prev(ranks=comm_ranks), - group=proc_group - ) - ) - - # 等待接收操作完成 - for op in recv_ops: - op.wait() - - if not self.is_pipeline_first_rank(comm_ranks): - send_ops.append( - dist.isend( - tensor=tensor_recv_next, - dst=self.get_rank_to_prev(ranks=comm_ranks), - group=proc_group - ) - ) - - if not self.is_pipeline_last_rank(comm_ranks): - recv_ops.append( - dist.irecv( - tensor=tensor_recv, - src=self.get_rank_to_next(ranks=comm_ranks), - group=proc_group - ) - ) - - # 等待所有发送和接收操作完成 - for op in send_ops + recv_ops: + if len(target_ranks) <= 1: + return None + comm_ops = [] + p2p_results = {} + + for target_rank in target_ranks: + send_tensor = tensor_send.clone() + recv_tensor = tensor_recv.clone() + start = torch_npu.npu.Event(enable_timing=True) + end = torch_npu.npu.Event(enable_timing=True) + start.record() + if src_rank < target_rank: + send_op = dist.isend(send_tensor, dst=target_rank) + recv_op = dist.irecv(recv_tensor, src=target_rank) + comm_ops.extend([send_op, recv_op]) + else: + recv_op = dist.irecv(recv_tensor, src=target_rank) + send_op = dist.isend(send_tensor, dst=target_rank) + comm_ops.extend([recv_op, send_op]) + end.record() + torch_npu.npu.synchronize() + elapsed_time = start.elapsed_time(end) + p2p_results[f"{src_rank}-{target_rank}"] = elapsed_time + + for op in comm_ops: op.wait() + if send_tensor is not None: + del send_tensor + if recv_tensor is not None: + del recv_tensor + return p2p_results def finalize(self): self.dcb.destroy_comm() + def initialize_tensor(self, + tensor_size=(3, 1024, 4096), + tensor_type=torch.float32): + """initialize the tensor with the given size and type""" + init_send_tensor = torch.randn(tensor_size, dtype=tensor_type).npu() + init_recv_tensor = torch.empty(tensor_size, dtype=tensor_type).npu() + return init_send_tensor, init_recv_tensor + def get_file_path(self, file_path, file_name): """get the file path of the given file name""" if not os.path.exists(file_path): logging.info(f"path is not exist, creating output dir: {file_path}") os.makedirs(file_path) return os.path.join(file_path, file_name) - - # 获取dp大小 - def get_dp_size(self): - # 返回world_size除以tp、pp、cp、ep的乘积 - return (self.nodes * self.proc_num) // (self.tp * self.pp * self.cp) - - def get_rank_to_next(self, ranks): - cur_index = ranks.index(dist.get_rank()) - return ranks[(cur_index + 1) % len(ranks)] - - def get_rank_to_prev(self, ranks): - cur_index = ranks.index(dist.get_rank()) - return ranks[(cur_index - 1) % len(ranks)] - - def get_comm_group(self, group_ranks): - if group_ranks is None: - raise ValueError("The group is None, please check again") - for sub in group_ranks: - if dist.get_rank() in sub: - return sub - raise ValueError("Failed to find the current group") - - def is_pipeline_first_rank(self, group): - return dist.get_rank() == group[0] - - def is_pipeline_last_rank(self, group): - return dist.get_rank() == group[-1] - - def comm_group_create(self): - group_names_list = [] - group_info_dict = {} - if self.get_dp_size() > 1: - group_names_list.append("dp") - group_info_dict["dp"] = [self.dcb.get_data_parallel_group_ranks(), - self.dcb.get_data_parallel_group()] - if self.tp > 1: - group_names_list.append("tp") - group_info_dict["tp"] = [self.dcb.get_tensor_parallel_group_ranks(), - self.dcb.get_tensor_parallel_group()] - if self.pp > 1: - group_names_list.append("pp") - group_info_dict["pp"] = [self.dcb.get_pipeline_parallel_group_ranks(), - self.dcb.get_pipeline_parallel_group()] - if self.ep > 1: - group_names_list.append("ep") - group_info_dict["ep"] = [self.dcb.get_expert_parallel_group_ranks(), - self.dcb.get_expert_parallel_group()] - if self.cp > 1: - group_names_list.append("cp") - group_info_dict["cp"] = [self.dcb.get_context_parallel_group_ranks(), - self.dcb.get_context_parallel_group()] - return group_names_list, group_info_dict + + def build_comm_group(self, self_rank): + '''合并通信组,找到所有的p2p通信的对端卡''' + dp_ranks = self.dcb.get_data_parallel_group_ranks() + tp_ranks = self.dcb.get_tensor_parallel_group_ranks() + pp_ranks = self.dcb.get_pipeline_parallel_group_ranks() + cp_ranks = self.dcb.get_context_parallel_group_ranks() + ep_ranks = self.dcb.get_expert_parallel_group_ranks() + groups = [dp_ranks, tp_ranks, pp_ranks, cp_ranks, ep_ranks] + target_ranks = set() + for group in groups: + target_ranks.update(group) + target_ranks.remove(self_rank) + return list(target_ranks) + def get_group_name(self): if len(self.group_names_list) == 0: return [] return self.group_names_list - def collect(self, data_path, iteration: int = 5, batch_size=(32, 4096, 8192)): + def collect(self, data_path): """Perform point-to-point communication between ranks in the group""" self.initialize() - group_names_list, group_info_dict = self.comm_group_create() - local_cost = {"rank": dist.get_rank()} - if len(group_names_list) == 0: - raise ValueError("The tensor-parallel, pipeline-parallel, data-parallel, " - "expert-parallel, context-parallel is set less than 2, " - "no peer communication can be created") - tensor_send = torch.randn(batch_size, dtype=torch.float32).npu() - tensor_recv = torch.empty(batch_size, dtype=torch.float32).npu() - for group_name in group_names_list: - ranks, proc_group = group_info_dict[group_name] - local_cost[group_name] = [] - if len(ranks) <= 1: - raise ValueError(f"Failed to start communication group {group_name}," - f"since the group is {ranks}.") - for i in range(iteration): - if dist.get_rank() == 0: - logging.info(f">>Start communication: {group_name}, iteration: {i}") - tensor_send.uniform_(-1, 1) - if i == 0: - dist.barrier() - timer = Timer() - timer.start() - self.perform_p2p_communication_async(ranks, tensor_send, tensor_recv, proc_group) - timer.stop() - local_cost[group_name].append(timer.delta) - local_cost[f"{group_name}_avg_cost"] = sum(local_cost[group_name]) / iteration - self.dump_tmp_file(local_cost, data_path) + rank = dist.get_rank() + if rank == 0: + logging.info("Start to collect the communication") + target_ranks = self.build_comm_group(rank) + result = {} + init_send_tensor, init_recv_tensor = self.initialize_tensor() + dist.barrier() + for _ in range(self.iteration): + result_p2p = self.perform_p2p_communication_async(rank, target_ranks, init_send_tensor, init_recv_tensor) + for key, value in result_p2p.items(): + if key in result: + result[key] += (value / self.iteration) + else: + result[key] = (value / self.iteration) + + self.dump_tmp_file(result, data_path) + dist.barrier() + + if dist.get_rank() == 0: + logging.info("Finished to collect the communication") self.finalize() - def dump_tmp_file(self, data, save_dir): - """save data to file as worker_x_rank_y.json""" + """save data to file as node_x_rank_y.json""" cur_rank = dist.get_rank() - worker_id = os.getenv("GROUP_RANK") - if worker_id is None: - raise ValueError("GROUP_ID environment variable is not set.") + if self.node_rank is None: + raise ValueError("NODE_RANK environment variable is not set.") if data is None: raise ValueError(f"data is not created by rank {cur_rank}, please check this") - dump_path = self.get_file_path(save_dir, f"worker_{worker_id}_rank_{cur_rank}.json") + dump_path = self.get_file_path(save_dir, f"node_{self.node_rank}_rank_{cur_rank}_comm.json") try: with FdOpen(dump_path, 'w') as f: json.dump(data, f, indent=4) diff --git a/profiler/msprof_analyze/precheck/env_check/io_check.py b/profiler/msprof_analyze/precheck/env_check/io_check.py index cc7d0d4e0d20b804aa8644874bbbc94e4e6af06b..01c8c7b2d87c67b3bd0f8b33db294687e99055b5 100644 --- a/profiler/msprof_analyze/precheck/env_check/io_check.py +++ b/profiler/msprof_analyze/precheck/env_check/io_check.py @@ -19,7 +19,10 @@ import logging import shutil import json import numpy as np +import torch +from torch import distributed as dist from msprof_analyze.precheck.env_check.environment_check import HardwareCheck +from msprof_analyze.precheck.env_check.analyze import TimeAnalyze from msprof_analyze.precheck.common.file import File, FileOpen, FdOpen @@ -28,69 +31,59 @@ class IOCheck(HardwareCheck): def __init__(self, args): super().__init__(args) - self.global_worker_id = os.getenv("GROUP_RANK") - self.local_rank = os.getenv("LOCAL_RANK") + self.node_rank = args.node_rank + self.local_rank = args.local_rank self.output = args.output - self.tmp_work_dir = '/tmp/data_file' - self.results = {self.global_worker_id: { - 'ckpt_read': [], - 'log_read': [], - 'ckpt_write': [], - 'log_write': [], - } - } + self.tmp_work_dir = args.output + self.results = {'node rank': [self.node_rank], + 'ckpt_read': [], + 'log_read': [], + 'ckpt_write': [], + 'log_write': []} @staticmethod def analyze(data_path, args): """ 遍历指定路径下的所有JSON文件,并获取指定键对应的value,并将结果保存到指定的文件中 """ - if int(os.getenv("RANK")) != 0: - return - data_to_merge = {} + logging.info("Start to analyze the IO collect data") + total_data = {} + file_name = 'IO_check.txt' + result = [] # 遍历指定路径下的所有JSON文件 - def get_json_files(data_path): - for root, _, files in os.walk(data_path): - for file in files: - if file.endswith('.json'): - yield os.path.join(root, file) - files_dir = get_json_files(data_path) + files_dir = list(File.get_json_files(data_path, 'io.json')) + # 按照测试的类别‘ckpt_read’,‘log_read’等汇聚数据 for file_path in files_dir: - try: - with FileOpen(file_path, "r") as file: - file_data = json.load(file) - for key, value in file_data.items(): - if key in data_to_merge and isinstance(value, dict): - for k, v in value.items(): - data_to_merge[key][k].extend(v) - else: - data_to_merge[key] = value - except FileNotFoundError: - logging.error(f"文件 {file_path} 未找到。") - except json.JSONDecodeError: - logging.error(f"文件 {file_path} 不是有效的 JSON 文件。") - - def process_dict(data): - """处理字典数据""" - if data is None: - return None - for key, value in data.items(): - if isinstance(value, dict): - data[key] = process_dict(value) - else: - try: - if len(value) > 0: - data[key] = sum(value) / len(value) - else: - data[key] = None - except TypeError: - data[key] = value - return data - - data_to_merge = process_dict(data_to_merge) - with FdOpen(os.path.join(kwargs.get('output', 'final.json')), 'w') as f: - json.dump(data_to_merge, f, indent=4) + with FileOpen(file_path, "r") as file: + file_data = json.load(file.file_reader) + for key, value in file_data.items(): + total_data.setdefault(key, []).extend(value) + for key, value in total_data.items(): + if key == 'node rank': + continue + if len(value) == 0: + logging.error(f'None data is collected for {key}, please check this item') + cpu_analyze = TimeAnalyze(value) + index, low_band, min_ratio, isproblem = cpu_analyze.bandwidth_analyze() + for i, idx in enumerate(index): + node_rank = total_data['node rank'].get(idx, 'Unknown') + band = low_band[i] + log_info = f'the node rank {node_rank} may have problem, ' + \ + f'since the collected value {band} is lower ' + \ + f'the mean value than {min_ratio}' + result.append(log_info) + + if isproblem: + with FdOpen(os.path.join(data_path, file_name), 'w') as f: + for line in result: + f.write(line + '\n') + logging.info(f"The result can be read in file {file_name} at path {data_path}") + else: + logging.info("IO check is OK...") + logging.info("Finished the IO analyzation.") + + def generate_log_message(self, index): """生成日志消息""" @@ -102,24 +95,24 @@ class IOCheck(HardwareCheck): """生成文件路径""" if file_type == "txt": return os.path.join(self.tmp_work_dir, - f'worker_{self.global_worker_id}_data_{index}.txt') + f'node_{self.node_rank}_data_{index}.txt') elif file_type == "npy": return os.path.join(self.tmp_work_dir, - f'worker_{self.global_worker_id}_data_{index}.npy') + f'node_{self.node_rank}_data_{index}.npy') else: raise ValueError(f"file type {file_type} is not included in the list [txt, npy]") - def generate_random_weights(self, shape=(4, 2048, 4096)): + def generate_random_weights(self, shape=(16, 2048, 4096)): """生成随机权重""" return np.random.randn(*shape).astype(np.float32) def save_data(self, index, data, file_type='npy'): if file_type == 'txt': - file_path = os.path.join(self.tmp_work_dir, f'worker_{self.global_worker_id}_data_{index}.txt') + file_path = os.path.join(self.tmp_work_dir, f'node_{self.node_rank}_data_{index}.txt') with FdOpen(file_path, 'w') as f: f.write(str(data)) elif file_type == 'npy': - file_path = os.path.join(self.tmp_work_dir, f'worker_{self.global_worker_id}_data_{index}.npy') + file_path = os.path.join(self.tmp_work_dir, f'node_{self.node_rank}_data_{index}.npy') np.save(file_path, data) else: raise ValueError(f"file type {file_type} is not included in the list [txt, npy]") @@ -131,7 +124,7 @@ class IOCheck(HardwareCheck): def calculate_speed(self, start_time, end_time, data_size): """计算读写速度""" elapsed_time = end_time - start_time - return data_size / elapsed_time / (1024 ** 2) + return data_size / elapsed_time / (1024 ** 2) def generate_file_like_log(self, iteration: int = 10000, scaler: int = 100): """生成数据文件并计算写入速度""" @@ -140,7 +133,7 @@ class IOCheck(HardwareCheck): start_time = time.time() total_data_size = 0 - with FdOpen(file_path, 'w', encoding='utf-8') as file: + with FdOpen(file_path, 'w') as file: for i in range(iteration): log_message = self.generate_log_message(i) * scaler file.write(log_message) @@ -148,7 +141,7 @@ class IOCheck(HardwareCheck): total_data_size += len(log_message) end_time = time.time() write_speed = self.calculate_speed(start_time, end_time, total_data_size) - self.results[self.global_worker_id]['log_write'].append(write_speed) + self.results['log_write'].append(write_speed) return file_path def read_file_like_log(self, file_path=None): @@ -158,13 +151,13 @@ class IOCheck(HardwareCheck): try: start_time = time.time() data = "" - with FileOpen(file_path, 'r', encoding='utf-8') as file: + with open(file_path, 'r') as file: for line in file: - data += line + data = line end_time = time.time() data_size = len(data.encode('utf-8')) read_speed = self.calculate_speed(start_time, end_time, data_size) - self.results[self.global_worker_id]['log_read'].append(read_speed) + self.results['log_read'].append(read_speed) return data except FileNotFoundError: logging.error(f"File {file_path} not found.") @@ -185,7 +178,7 @@ class IOCheck(HardwareCheck): np.save(file, weight) end_time = time.time() write_speed = self.calculate_speed(start_time, end_time, total_data_size) - self.results[self.global_worker_id]['ckpt_write'].append(write_speed) + self.results['ckpt_write'].append(write_speed) return file_path def read_file_like_ckpt(self, file_path=None): @@ -193,13 +186,12 @@ class IOCheck(HardwareCheck): if file_path is None: file_path = self.get_file_path(self.local_rank, file_type='npy') try: + data_size = os.path.getsize(file_path) start_time = time.time() - with FileOpen(file_path, 'r') as f: - data = np.load(f) - end_time = time.time() - data_size = data.nbytes + data = np.load(file_path) + end_time = time.time() read_speed = self.calculate_speed(start_time, end_time, data_size) - self.results[self.global_worker_id]['ckpt_read'].append(read_speed) + self.results['ckpt_read'].append(read_speed) return data except FileNotFoundError: logging.error(f"File {file_path} not found.") @@ -229,30 +221,31 @@ class IOCheck(HardwareCheck): def collect(self, data_path): """模拟IO操作,收集结果""" - logging.info("======Statrting IO operation...") - logging.info("=====Starting log file...") + if self.node_rank == 0: + logging.info("Statrting IO operation...") try: log_file_path = self.generate_file_like_log() self.read_file_like_log(log_file_path) except Exception as e: logging.error(f"Error reading log file: {e}") - logging.info("=====Starting file operation...") try: ckpt_file_path = self.generate_file_like_ckpt() self.read_file_like_ckpt(ckpt_file_path) except Exception as e: logging.error(f"Error reading checkpoint file: {e}") - logging.info("======Finished operation finished") + """生成临时结果用于分析""" - file_name = os.path.join(data_path, f"worker_{self.global_worker_id}_{self.local_rank}.json") + file_name = os.path.join(data_path, + f"node_{self.node_rank}_{self.local_rank}_io.json") try: with FdOpen(file_name, 'w') as f: json.dump(self.results, f, indent=4) - logging.info(f"Results have been saved to {data_path}") except Exception as e: logging.error(f"Error saving results to {data_path}: {e}") self.clean_cache_file('npy') self.clean_cache_file('txt') + if self.node_rank == 0: + logging.info("Finished operation finished")