From 09fdbacb1bb4d5ee8c362546c5ad683b87057ca0 Mon Sep 17 00:00:00 2001 From: s00820771 Date: Sun, 20 Apr 2025 12:03:16 +0800 Subject: [PATCH 1/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/utils/transfer_dock.py | 101 ++++- .../trainer/utils/transfer_dock_profiler.py | 348 ++++++++++++++++++ 2 files changed, 435 insertions(+), 14 deletions(-) create mode 100644 mindspeed_rl/trainer/utils/transfer_dock_profiler.py diff --git a/mindspeed_rl/trainer/utils/transfer_dock.py b/mindspeed_rl/trainer/utils/transfer_dock.py index e368566..13c2229 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock.py +++ b/mindspeed_rl/trainer/utils/transfer_dock.py @@ -13,6 +13,8 @@ from torch import Tensor from torch.nn import functional as F from torch.nn.utils.rnn import pad_sequence from mindspeed_rl.utils.loggers import Loggers +from mindspeed_rl.trainer.utils.transfer_dock_profiler import TransferDockProfiler, get_var_memory_size + logger = Loggers("transfer_dock") @@ -68,6 +70,16 @@ class TransferDock(ABC): self.timeout = timeout if timeout is not None else 300 self.timeout_interval = timeout_interval if timeout_interval is not None else 5 + self.profiler = TransferDockProfiler(self) + if self.profiler is not None: + self.profiler.record_memory_statistic(key='experience_data_frame', + value=get_var_memory_size(self.experience_data)) + self.profiler.record_memory_statistic(key='experience_data_status_frame', + value=get_var_memory_size(self.experience_data_status)) + self.profiler.record_memory_statistic(key='experience_data_status', + value=sum([get_var_memory_size(value) + for value in self.experience_data_status.values()])) + def _put( self, experience_columns: List[str], @@ -181,7 +193,9 @@ class TransferDock(ABC): data_ready = sum(itemgetter(*indexes)(self.experience_data_status[single_column])) == len(indexes) start_time = time.time() + data_ready_flag = data_ready while not data_ready: + self.profiler and self.profiler.record_wait_time_begin(single_column) elapsed_time = time.time() - start_time if ( elapsed_time > self.timeout @@ -195,6 +209,8 @@ class TransferDock(ABC): data_ready = sum( itemgetter(*indexes)(self.experience_data_status[single_column]) ) == len(indexes) + if not data_ready_flag and self.profiler: + self.profiler.record_wait_time_end(single_column) def _clear_experience_data_and_status(self, indexes=None): """Clear data and data status in TransferDock. @@ -211,36 +227,74 @@ class TransferDock(ABC): key: torch.zeros(self.max_len, dtype=torch.int32) for key in self.experience_columns } + self.profiler and self.profiler.clear_experience_data_statistic() else: for key in self.experience_columns: self.experience_data_status[key][indexes] = 0 for key in self.experience_columns: for idx in indexes: + self.profiler and self.profiler.decrement_memory_statistic( + key=f'experience_data_value_{key}', + value=get_var_memory_size(self.experience_data[key][idx])) self.experience_data[key][idx] = None - def get_experience_data(self): - """Get all data in TransferDock. + def get_experience_len(self): + """Get the maximum length of data in TransferDock. - Returns: Data dict. + Returns: The maximum length of data. """ - return self.experience_data + return self.max_len - def get_experience_status(self): - """Get all data status in TransferDock. + def get_experience_status(self, column: str = None, index: int = None): + """get experience status - Returns: Data status dict. + Args: + column: status column + index: try to get column index + + Returns: status information """ - return self.experience_data_status + if column is None: + return self.experience_data_status - def get_experience_len(self): - """Get the maximum length of data in TransferDock. + if column not in self.experience_data_status: + logger.warning(f"{column} not in self.experience_data_status") + return None - Returns: The maximum length of data. + if index is not None: + if index >= self.max_len: + logger.warning(f"{index} exceeds the Transfer Dock range {self.max_len}") + return None + return self.experience_data_status[column][index] + + return self.experience_data_status[column] + + def get_experience_data(self, column: str = None, index: int = None): + """get experience data + + Args: + column: data column + index: try to get column index + + Returns: data value """ - return self.max_len + if not column: + return self.experience_data + + if column not in self.experience_data: + logger.warning(f"{column} not in self.experience_data") + return None + + if index: + if index >= self.max_len: + logger.warning(f"{index} exceeds the Transfer Dock range {self.max_len}") + return None + return self.experience_data[column][index] + + return self.experience_data[column] @ray.remote(max_concurrency=100, num_cpus=10) @@ -326,9 +380,11 @@ class GRPOTransferDock(TransferDock): self.metrics = metrics def get_metrics(self): + self.profiler and self.profiler.increment_remote_caller() return self.metrics def update_metrics(self, key="", value=None, cumulate=False): + self.profiler and self.profiler.increment_remote_caller() self.metrics.update(key, value, cumulate=cumulate) def get_experience( @@ -356,6 +412,8 @@ class GRPOTransferDock(TransferDock): Returns: Data dict and row numbers. """ + self.profiler and self.profiler.increment_remote_caller() + self.profiler and self.profiler.increment_get_data_count(experience_columns) if consumer not in self.experience_consumers: raise ValueError( f"get experience ERROR: {consumer} not in TD experience_consumers {self.experience_consumers}" @@ -417,11 +475,13 @@ class GRPOTransferDock(TransferDock): """ + self.profiler and self.profiler.increment_remote_caller() if not indexes: raise ValueError( "put experience into TD without indexes, indexes must be provided" ) experience_columns, experience = trans_input_to_experience(data_dict) + self.profiler and self.profiler.increment_put_data_count(experience_columns) self._put(experience_columns, experience, indexes) def put_prompts_experience( @@ -436,7 +496,7 @@ class GRPOTransferDock(TransferDock): Returns: None """ - + self.profiler and self.profiler.increment_remote_caller() prompts = batch["prompts"] prompt_length = [] for prompt in prompts: @@ -463,7 +523,7 @@ class GRPOTransferDock(TransferDock): {"prompt_length": prompt_length, "prompts": prompts}, **add_vals ) experience_columns, experience = trans_input_to_experience(data_dict) - + self.profiler and self.profiler.increment_put_data_count(experience_columns) self._put(experience_columns, experience, indexes) def _sample_ready_index( @@ -484,6 +544,8 @@ class GRPOTransferDock(TransferDock): """ + self.profiler and self.profiler.record_wait_lock_time_begin( + consumer, experience_count, experience_columns) with self.consumer_sampling_lock[consumer]: not_consumed_indexes = self.experience_consumer_status[consumer] == 0 data_ready_indexes = torch.all( @@ -501,6 +563,8 @@ class GRPOTransferDock(TransferDock): experience_columns, usable_indexes, experience_count, target_seq_len ) self.experience_consumer_status[consumer][sampled_indexes] = 1 + self.profiler and self.profiler.record_wait_lock_time_end( + consumer, experience_count, experience_columns) else: sampled_indexes = None @@ -524,6 +588,8 @@ class GRPOTransferDock(TransferDock): Returns: Sampled row numbers. """ + self.profiler and self.profiler.record_wait_lock_time_begin( + consumer, experience_count, experience_columns) experience_count_n_samples = experience_count // self.n_samples_per_prompt with self.consumer_sampling_lock[consumer]: experience_consumer_status_n_samples = ( @@ -577,6 +643,9 @@ class GRPOTransferDock(TransferDock): self.experience_consumer_status[consumer][sampled_indexes] = 1 + self.profiler and self.profiler.record_wait_lock_time_end( + consumer, experience_count, experience_columns) + return sampled_indexes def all_consumed(self, consumer: str): @@ -588,6 +657,7 @@ class GRPOTransferDock(TransferDock): Returns: True or False. """ + self.profiler and self.profiler.increment_remote_caller() return self.experience_consumer_status[consumer].sum() == self.max_len def clear(self): @@ -596,6 +666,7 @@ class GRPOTransferDock(TransferDock): Returns: None """ + self.profiler and self.profiler.increment_remote_caller() self.experience_consumer_status = { key: torch.zeros(self.max_len, dtype=torch.int32) for key in self.experience_consumers @@ -609,11 +680,13 @@ class GRPOTransferDock(TransferDock): Returns: Consumer status dict. """ + self.profiler and self.profiler.increment_remote_caller() return self.experience_consumer_status def batch_balencing_sampler( self, experience_columns, usable_indexes, experience_count, target_seq_len=None ): + self.profiler and self.profiler.increment_remote_caller() if target_seq_len is None: weights = torch.ones(len(usable_indexes)) else: diff --git a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py new file mode 100644 index 0000000..785ad3d --- /dev/null +++ b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py @@ -0,0 +1,348 @@ +# coding=utf-8 +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + + +import os +import sys +from typing import List +import threading +import time +import datetime +import inspect +import torch +import psutil +import acl + +from mindspeed_rl.utils.loggers import Loggers + + +logger = Loggers("transfer_dock_profiler") + + +def get_var_memory_size(var): + if isinstance(var, torch.Tensor): + return var.element_size() * var.nelement() + elif isinstance(var, list): + return sum([element.element_size() * element.nelement() + for element in var if isinstance(element, torch.Tensor)]) + if isinstance(var, dict): + memory_size = 0 + for value in var.values(): + if isinstance(value, torch.Tensor): + memory_size += value.element_size() * value.nelement() + elif isinstance(value, list): + memory_size += sum([element.element_size() * element.nelement() + for element in value if isinstance(element, torch.Tensor)]) + if memory_size > 0: + return memory_size + + try: + from pympler import asizeof + return asizeof.asizeof(var) + except ImportError: + return sys.getsizeof(var) + + +def get_current_time(): + now = datetime.datetime.now() + return (f"{now.year}/{now.month:02d}/{now.day:02d} " + f"{now.hour:02d}:{now.minute:02d}:{now.second:02d}.{now.microsecond:06d}") + + +class TransferDockProfiler: + """ + TransferDockProfiler is a profiler tools for TransferDock + """ + def __init__(self, td_instance, + profiler_interval_secs: int = 5, + profiler_level: int = 2, + profiler_file_path: str = 'td_profiling.log', + profiler_control_file: str = '.profiling_ctrl'): + """TransferDockProfiler initialize + + Args: + profiler_interval_secs: task schedule interval, seconds + profiler_level: profiling level, higher number with more detail + profiler_file_path: if not None, dump data on screen, else write to file + profiler_control_file: for dynamic control profiler output data + """ + self.td_instance = td_instance + self.profiler_interval_secs = profiler_interval_secs + self.profiler_level = profiler_level + self.profiler_file_path = profiler_file_path + self.profiler_control_file = profiler_control_file + self.memory_statistic = {} + self.op_statistic = { + 'put_experience': 0, + 'put_experience_columns': {}, + 'get_experience': 0, + 'get_experience_columns': {}, + 'get_metrics': 0, + 'update_metrics': 0, + 'all_consumed': 0, + 'clear': 0, + 'get_consumer_status': 0, + } + self.conflict_lock_time = {} + self.wait_experience_time = {} + self.format_key_pad = 64 + self.split_line = '-' * (self.format_key_pad + 16) + if os.path.exists(self.profiler_file_path): + os.remove(self.profiler_file_path) + self.start_profiling_timer() + + def start_profiling_timer(self): + threading.Timer(self.profiler_interval_secs, self.interval_task).start() + + def interval_task(self): + self.output_data('+' * 80) + profiling_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + self.output_data(f"PROFILING TIME: {profiling_time}") + self.output_data('+' * 80) + self.dump_host_os_information() + self.dump_device_memory() + self.dump_memory_statistic() + self.dump_op_statistic() + if self.profiler_level >= 1: + self.dump_conflict_lock_time() + self.dump_wait_experience_time() + if self.profiler_level >= 2: + self.dump_custom_statistic() + self.start_profiling_timer() + + def record_memory_statistic(self, key: str, value: int): + """ + memory profiling + """ + if key not in self.memory_statistic: + self.memory_statistic[key] = value + else: + self.memory_statistic[key] += value + + def decrement_memory_statistic(self, key: str, value: int): + """ + memory profiling + """ + if key in self.memory_statistic: + if value <= self.memory_statistic[key]: + self.memory_statistic[key] -= value + + def output_data(self, data: str): + if self.profiler_file_path is None: + logger.info(data) + else: + with open(self.profiler_file_path, 'a') as f: + f.write(data + '\n') + + def dump_device_memory(self): + ret = acl.init() + if ret != 0: + return + device_count, ret = acl.rt.get_device_count() + if ret != 0: + return + + self.output_data(f"dump host operation system information:") + self.output_data(self.split_line) + for i in range(device_count): + ret = acl.rt.set_device(i) + if ret != 0: + continue + free_mem_hbm, total_mem_hbm, ret_hbm = acl.rt.get_mem_info(1) + if ret_hbm == 0: + hbm_total = f" device {i} HBM total".ljust(self.format_key_pad) + hbm_free = f" device {i} HBM free".ljust(self.format_key_pad) + hbm_used = f" device {i} HBM use".ljust(self.format_key_pad) + self.output_data(f"{hbm_total}: {total_mem_hbm / (1024 ** 3):.2f} GB") + self.output_data(f"{hbm_free}: {free_mem_hbm / (1024 ** 3):.2f} GB") + self.output_data(f"{hbm_used}: {(total_mem_hbm - free_mem_hbm) / (1024 ** 3):.2f} GB") + acl.rt.reset_device(i) + acl.finalize() + self.output_data(self.split_line) + + def dump_host_os_information(self): + process = psutil.Process(os.getpid()) + memory_info = psutil.virtual_memory() + cpu_percent = psutil.cpu_percent(interval=1, percpu=True) + core_percent = '' + for idx, percent in enumerate(cpu_percent): + core_percent += f"core {idx}: {percent}% " + if (idx + 9) % 8 == 0: + core_percent += '\n' + + self.output_data(f"dump host operation system information:") + self.output_data(self.split_line) + td_rss_memory = " td process rss memory(MB)".ljust(self.format_key_pad) + td_cpu_percent = " td process cpu percent".ljust(self.format_key_pad) + os_total_memory = " os total memory(GB)".ljust(self.format_key_pad) + os_available_memory = " os available memory(GB)".ljust(self.format_key_pad) + os_used_memory = " os used memory(GB)".ljust(self.format_key_pad) + os_memory_percent = " os memory percent".ljust(self.format_key_pad) + os_cpu_total_percent = " os cpu total percent".ljust(self.format_key_pad) + self.output_data(f"{td_rss_memory} : {process.memory_info().rss / (1024 ** 2):.2f} MB") + self.output_data(f"{td_cpu_percent} : {process.cpu_percent(interval=1):.2f}%") + self.output_data(f"{os_total_memory} : {memory_info.total / (1024 ** 3):.2f} GB") + self.output_data(f"{os_available_memory} : {memory_info.available / (1024 ** 3):.2f} GB") + self.output_data(f"{os_used_memory} : {memory_info.used / (1024 ** 3):.2f} GB") + self.output_data(f"{os_memory_percent} : {memory_info.percent}%") + self.output_data(f"{os_cpu_total_percent} : {psutil.cpu_count() * 100}%") + self.output_data(self.split_line) + + def dump_memory_statistic(self): + total_memory = sum(self.memory_statistic.values()) + self.output_data(f"dump td alloc memory statistic:") + self.output_data(self.split_line) + self.output_data(f" total memory: {total_memory}") + self.output_data(self.split_line) + for key, value in self.memory_statistic.items(): + self.output_data(f" memory {key.ljust(self.format_key_pad - 8)}: {value}") + self.output_data(self.split_line) + + def dump_op_statistic(self): + self.output_data(f"dump operation statistic:") + self.output_data(self.split_line) + for key, value in self.op_statistic.items(): + if key == 'put_experience_columns' or key == 'get_experience_columns': + continue + self.output_data(f" operation {key.ljust(self.format_key_pad - 11)}: {value}") + + for key, value in self.op_statistic['put_experience_columns'].items(): + self.output_data(f" put_experience_columns {key.ljust(self.format_key_pad - 26)}: {value}") + for key, value in self.op_statistic['get_experience_columns'].items(): + self.output_data(f" get_experience_columns {key.ljust(self.format_key_pad - 26)}: {value}") + self.output_data(self.split_line) + + def dump_conflict_lock_time(self): + self.output_data(f"dump get experience sample ready detail:") + self.output_data(self.split_line) + for key, value in self.conflict_lock_time.items(): + for record in value: + self.output_data(f" consumer {key.ljust(self.format_key_pad - 10)}: " + f"{record['begin']} - {record['end']}") + self.output_data(self.split_line) + + def dump_wait_experience_time(self): + self.output_data(f"dump wait experience detail:") + self.output_data(self.split_line) + for key, value in self.wait_experience_time.items(): + for record in value: + self.output_data(f" wait experience {key.ljust(self.format_key_pad - 17)}: " + f"{record['begin']} - {record['end']}") + + def do_get_experience_status(self, para: List[str]): + if len(para) == 1: + status = self.td_instance.get_experience_status() + elif len(para) == 2: + status = self.td_instance.get_experience_status(column=para[1]) + elif len(para) == 3: + status = self.td_instance.get_experience_status(column=para[1], index=int(para[2])) + else: + logger.warning(f"invalid command {para}. usage: get_experience_status [column] [index]") + return + + self.output_data(f"command: {' '.join(para)}") + self.output_data(self.split_line) + self.output_data(f"result:") + self.output_data(f"{status}") + self.output_data(self.split_line) + + def do_get_experience_data(self, para: List[str]): + if len(para) == 1: + data = self.td_instance.get_experience_data() + elif len(para) == 2: + data = self.td_instance.get_experience_data(column=para[1]) + elif len(para) == 3: + data = self.td_instance.get_experience_data(column=para[1], index=int(para[2])) + else: + logger.warning(f"invalid command {para}. usage: get_experience_data [column] [index]") + return + + self.output_data(f"command: {' '.join(para)}") + self.output_data(self.split_line) + self.output_data(f"result:") + self.output_data(f"{data}") + self.output_data(self.split_line) + + def parse_command(self, content: str): + command_list = content.split(' ') + command_execute_map = { + 'get_experience_status': self.do_get_experience_status, + 'get_experience_data': self.do_get_experience_data, + } + if command_list[0] not in command_execute_map: + self.output_data(f"command {command_list[0]} is unknown") + return + command_execute_map[command_list[0]](command_list) + + def dump_custom_statistic(self): + if self.profiler_control_file is None: + return + if not os.path.exists(self.profiler_control_file): + return + + with open(self.profiler_control_file, 'r') as f: + lines = f.readlines() + for line in lines: + line = line.strip() + self.parse_command(line) + + def increment_put_data_count(self, experience_columns: List[str]): + for column in experience_columns: + if column not in self.op_statistic['put_experience_columns']: + self.op_statistic['put_experience_columns'][column] = 0 + self.op_statistic['put_experience_columns'][column] += 1 + + def increment_get_data_count(self, experience_columns: List[str]): + for column in experience_columns: + if column not in self.op_statistic['get_experience_columns']: + self.op_statistic['get_experience_columns'][column] = 0 + self.op_statistic['get_experience_columns'][column] += 1 + + def increment_remote_caller(self): + current_frame = inspect.currentframe() + function_name = current_frame.f_back.f_code.co_name + if function_name not in self.op_statistic: + self.op_statistic[function_name] = 0 + print(f"set call increment_remote_caller {function_name}") + self.op_statistic[function_name] += 1 + + def clear_experience_data_statistic(self): + for key in self.memory_statistic: + if key.startswith('experience_data_value_'): + self.memory_statistic[key] = 0 + + def record_wait_time_begin(self, single_column: str): + if single_column not in self.wait_experience_time: + self.wait_experience_time[single_column] = [] + self.wait_experience_time[single_column].append({"begin": get_current_time(), "end": None}) + + def record_wait_time_end(self, single_column: str): + if single_column not in self.wait_experience_time: + logger.warning(f"not found record wait time begin") + return + last_record = self.wait_experience_time[single_column][-1] + if 'begin' not in last_record and 'end' not in last_record: + logger.warning(f"not found last record wait time begin and end") + return + last_record['end'] = get_current_time() + + def record_wait_lock_time_begin(self, consumer: str, + experience_count: int, + experience_columns: List[str]): + key = f"{consumer}_{experience_count}_[{'/'.join(experience_columns)}]" + if key not in self.conflict_lock_time: + self.conflict_lock_time[key] = [] + if len(self.conflict_lock_time[key]) > 0: + if (self.conflict_lock_time[key][-1]['begin'] is not None and + self.conflict_lock_time[key][-1]['end'] is None): + return + self.conflict_lock_time[key].append({"begin": get_current_time(), "end": None}) + + def record_wait_lock_time_end(self, consumer: str, + experience_count: int, + experience_columns: List[str]): + key = f"{consumer}_{experience_count}_[{'/'.join(experience_columns)}]" + if key not in self.conflict_lock_time: + logger.warning(f"not found record wait lock time begin") + if len(self.conflict_lock_time[key]) > 0: + self.conflict_lock_time[key][-1]['end'] = get_current_time() -- Gitee From f75ea9f39b86b3a387743933142658d746d179a1 Mon Sep 17 00:00:00 2001 From: s00820771 Date: Sun, 20 Apr 2025 12:08:18 +0800 Subject: [PATCH 2/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/utils/transfer_dock_profiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py index 785ad3d..ad58f82 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py +++ b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py @@ -7,7 +7,7 @@ import sys from typing import List import threading import time -import datetime +from datetime import datetime, timezone import inspect import torch import psutil @@ -44,7 +44,7 @@ def get_var_memory_size(var): def get_current_time(): - now = datetime.datetime.now() + now = datetime.datetime.now(tz=timezone.utc) return (f"{now.year}/{now.month:02d}/{now.day:02d} " f"{now.hour:02d}:{now.minute:02d}:{now.second:02d}.{now.microsecond:06d}") -- Gitee From 3ff6bd5c3b9aebd20605d471ad52ec7cfed2f710 Mon Sep 17 00:00:00 2001 From: s00820771 Date: Sun, 20 Apr 2025 12:09:38 +0800 Subject: [PATCH 3/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/utils/transfer_dock_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py index ad58f82..49a649b 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py +++ b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py @@ -44,7 +44,7 @@ def get_var_memory_size(var): def get_current_time(): - now = datetime.datetime.now(tz=timezone.utc) + now = datetime.now(tz=timezone.utc) return (f"{now.year}/{now.month:02d}/{now.day:02d} " f"{now.hour:02d}:{now.minute:02d}:{now.second:02d}.{now.microsecond:06d}") -- Gitee From f66e4a5715290dee9319ea5303839ec117700a56 Mon Sep 17 00:00:00 2001 From: s00820771 Date: Sun, 20 Apr 2025 12:19:55 +0800 Subject: [PATCH 4/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/utils/transfer_dock_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py index 49a649b..e3dcfc3 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py +++ b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py @@ -185,7 +185,7 @@ class TransferDockProfiler: self.output_data(f"{os_available_memory} : {memory_info.available / (1024 ** 3):.2f} GB") self.output_data(f"{os_used_memory} : {memory_info.used / (1024 ** 3):.2f} GB") self.output_data(f"{os_memory_percent} : {memory_info.percent}%") - self.output_data(f"{os_cpu_total_percent} : {psutil.cpu_count() * 100}%") + self.output_data(f"{os_cpu_total_percent} : {psutil.cpu_percent(interval=1)}%") self.output_data(self.split_line) def dump_memory_statistic(self): -- Gitee From 258e302aff5c3fe7dcf9ef8e345bfb402afcfcc9 Mon Sep 17 00:00:00 2001 From: s00820771 Date: Sun, 20 Apr 2025 12:31:51 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/utils/transfer_dock.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindspeed_rl/trainer/utils/transfer_dock.py b/mindspeed_rl/trainer/utils/transfer_dock.py index 13c2229..5c85dc0 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock.py +++ b/mindspeed_rl/trainer/utils/transfer_dock.py @@ -132,6 +132,9 @@ class TransferDock(ABC): for i, index in enumerate(indexes): self.experience_data[single_column][index] = experience[column_idx][i] self.experience_data_status[single_column][index] = 1 + self.profiler and self.profiler.record_memory_statistic( + key=f'experience_data_value_{single_column}', + value=get_var_memory_size(self.experience_data[single_column][index])) def _get(self, experience_columns: List[str], indexes: List[int]): """Get data based on row and column numbers. -- Gitee From 55efa84a6135c31e46b563fa8934bc507ff3a61e Mon Sep 17 00:00:00 2001 From: s00820771 Date: Mon, 21 Apr 2025 11:26:26 +0800 Subject: [PATCH 6/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../trainer/utils/transfer_dock_profiler.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py index e3dcfc3..070577a 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py +++ b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py @@ -56,7 +56,7 @@ class TransferDockProfiler: def __init__(self, td_instance, profiler_interval_secs: int = 5, profiler_level: int = 2, - profiler_file_path: str = 'td_profiling.log', + profiler_file_path: str = 'outputs/td_profiling.log', profiler_control_file: str = '.profiling_ctrl'): """TransferDockProfiler initialize @@ -73,6 +73,7 @@ class TransferDockProfiler: self.profiler_control_file = profiler_control_file self.memory_statistic = {} self.op_statistic = { + 'put_prompts_experience': 0, 'put_experience': 0, 'put_experience_columns': {}, 'get_experience': 0, @@ -87,6 +88,9 @@ class TransferDockProfiler: self.wait_experience_time = {} self.format_key_pad = 64 self.split_line = '-' * (self.format_key_pad + 16) + dir_path = os.path.dirname(self.profiler_file_path) + if not os.path.exists(dir_path): + os.makedirs(dir_path) if os.path.exists(self.profiler_file_path): os.remove(self.profiler_file_path) self.start_profiling_timer() @@ -142,7 +146,7 @@ class TransferDockProfiler: if ret != 0: return - self.output_data(f"dump host operation system information:") + self.output_data(f"dump device memory information:") self.output_data(self.split_line) for i in range(device_count): ret = acl.rt.set_device(i) @@ -153,9 +157,9 @@ class TransferDockProfiler: hbm_total = f" device {i} HBM total".ljust(self.format_key_pad) hbm_free = f" device {i} HBM free".ljust(self.format_key_pad) hbm_used = f" device {i} HBM use".ljust(self.format_key_pad) - self.output_data(f"{hbm_total}: {total_mem_hbm / (1024 ** 3):.2f} GB") - self.output_data(f"{hbm_free}: {free_mem_hbm / (1024 ** 3):.2f} GB") - self.output_data(f"{hbm_used}: {(total_mem_hbm - free_mem_hbm) / (1024 ** 3):.2f} GB") + self.output_data(f"{hbm_total} : {total_mem_hbm / (1024 ** 3):.2f} GB") + self.output_data(f"{hbm_free} : {free_mem_hbm / (1024 ** 3):.2f} GB") + self.output_data(f"{hbm_used} : {(total_mem_hbm - free_mem_hbm) / (1024 ** 3):.2f} GB") acl.rt.reset_device(i) acl.finalize() self.output_data(self.split_line) -- Gitee From 412ec815a56fc18cfee8635c3ad3cb594cd44755 Mon Sep 17 00:00:00 2001 From: s00820771 Date: Mon, 21 Apr 2025 15:07:12 +0800 Subject: [PATCH 7/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../trainer/utils/transfer_dock_profiler.py | 139 ++++++++++-------- 1 file changed, 79 insertions(+), 60 deletions(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py index 070577a..0c53150 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py +++ b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py @@ -72,6 +72,7 @@ class TransferDockProfiler: self.profiler_file_path = profiler_file_path self.profiler_control_file = profiler_control_file self.memory_statistic = {} + self.memory_statistic_lock = threading.Lock() self.op_statistic = { 'put_prompts_experience': 0, 'put_experience': 0, @@ -84,8 +85,12 @@ class TransferDockProfiler: 'clear': 0, 'get_consumer_status': 0, } + self.put_experience_lock = threading.Lock() + self.get_experience_lock = threading.Lock() self.conflict_lock_time = {} + self.conflict_lock_time_lock = threading.Lock() self.wait_experience_time = {} + self.wait_experience_time_lock = threading.Lock() self.format_key_pad = 64 self.split_line = '-' * (self.format_key_pad + 16) dir_path = os.path.dirname(self.profiler_file_path) @@ -118,18 +123,20 @@ class TransferDockProfiler: """ memory profiling """ - if key not in self.memory_statistic: - self.memory_statistic[key] = value - else: - self.memory_statistic[key] += value + with self.memory_statistic_lock: + if key not in self.memory_statistic: + self.memory_statistic[key] = value + else: + self.memory_statistic[key] += value def decrement_memory_statistic(self, key: str, value: int): """ memory profiling """ - if key in self.memory_statistic: - if value <= self.memory_statistic[key]: - self.memory_statistic[key] -= value + with self.memory_statistic_lock: + if key in self.memory_statistic: + if value <= self.memory_statistic[key]: + self.memory_statistic[key] -= value def output_data(self, data: str): if self.profiler_file_path is None: @@ -193,14 +200,15 @@ class TransferDockProfiler: self.output_data(self.split_line) def dump_memory_statistic(self): - total_memory = sum(self.memory_statistic.values()) - self.output_data(f"dump td alloc memory statistic:") - self.output_data(self.split_line) - self.output_data(f" total memory: {total_memory}") - self.output_data(self.split_line) - for key, value in self.memory_statistic.items(): - self.output_data(f" memory {key.ljust(self.format_key_pad - 8)}: {value}") - self.output_data(self.split_line) + with self.memory_statistic_lock: + total_memory = sum(self.memory_statistic.values()) + self.output_data(f"dump td alloc memory statistic:") + self.output_data(self.split_line) + self.output_data(f" total memory: {total_memory}") + self.output_data(self.split_line) + for key, value in self.memory_statistic.items(): + self.output_data(f" memory {key.ljust(self.format_key_pad - 8)}: {value}") + self.output_data(self.split_line) def dump_op_statistic(self): self.output_data(f"dump operation statistic:") @@ -210,28 +218,32 @@ class TransferDockProfiler: continue self.output_data(f" operation {key.ljust(self.format_key_pad - 11)}: {value}") - for key, value in self.op_statistic['put_experience_columns'].items(): - self.output_data(f" put_experience_columns {key.ljust(self.format_key_pad - 26)}: {value}") - for key, value in self.op_statistic['get_experience_columns'].items(): - self.output_data(f" get_experience_columns {key.ljust(self.format_key_pad - 26)}: {value}") + with self.put_experience_lock: + for key, value in self.op_statistic['put_experience_columns'].items(): + self.output_data(f" put_experience_columns {key.ljust(self.format_key_pad - 26)}: {value}") + with self.get_experience_lock: + for key, value in self.op_statistic['get_experience_columns'].items(): + self.output_data(f" get_experience_columns {key.ljust(self.format_key_pad - 26)}: {value}") self.output_data(self.split_line) def dump_conflict_lock_time(self): self.output_data(f"dump get experience sample ready detail:") self.output_data(self.split_line) - for key, value in self.conflict_lock_time.items(): - for record in value: - self.output_data(f" consumer {key.ljust(self.format_key_pad - 10)}: " - f"{record['begin']} - {record['end']}") + with self.conflict_lock_time_lock: + for key, value in self.conflict_lock_time.items(): + for record in value: + self.output_data(f" consumer {key.ljust(self.format_key_pad - 10)}: " + f"{record['begin']} - {record['end']}") self.output_data(self.split_line) def dump_wait_experience_time(self): self.output_data(f"dump wait experience detail:") self.output_data(self.split_line) - for key, value in self.wait_experience_time.items(): - for record in value: - self.output_data(f" wait experience {key.ljust(self.format_key_pad - 17)}: " - f"{record['begin']} - {record['end']}") + with self.wait_experience_time_lock: + for key, value in self.wait_experience_time.items(): + for record in value: + self.output_data(f" wait experience {key.ljust(self.format_key_pad - 17)}: " + f"{record['begin']} - {record['end']}") def do_get_experience_status(self, para: List[str]): if len(para) == 1: @@ -291,16 +303,18 @@ class TransferDockProfiler: self.parse_command(line) def increment_put_data_count(self, experience_columns: List[str]): - for column in experience_columns: - if column not in self.op_statistic['put_experience_columns']: - self.op_statistic['put_experience_columns'][column] = 0 - self.op_statistic['put_experience_columns'][column] += 1 + with self.put_experience_lock: + for column in experience_columns: + if column not in self.op_statistic['put_experience_columns']: + self.op_statistic['put_experience_columns'][column] = 0 + self.op_statistic['put_experience_columns'][column] += 1 def increment_get_data_count(self, experience_columns: List[str]): - for column in experience_columns: - if column not in self.op_statistic['get_experience_columns']: - self.op_statistic['get_experience_columns'][column] = 0 - self.op_statistic['get_experience_columns'][column] += 1 + with self.get_experience_lock: + for column in experience_columns: + if column not in self.op_statistic['get_experience_columns']: + self.op_statistic['get_experience_columns'][column] = 0 + self.op_statistic['get_experience_columns'][column] += 1 def increment_remote_caller(self): current_frame = inspect.currentframe() @@ -311,42 +325,47 @@ class TransferDockProfiler: self.op_statistic[function_name] += 1 def clear_experience_data_statistic(self): - for key in self.memory_statistic: - if key.startswith('experience_data_value_'): - self.memory_statistic[key] = 0 + with self.memory_statistic_lock: + for key in self.memory_statistic: + if key.startswith('experience_data_value_'): + self.memory_statistic[key] = 0 def record_wait_time_begin(self, single_column: str): - if single_column not in self.wait_experience_time: - self.wait_experience_time[single_column] = [] - self.wait_experience_time[single_column].append({"begin": get_current_time(), "end": None}) + with self.wait_experience_time_lock: + if single_column not in self.wait_experience_time: + self.wait_experience_time[single_column] = [] + self.wait_experience_time[single_column].append({"begin": get_current_time(), "end": None}) def record_wait_time_end(self, single_column: str): - if single_column not in self.wait_experience_time: - logger.warning(f"not found record wait time begin") - return - last_record = self.wait_experience_time[single_column][-1] - if 'begin' not in last_record and 'end' not in last_record: - logger.warning(f"not found last record wait time begin and end") - return - last_record['end'] = get_current_time() + with self.wait_experience_time_lock: + if single_column not in self.wait_experience_time: + logger.warning(f"not found record wait time begin") + return + last_record = self.wait_experience_time[single_column][-1] + if 'begin' not in last_record and 'end' not in last_record: + logger.warning(f"not found last record wait time begin and end") + return + last_record['end'] = get_current_time() def record_wait_lock_time_begin(self, consumer: str, experience_count: int, experience_columns: List[str]): key = f"{consumer}_{experience_count}_[{'/'.join(experience_columns)}]" - if key not in self.conflict_lock_time: - self.conflict_lock_time[key] = [] - if len(self.conflict_lock_time[key]) > 0: - if (self.conflict_lock_time[key][-1]['begin'] is not None and - self.conflict_lock_time[key][-1]['end'] is None): - return - self.conflict_lock_time[key].append({"begin": get_current_time(), "end": None}) + with self.conflict_lock_time_lock: + if key not in self.conflict_lock_time: + self.conflict_lock_time[key] = [] + if len(self.conflict_lock_time[key]) > 0: + if (self.conflict_lock_time[key][-1]['begin'] is not None and + self.conflict_lock_time[key][-1]['end'] is None): + return + self.conflict_lock_time[key].append({"begin": get_current_time(), "end": None}) def record_wait_lock_time_end(self, consumer: str, experience_count: int, experience_columns: List[str]): key = f"{consumer}_{experience_count}_[{'/'.join(experience_columns)}]" - if key not in self.conflict_lock_time: - logger.warning(f"not found record wait lock time begin") - if len(self.conflict_lock_time[key]) > 0: - self.conflict_lock_time[key][-1]['end'] = get_current_time() + with self.conflict_lock_time_lock: + if key not in self.conflict_lock_time: + logger.warning(f"not found record wait lock time begin") + if len(self.conflict_lock_time[key]) > 0: + self.conflict_lock_time[key][-1]['end'] = get_current_time() -- Gitee From a9bd979246327216783e2d753481171614a5d59d Mon Sep 17 00:00:00 2001 From: s00820771 Date: Tue, 22 Apr 2025 21:20:44 +0800 Subject: [PATCH 8/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81td=20profiling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_rl/trainer/utils/transfer_dock.py | 16 +- .../trainer/utils/transfer_dock_profiler.py | 177 +++++++++++------- tests/ut/utils/test_transfer_dock.py | 4 + 3 files changed, 117 insertions(+), 80 deletions(-) diff --git a/mindspeed_rl/trainer/utils/transfer_dock.py b/mindspeed_rl/trainer/utils/transfer_dock.py index 5c85dc0..ae053ce 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock.py +++ b/mindspeed_rl/trainer/utils/transfer_dock.py @@ -72,13 +72,12 @@ class TransferDock(ABC): self.profiler = TransferDockProfiler(self) if self.profiler is not None: - self.profiler.record_memory_statistic(key='experience_data_frame', + self.profiler.record_memory_statistic(key='experience_data_total', value=get_var_memory_size(self.experience_data)) - self.profiler.record_memory_statistic(key='experience_data_status_frame', - value=get_var_memory_size(self.experience_data_status)) self.profiler.record_memory_statistic(key='experience_data_status', value=sum([get_var_memory_size(value) - for value in self.experience_data_status.values()])) + for value in self.experience_data_status.values()]) + + get_var_memory_size(self.experience_data_status)) def _put( self, @@ -198,7 +197,6 @@ class TransferDock(ABC): start_time = time.time() data_ready_flag = data_ready while not data_ready: - self.profiler and self.profiler.record_wait_time_begin(single_column) elapsed_time = time.time() - start_time if ( elapsed_time > self.timeout @@ -212,8 +210,6 @@ class TransferDock(ABC): data_ready = sum( itemgetter(*indexes)(self.experience_data_status[single_column]) ) == len(indexes) - if not data_ready_flag and self.profiler: - self.profiler.record_wait_time_end(single_column) def _clear_experience_data_and_status(self, indexes=None): """Clear data and data status in TransferDock. @@ -660,7 +656,11 @@ class GRPOTransferDock(TransferDock): Returns: True or False. """ - self.profiler and self.profiler.increment_remote_caller() + if self.profiler: + self.profiler.increment_remote_caller() + self.profiler.increment_all_consumed_counter() + if self.profiler.is_need_trace_profiling(): + self.profiler.trace_task() return self.experience_consumer_status[consumer].sum() == self.max_len def clear(self): diff --git a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py index 0c53150..a80fa69 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock_profiler.py +++ b/mindspeed_rl/trainer/utils/transfer_dock_profiler.py @@ -4,11 +4,12 @@ import os import sys -from typing import List +from typing import List, Dict import threading -import time from datetime import datetime, timezone import inspect +import csv +import glob import torch import psutil import acl @@ -22,17 +23,13 @@ logger = Loggers("transfer_dock_profiler") def get_var_memory_size(var): if isinstance(var, torch.Tensor): return var.element_size() * var.nelement() - elif isinstance(var, list): + elif isinstance(var, list) and isinstance(var[0], torch.Tensor): return sum([element.element_size() * element.nelement() for element in var if isinstance(element, torch.Tensor)]) if isinstance(var, dict): memory_size = 0 for value in var.values(): - if isinstance(value, torch.Tensor): - memory_size += value.element_size() * value.nelement() - elif isinstance(value, list): - memory_size += sum([element.element_size() * element.nelement() - for element in value if isinstance(element, torch.Tensor)]) + memory_size += get_var_memory_size(value) if memory_size > 0: return memory_size @@ -54,7 +51,7 @@ class TransferDockProfiler: TransferDockProfiler is a profiler tools for TransferDock """ def __init__(self, td_instance, - profiler_interval_secs: int = 5, + profiler_interval_secs: int = 1, profiler_level: int = 2, profiler_file_path: str = 'outputs/td_profiling.log', profiler_control_file: str = '.profiling_ctrl'): @@ -87,38 +84,57 @@ class TransferDockProfiler: } self.put_experience_lock = threading.Lock() self.get_experience_lock = threading.Lock() - self.conflict_lock_time = {} - self.conflict_lock_time_lock = threading.Lock() - self.wait_experience_time = {} - self.wait_experience_time_lock = threading.Lock() + self.get_waiting_time = {} + self.get_waiting_time_lock = threading.Lock() self.format_key_pad = 64 self.split_line = '-' * (self.format_key_pad + 16) + self.all_consumed_counter = 0 + self.last_profiling_time = get_current_time() dir_path = os.path.dirname(self.profiler_file_path) if not os.path.exists(dir_path): os.makedirs(dir_path) if os.path.exists(self.profiler_file_path): os.remove(self.profiler_file_path) + csv_files = glob.glob(f"{os.path.dirname(self.profiler_file_path)}/*.csv") + for csv_file in csv_files: + os.remove(csv_file) self.start_profiling_timer() + def increment_all_consumed_counter(self): + self.all_consumed_counter += 1 + + def is_need_trace_profiling(self): + return self.all_consumed_counter > 128 + def start_profiling_timer(self): threading.Timer(self.profiler_interval_secs, self.interval_task).start() def interval_task(self): + self.last_profiling_time = get_current_time() self.output_data('+' * 80) - profiling_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) - self.output_data(f"PROFILING TIME: {profiling_time}") + self.output_data(f"PROFILING TIME: {self.last_profiling_time}") self.output_data('+' * 80) self.dump_host_os_information() self.dump_device_memory() self.dump_memory_statistic() self.dump_op_statistic() if self.profiler_level >= 1: - self.dump_conflict_lock_time() - self.dump_wait_experience_time() + self.dump_get_waiting_time() if self.profiler_level >= 2: self.dump_custom_statistic() self.start_profiling_timer() + def trace_task(self): + time_format = "%Y/%m/%d %H:%M:%S.%f" + begin_time = datetime.strptime(self.last_profiling_time, time_format) + end_time = datetime.strptime(get_current_time(), time_format) + between_time = end_time - begin_time + if between_time.microseconds >= 100000: + self.interval_task() + self.all_consumed_counter = 0 + else: + self.all_consumed_counter = self.all_consumed_counter // 2 + def record_memory_statistic(self, key: str, value: int): """ memory profiling @@ -128,6 +144,8 @@ class TransferDockProfiler: self.memory_statistic[key] = value else: self.memory_statistic[key] += value + if key.startswith('experience_data_value_'): + self.memory_statistic['experience_data_total'] += value def decrement_memory_statistic(self, key: str, value: int): """ @@ -137,13 +155,23 @@ class TransferDockProfiler: if key in self.memory_statistic: if value <= self.memory_statistic[key]: self.memory_statistic[key] -= value + if key.startswith('experience_data_value_'): + self.memory_statistic['experience_data_total'] -= value def output_data(self, data: str): - if self.profiler_file_path is None: - logger.info(data) - else: - with open(self.profiler_file_path, 'a') as f: - f.write(data + '\n') + pass + + def output_csv(self, output_type: str, data: Dict[str, any]): + file_name = f"{self.profiler_file_path[:-4]}_{output_type}.csv" + file_exists = os.path.exists(file_name) + with open(file_name, mode=f"a", newline='', encoding='utf-8') as file: + writer = csv.writer(file) + items = list(data.items()) + items.insert(0, ("time", get_current_time())) + data = dict(items) + if not file_exists: + writer.writerow(list(data.keys())) + writer.writerow(list(data.values())) def dump_device_memory(self): ret = acl.init() @@ -173,14 +201,10 @@ class TransferDockProfiler: def dump_host_os_information(self): process = psutil.Process(os.getpid()) + process_mem_info = process.memory_info() + process_cpu_info = process.cpu_percent(interval=5) memory_info = psutil.virtual_memory() - cpu_percent = psutil.cpu_percent(interval=1, percpu=True) - core_percent = '' - for idx, percent in enumerate(cpu_percent): - core_percent += f"core {idx}: {percent}% " - if (idx + 9) % 8 == 0: - core_percent += '\n' - + cpu_percent = process.cpu_percent(interval=5) self.output_data(f"dump host operation system information:") self.output_data(self.split_line) td_rss_memory = " td process rss memory(MB)".ljust(self.format_key_pad) @@ -190,14 +214,21 @@ class TransferDockProfiler: os_used_memory = " os used memory(GB)".ljust(self.format_key_pad) os_memory_percent = " os memory percent".ljust(self.format_key_pad) os_cpu_total_percent = " os cpu total percent".ljust(self.format_key_pad) - self.output_data(f"{td_rss_memory} : {process.memory_info().rss / (1024 ** 2):.2f} MB") - self.output_data(f"{td_cpu_percent} : {process.cpu_percent(interval=1):.2f}%") + self.output_data(f"{td_rss_memory} : {process_mem_info.rss / (1024 ** 2):.2f} MB") + self.output_data(f"{td_cpu_percent} : {process_cpu_info:.2f}%") self.output_data(f"{os_total_memory} : {memory_info.total / (1024 ** 3):.2f} GB") self.output_data(f"{os_available_memory} : {memory_info.available / (1024 ** 3):.2f} GB") self.output_data(f"{os_used_memory} : {memory_info.used / (1024 ** 3):.2f} GB") self.output_data(f"{os_memory_percent} : {memory_info.percent}%") - self.output_data(f"{os_cpu_total_percent} : {psutil.cpu_percent(interval=1)}%") + self.output_data(f"{os_cpu_total_percent} : {cpu_percent}%") self.output_data(self.split_line) + self.output_csv(output_type="host_os", + data={ + "cpu_percent": f"{cpu_percent:.2f}", + "memory_used": f"{memory_info.used / (1024 ** 3):.2f}", + "process_cpu_percent": f"{process_cpu_info:.2f}", + "process_memory_used": f"{process_mem_info.rss / (1024 ** 2):.2f}", + }) def dump_memory_statistic(self): with self.memory_statistic_lock: @@ -209,14 +240,20 @@ class TransferDockProfiler: for key, value in self.memory_statistic.items(): self.output_data(f" memory {key.ljust(self.format_key_pad - 8)}: {value}") self.output_data(self.split_line) + self.output_csv(output_type="td_memory", + data={"total_memory": f"{total_memory / (1024 ** 2):.2f}"}) def dump_op_statistic(self): self.output_data(f"dump operation statistic:") self.output_data(self.split_line) + data_dict = {} for key, value in self.op_statistic.items(): if key == 'put_experience_columns' or key == 'get_experience_columns': continue self.output_data(f" operation {key.ljust(self.format_key_pad - 11)}: {value}") + data_dict[key] = value + self.op_statistic[key] = 0 + self.output_csv(output_type="op_statistic", data=data_dict) with self.put_experience_lock: for key, value in self.op_statistic['put_experience_columns'].items(): @@ -226,24 +263,37 @@ class TransferDockProfiler: self.output_data(f" get_experience_columns {key.ljust(self.format_key_pad - 26)}: {value}") self.output_data(self.split_line) - def dump_conflict_lock_time(self): + def dump_get_waiting_time(self): self.output_data(f"dump get experience sample ready detail:") self.output_data(self.split_line) - with self.conflict_lock_time_lock: - for key, value in self.conflict_lock_time.items(): + time_format = "%Y/%m/%d %H:%M:%S.%f" + between_time_min = 0.0 + between_time_max = 0.0 + between_time_mean = 0.0 + counter = 0 + with self.get_waiting_time_lock: + for key, value in self.get_waiting_time.items(): for record in value: self.output_data(f" consumer {key.ljust(self.format_key_pad - 10)}: " f"{record['begin']} - {record['end']}") + if record['end'] is None: + continue + counter += 1 + start_time = datetime.strptime(record['begin'], time_format) + end_time = datetime.strptime(record['end'], time_format) + time_difference = end_time - start_time + between_time_min = min(between_time_min, time_difference.microseconds) + between_time_max = max(between_time_max, time_difference.microseconds) + between_time_mean += time_difference.microseconds self.output_data(self.split_line) - - def dump_wait_experience_time(self): - self.output_data(f"dump wait experience detail:") - self.output_data(self.split_line) - with self.wait_experience_time_lock: - for key, value in self.wait_experience_time.items(): - for record in value: - self.output_data(f" wait experience {key.ljust(self.format_key_pad - 17)}: " - f"{record['begin']} - {record['end']}") + if counter > 0: + between_time_mean /= counter + self.output_csv(output_type='wait_experience_time', + data={ + "between_time_min": f"{int(between_time_min / 1000)}", + "between_time_max": f"{int(between_time_max / 1000)}", + "between_time_mean": f"{int(between_time_mean / 1000)}", + }) def do_get_experience_status(self, para: List[str]): if len(para) == 1: @@ -330,42 +380,25 @@ class TransferDockProfiler: if key.startswith('experience_data_value_'): self.memory_statistic[key] = 0 - def record_wait_time_begin(self, single_column: str): - with self.wait_experience_time_lock: - if single_column not in self.wait_experience_time: - self.wait_experience_time[single_column] = [] - self.wait_experience_time[single_column].append({"begin": get_current_time(), "end": None}) - - def record_wait_time_end(self, single_column: str): - with self.wait_experience_time_lock: - if single_column not in self.wait_experience_time: - logger.warning(f"not found record wait time begin") - return - last_record = self.wait_experience_time[single_column][-1] - if 'begin' not in last_record and 'end' not in last_record: - logger.warning(f"not found last record wait time begin and end") - return - last_record['end'] = get_current_time() - def record_wait_lock_time_begin(self, consumer: str, experience_count: int, experience_columns: List[str]): key = f"{consumer}_{experience_count}_[{'/'.join(experience_columns)}]" - with self.conflict_lock_time_lock: - if key not in self.conflict_lock_time: - self.conflict_lock_time[key] = [] - if len(self.conflict_lock_time[key]) > 0: - if (self.conflict_lock_time[key][-1]['begin'] is not None and - self.conflict_lock_time[key][-1]['end'] is None): + with self.get_waiting_time_lock: + if key not in self.get_waiting_time: + self.get_waiting_time[key] = [] + if len(self.get_waiting_time[key]) > 0: + if (self.get_waiting_time[key][-1]['begin'] is not None and + self.get_waiting_time[key][-1]['end'] is None): return - self.conflict_lock_time[key].append({"begin": get_current_time(), "end": None}) + self.get_waiting_time[key].append({"begin": get_current_time(), "end": None}) def record_wait_lock_time_end(self, consumer: str, experience_count: int, experience_columns: List[str]): key = f"{consumer}_{experience_count}_[{'/'.join(experience_columns)}]" - with self.conflict_lock_time_lock: - if key not in self.conflict_lock_time: + with self.get_waiting_time_lock: + if key not in self.get_waiting_time: logger.warning(f"not found record wait lock time begin") - if len(self.conflict_lock_time[key]) > 0: - self.conflict_lock_time[key][-1]['end'] = get_current_time() + if len(self.get_waiting_time[key]) > 0: + self.get_waiting_time[key][-1]['end'] = get_current_time() diff --git a/tests/ut/utils/test_transfer_dock.py b/tests/ut/utils/test_transfer_dock.py index 6f1721a..3c1bf3f 100644 --- a/tests/ut/utils/test_transfer_dock.py +++ b/tests/ut/utils/test_transfer_dock.py @@ -2,10 +2,12 @@ # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import random import pytest +from unittest.mock import MagicMock import ray import torch +import mindspeed_rl from mindspeed_rl.trainer.utils import TransferDock, GRPOTransferDock from tests.test_tools.dist_test import DistributedTest from mindspeed_rl.utils.metrics import Metric @@ -20,6 +22,7 @@ def setup_teardown_transfer_dock(request): self.timeout = 10 self.timeout_interval = 2 self.default_experience_columns = ["default"] + mindspeed_rl.trainer.utils.transfer_dock.TransferDockProfiler = MagicMock(return_value=None) self.td = TransferDock(prompts_num=self.prompts_num, n_samples_per_prompt=self.n_samples_per_prompt, experience_columns=self.default_experience_columns, @@ -36,6 +39,7 @@ def setup_teardown_grpo_transfer_dock_function(request): self.n_samples_per_prompt = 1 self.max_len = self.prompts_num * self.n_samples_per_prompt metrics = Metric() + mindspeed_rl.trainer.utils.transfer_dock.TransferDockProfiler = MagicMock(return_value=None) self.td = GRPOTransferDock.remote(prompts_num=self.prompts_num, n_samples_per_prompt=self.n_samples_per_prompt, metrics=metrics) -- Gitee