diff --git a/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx b/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx index e3c7599f98152555fc172c7325d2fffecf328144..060a18d1679c3197f986191a456746a4875a66d7 100644 --- a/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx +++ b/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx @@ -176,6 +176,7 @@ export const MemoryView: React.FC = React.memo((props) => { const getName = React.useCallback((row: any) => row[searchIndex], [ searchIndex ]) + const getNameAscend = (row: any) => row[0] const [searchedTableDataRows] = useSearchDirectly( searchOperatorName, getName, @@ -183,7 +184,7 @@ export const MemoryView: React.FC = React.memo((props) => { ) const [searchedEventsTableDataRows] = useSearchDirectly( searchEventOperatorName, - getName, + deviceTarget === 'Ascend' ? getNameAscend : getName, filterByEventSize( memoryEventsData?.rows[device], filterEventSize[device] ?? [0, Infinity] diff --git a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py index 592c395c9bc33322b283cac0586a766ac37ec4ba..0150fcdcf1ec2191c69715e877d4d39d608d0b61 100644 --- a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py +++ b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py @@ -206,7 +206,7 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): group_by = request.args.get('group_by') input_shape = request.args.get('input_shape') if group_by == 'OperationAndInputShape': - return self.respond_as_json(profile.operation_stack_by_name_input[str(op_name)+'###'+str(input_shape)]) + return self.respond_as_json(profile.operation_stack_by_name_input[str(op_name) + '###' + str(input_shape)]) else: return self.respond_as_json(profile.operation_stack_by_name[str(op_name)]) @@ -293,7 +293,7 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): end_ts = int(end_ts) return self.respond_as_json( - profile.get_memory_stats(start_ts=start_ts, end_ts=end_ts, memory_metric=memory_metric), True) + profile.get_memory_stats(start_ts=start_ts, end_ts=end_ts, memory_metric=memory_metric), True) @wrappers.Request.application def memory_curve_route(self, request: werkzeug.Request): @@ -315,10 +315,20 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): memory_metric = request.args.get('memory_metric', 'KB') if profile.device_target == 'Ascend': operator_memory_events = profile.memory_events['operator']['rows'] - start_ts = int(start_ts) if start_ts is not None else 0 - end_ts = int(end_ts) if end_ts is not None else float('inf') + if start_ts is not None: + start_ts = int(start_ts) + if end_ts is not None: + end_ts = int(end_ts) for key in operator_memory_events: - operator_memory_events[key] = [i for i in operator_memory_events[key] if start_ts <= i[2] <= end_ts] + if start_ts is not None and end_ts is not None: + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and start_ts <= i[2] <= end_ts] + elif start_ts is not None: + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and start_ts <= i[2]] + elif end_ts is not None: + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and end_ts >= i[2]] return self.respond_as_json(profile.memory_events, True) else: if start_ts is not None: diff --git a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py index 7fd44d6e6e8d61c0e421b0a852ce6cff1319256a..846c6db42940f3e94f1b19203a2d25bb8e5d22af 100644 --- a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -119,8 +119,12 @@ class RunGenerator(object): if len(datas) <= 1: return operator_by_name, operator_by_name_and_input_shapes for ls in datas[1:]: - temp: list = [ls[0], RunGenerator._trans_shape(str(ls[1])), ls[2], float(ls[3]), float(ls[4]), - float(ls[5]), float(ls[6]), float(ls[7]), float(ls[8])] + try: + temp: list = [ls[0], RunGenerator._trans_shape(str(ls[1])), ls[2], float(ls[3]), float(ls[4]), + float(ls[5]), float(ls[6]), float(ls[7]), float(ls[8])] + except (ValueError, IndexError): + logger.error('Data in file "operator_details.csv" has wrong format.') + return operator_by_name, operator_by_name_and_input_shapes operator_by_name[ls[0]].append(temp) key = "{}###{}".format(str(ls[0]), RunGenerator._trans_shape(str(ls[1]))) operator_by_name_and_input_shapes[key].append(temp) @@ -279,7 +283,7 @@ class RunGenerator(object): return temp def _get_memory_event(self, peak_memory_events: dict): - display_columns = ('Operator', 'Size(KB)', 'Allocation Time(us)', 'Release Time(us)', 'Duration(us)') + display_columns = ('Name', 'Size(KB)', 'Allocation Time(us)', 'Release Time(us)', 'Duration(us)') path = self.profile_data.memory_operator_path display_datas = defaultdict(list) devices_type = [] @@ -296,7 +300,7 @@ class RunGenerator(object): if column == 'Device Type': self.device_type_form_idx = idx if column in display_columns: - if column == 'Operator': + if column == 'Name': table['columns'].append({'name': column, 'type': 'string'}) elif column == 'Size(KB)': table['columns'].append({'name': column, 'type': 'number'}) @@ -306,12 +310,11 @@ class RunGenerator(object): for ls in datas[1:]: device_type = ls[self.device_type_form_idx] # convert time metric 'us' to 'ms' - nums = [ls[0], float(ls[1]), round((float(ls[2]) - self.profile_data.start_ts) / 1000, 3)] - # some operators may not have column[3] or column[4] - if ls[3]: - nums.append(round((float(ls[3]) - self.profile_data.start_ts) / 1000, 3)) - if ls[4]: - nums.append(round(float(ls[4]) / 1000, 2)) + # some operators may not have the following columns + nums = [ls[0] if ls[0] else '', abs(float(ls[1])), + round((float(ls[2]) - self.profile_data.start_ts) / 1000, 2) if ls[2] else None, + round((float(ls[3]) - self.profile_data.start_ts) / 1000, 2) if ls[3] else None, + round(float(ls[4]) / 1000, 2) if ls[4] else None] display_datas[device_type].append(nums) table['rows'] = display_datas for name in display_datas: