diff --git a/torch_npu/profiler/dynamic_profile.py b/torch_npu/profiler/dynamic_profile.py index 313fe7d388f078fed4a63ffd668f531136a132a7..5a89a0afd2e692bff680728ae0e9c7f97d282889 100644 --- a/torch_npu/profiler/dynamic_profile.py +++ b/torch_npu/profiler/dynamic_profile.py @@ -79,10 +79,9 @@ class _DynamicProfile: elif self.cur_step - self.RECORD_TIME_STEP == 1: self._step_time = max(self._min_poll_interval, int(time.time() - self._step_record_time)) self._dynamic_monitor.modify_step_time(self._step_time) + if self._step_mstx_range_id: + mstx.range_end(self._step_mstx_range_id) if self.prof: - if self._step_mstx_range_id: - mstx.range_end(self._step_mstx_range_id) - self._step_mstx_range_id = mstx.range_start(f"step {self.cur_step}", current_stream()) self.prof.step() self.step_num -= 1 if 0 == self.step_num: @@ -94,6 +93,8 @@ class _DynamicProfile: self.step_num = self.cfg_ctx.active() + self.cfg_ctx.warmup() self.enable_prof() self.cfg_ctx = None + if self.prof is None: + self._step_mstx_range_id = mstx.range_start(f"step {self.cur_step}", current_stream()) def start(self, config_path: str): if self.prof: