From acd67c7ae6cf9c984541d89d46a5fe46ec2d3feb Mon Sep 17 00:00:00 2001 From: hehongzhe <935062458@qq.com> Date: Sun, 17 Aug 2025 16:44:59 +0800 Subject: [PATCH] profiler support thread --- test/profiler/test_npu_profiler.py | 29 +++++++-- test/torch_npu_schema.json | 12 ++++ torch_npu/csrc/profiler/init.cpp | 2 + torch_npu/csrc/profiler/npu_profiler.cpp | 28 +++++++++ torch_npu/csrc/profiler/npu_profiler.h | 6 +- torch_npu/csrc/profiler/profiler_python.cpp | 70 ++++++++++++++++++--- torch_npu/profiler/profiler.py | 38 ++++++++++- 7 files changed, 172 insertions(+), 13 deletions(-) diff --git a/test/profiler/test_npu_profiler.py b/test/profiler/test_npu_profiler.py index 035adcf7e00..7183a7229db 100644 --- a/test/profiler/test_npu_profiler.py +++ b/test/profiler/test_npu_profiler.py @@ -1,7 +1,7 @@ import unittest import os import json - +import threading import torch import torch_npu @@ -34,15 +34,22 @@ class TrainModel: self.criterion = torch.nn.MSELoss() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001) - def train_one_step(self): - inputs = torch.rand(self.input_shape, requires_grad=True).to(self.device) + def train_one_step(self, device_id: int = 0, chiled_thread: bool = False): + device = self.device + if chiled_thread: + device = torch.device(f"npu:{device_id}") + torch.npu.set_device(device_id) + torch_npu.profiler.profile.enable_profiler_in_child_thread() + inputs = torch.rand(self.input_shape, requires_grad=True).to(device) inputs.register_hook(lambda grad: print("tersor backward hook")) - target = torch.rand(self.out_shape).reshape(self.out_shape[0], -1).to(self.device) + target = torch.rand(self.out_shape).reshape(self.out_shape[0], -1).to(device) output = self.model(inputs) loss = self.criterion(output, target) self.optimizer.zero_grad() loss.backward() self.optimizer.step() + if chiled_thread: + torch_npu.profiler.profile.disable_profiler_in_child_thread() class TestNpuProfiler(TestCase): @@ -204,6 +211,20 @@ class TestNpuProfiler(TestCase): else: os.environ["TASK_QUEUE_ENABLE"] = original_value + def test_single_process_multiple_devices_with_child_thread(self): + worker_name = self.worker_name + with torch_npu.profiler.profile( + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(self.results_path, worker_name=worker_name) + ) as prof: + t = threading.Thread(target=self.model_train.train_one_step, args=(1, True)) + t.start() + for _ in range(self.small_steps): + self.model_train.train_one_step() + t.join() + self.assertEqual(True, self._has_view_result(self.results_path, worker_name, self.TRACE_FILE_NAME)) + self.assertEqual(True, self._has_view_result(self.results_path, worker_name, self.KERNEL_FILE_NAME)) + self.assertEqual(True, self._has_view_result(self.results_path, worker_name, self.OPERATOR_FILE_NAME)) + def test_ascend_work_path(self): PathManager.remove_path_safety(self.results_work_path) os.environ["ASCEND_WORK_PATH"] = self.results_work_path diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index 719d6a1ea31..30cb3f2c16b 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -2174,6 +2174,12 @@ "torch_npu.profiler.profile.step": { "signature": "(self)" }, + "torch_npu.profiler.profile.enable_profiler_in_child_thread": { + "signature": "(record_shapes: bool = False, profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, with_modules: bool = False)" + }, + "torch_npu.profiler.profile.disable_profiler_in_child_thread": { + "signature": "()" + }, "torch_npu.profiler.schedule": { "signature": "(wait: int, active: int, warmup: int = 0, repeat: int = 0, skip_first: int = 0) -> None" }, @@ -2291,6 +2297,12 @@ "torch_npu.profiler.profiler.profile.step": { "signature": "(self)" }, + "torch_npu.profiler.profiler.profile.enable_profiler_in_child_thread": { + "signature": "(record_shapes: bool = False, profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, with_modules: bool = False)" + }, + "torch_npu.profiler.profiler.profile.disable_profiler_in_child_thread": { + "signature": "()" + }, "torch_npu.profiler.profiler.supported_activities": { "signature": "()" }, diff --git a/torch_npu/csrc/profiler/init.cpp b/torch_npu/csrc/profiler/init.cpp index ce73aef7036..7313fcf1b20 100644 --- a/torch_npu/csrc/profiler/init.cpp +++ b/torch_npu/csrc/profiler/init.cpp @@ -95,7 +95,9 @@ PyObject* profiler_initExtension(PyObject* _unused, PyObject *unused) py::arg("config"), py::arg("activities"), py::arg("scopes") = std::unordered_set()); + m.def("_enable_profiler_in_child_thread", &enableProfilerInChildThread, py::arg("config")); m.def("_stop_profiler", stopNpuProfiler); + m.def("_disable_profiler_in_child_thread", disableProfilerInChildThread); m.def("_finalize_profiler", finalizeNpuProfiler); m.def("_get_freq", at_npu::native::getFreq); m.def("_get_syscnt_enable", at_npu::native::isSyscntEnable); diff --git a/torch_npu/csrc/profiler/npu_profiler.cpp b/torch_npu/csrc/profiler/npu_profiler.cpp index 1aae58c4349..dbe33b75e97 100644 --- a/torch_npu/csrc/profiler/npu_profiler.cpp +++ b/torch_npu/csrc/profiler/npu_profiler.cpp @@ -316,6 +316,21 @@ void warmupNpuProfiler(const NpuProfilerConfig &config, ProfilerMgr::GetInstance()->Warmup(npu_config, cpu_trace); } +void enableProfilerInChildThread(const NpuProfilerConfig &config) +{ + std::set activities{NpuActivityType::CPU}; + auto state = std::make_shared(config, activities); + if (c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE) != nullptr) { + ASCEND_LOGE("Ascend Pytorch CPU Profiler is already enabled."); + return; + } + c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); + registerCallback({}); + if (state->tracePython()) { + python_tracer::call(python_tracer::Command::kStartOne); + } +} + void startNpuProfiler(const NpuProfilerConfig &config, const std::set &activities, const std::unordered_set &scopes) @@ -341,6 +356,19 @@ void startNpuProfiler(const NpuProfilerConfig &config, } } +void disableProfilerInChildThread() +{ + auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); + auto state_ptr = static_cast(state.get()); + if (state_ptr == nullptr) { + ASCEND_LOGE("Can't disable Ascend Pytorch CPU Profiler when it's not running."); + return; + } + if (state_ptr->tracePython()) { + python_tracer::call(python_tracer::Command::kStopOne); + } +} + void stopNpuProfiler() { auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); diff --git a/torch_npu/csrc/profiler/npu_profiler.h b/torch_npu/csrc/profiler/npu_profiler.h index 854191dfb76..1d838a7fe29 100644 --- a/torch_npu/csrc/profiler/npu_profiler.h +++ b/torch_npu/csrc/profiler/npu_profiler.h @@ -18,7 +18,7 @@ namespace torch_npu { namespace profiler { namespace python_tracer { -enum class Command { kStartOne = 0, kStartAll, kStop, kClear }; +enum class Command { kStartOne = 0, kStartAll, kStop, kStopOne, kClear }; using CallFn = void (*)(Command); void registerFunctions(CallFn call); } // python_tracer @@ -126,8 +126,12 @@ void initNpuProfiler(const std::string &path, const std::set &a void warmupNpuProfiler(const NpuProfilerConfig &config, const std::set &activities); +void enableProfilerInChildThread(const NpuProfilerConfig &config); + void startNpuProfiler(const NpuProfilerConfig &config, const std::set &activities, const std::unordered_set &scops = {}); +void disableProfilerInChildThread(); + void stopNpuProfiler(); void finalizeNpuProfiler(); diff --git a/torch_npu/csrc/profiler/profiler_python.cpp b/torch_npu/csrc/profiler/profiler_python.cpp index 571fb57bbce..cba33f749d4 100644 --- a/torch_npu/csrc/profiler/profiler_python.cpp +++ b/torch_npu/csrc/profiler/profiler_python.cpp @@ -31,6 +31,7 @@ const std::string EXIT_EVENT_DESC = "__torch_npu_profiler_python_tracer_exit"; const size_t EXIT_EVENT_HASH_ID = c10::get_hash(EXIT_EVENT_DESC); // Special hash key for exit event const std::string MODULE_NAME_DELIMITER = "######"; constexpr size_t TRACE_DUMP_THRESHOLD = 1024 * DEFAULT_BLOCK_SIZE; +constexpr size_t STACK_MAX_DEPTH = 128; using TensorMetadata = torch_npu::toolkit::profiler::TensorMetadata; using ModuleParam = torch_npu::toolkit::profiler::ModuleParam; @@ -217,7 +218,9 @@ private: static PythonTracer& singleton(); void start(size_t max_threads = max_py_threads); + void startOne(); void stop(); + void stopOne(); void clear(); size_t genPyCallHashId(PyFrameObject* frame); void recordPyCall(TraceContext* ctx, PyFrameObject* frame); @@ -301,7 +304,6 @@ void PythonTracer::start(size_t max_threads) thread_states.resize(max_threads); } - const size_t STACK_MAX_DEPTH = 128; // Register the tracer in each thread. for (const auto thread_state : thread_states) { PyThreadState_Swap(thread_state); @@ -331,6 +333,34 @@ void PythonTracer::start(size_t max_threads) } } +void PythonTracer::startOne() +{ + if (!active_.load()) { + return; + } + GilAndRestoreThread gil; + auto thread_state = gil.initial_thread_state(); + if (thread_state && thread_state->c_profilefunc == nullptr) { + PyThreadState_Swap(thread_state); + thread_local_results_.emplace_back(this); + auto* ctx = thread_local_results_.back().ctx_; + + std::vector current_stack; + auto frame = PyEval_GetFrame_NPU(); + size_t depth = 0; // Make sure we can't infinite loop. + while (frame != nullptr && depth <= STACK_MAX_DEPTH) { + current_stack.emplace_back(frame); + frame = PyFrame_GetBack(frame); + ++depth; + } + // record py call before proflier start + for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { + start_py_call_info_[reinterpret_cast(ctx)].emplace_back(genPyCallHashId(*it)); + } + PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx); + } +} + void PythonTracer::stop() { TORCH_INTERNAL_ASSERT(active_.load(), "PythonTracer is not running.", PROF_ERROR(ErrCode::INTERNAL)); @@ -360,6 +390,19 @@ void PythonTracer::stop() reportParamData(); } +void PythonTracer::stopOne() +{ + if (!active_.load()) { + return; + } + GilAndRestoreThread gil; + auto thread_state = gil.initial_thread_state(); + if (thread_state && thread_state->c_profilefunc == &PythonTracer::pyProfileFn) { + PyThreadState_Swap(thread_state); + PyEval_SetProfile(nullptr, nullptr); + } +} + void PythonTracer::clear() { TORCH_CHECK(!active_.load(), "Cannot clear state while PythonTracer is active.", PROF_ERROR(ErrCode::INTERNAL)); @@ -591,24 +634,33 @@ int PythonTracer::pyProfileFn( PyObject* arg) { auto ctx = reinterpret_cast(obj); - auto thread_local_result = ctx->thread_local_result_; + if (ctx == nullptr) { + return 0; + } + if (ctx->thread_local_result_ == nullptr) { + return 0; + } + auto active_tracer = ctx->thread_local_result_->active_tracer_; + if (active_tracer == nullptr || !active_tracer->active_.load(std::memory_order_relaxed)) { + return 0; + } switch (what) { case PyTrace_CALL: - thread_local_result->active_tracer_->recordPyCall(ctx, frame); + active_tracer->recordPyCall(ctx, frame); break; case PyTrace_C_CALL: - thread_local_result->active_tracer_->recordCCall(ctx, frame, arg); + active_tracer->recordCCall(ctx, frame, arg); break; case PyTrace_EXCEPTION: case PyTrace_RETURN: - thread_local_result->active_tracer_->recordReturn(ctx, frame, TraceTag::kPy_Return); + active_tracer->recordReturn(ctx, frame, TraceTag::kPy_Return); break; case PyTrace_C_EXCEPTION: case PyTrace_C_RETURN: - thread_local_result->active_tracer_->recordReturn(ctx, frame, TraceTag::kC_Return); + active_tracer->recordReturn(ctx, frame, TraceTag::kC_Return); break; default: @@ -621,7 +673,7 @@ void PythonTracer::call(Command c) { switch (c) { case Command::kStartOne: - PythonTracer::singleton().start(1); + PythonTracer::singleton().startOne(); break; case Command::kStartAll: @@ -632,6 +684,10 @@ void PythonTracer::call(Command c) PythonTracer::singleton().stop(); break; + case Command::kStopOne: + PythonTracer::singleton().stopOne(); + break; + case Command::kClear: PythonTracer::singleton().clear(); break; diff --git a/torch_npu/profiler/profiler.py b/torch_npu/profiler/profiler.py index 250e528be43..84fc576ff33 100644 --- a/torch_npu/profiler/profiler.py +++ b/torch_npu/profiler/profiler.py @@ -5,7 +5,13 @@ from typing import Optional, Iterable, Callable, Any, Union import torch.autograd.profiler as prof import torch_npu.npu -from torch_npu._C._profiler import ProfilerActivity +from torch_npu._C._profiler import ( + _enable_profiler_in_child_thread, + _disable_profiler_in_child_thread, + _ExperimentalConfig as C_ExperimentalConfig, + ProfilerActivity, + NpuProfilerConfig +) from torch_npu.utils._error_code import ErrCode, prof_error from .experimental_config import _ExperimentalConfig @@ -281,6 +287,36 @@ class profile(_KinetoProfile): self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num + self._step_num_offset)) self.step_rec_fn.__enter__() + @classmethod + @no_exception_func() + def enable_profiler_in_child_thread(cls, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False): + params = { + 'record_shapes': record_shapes, + 'profile_memory': profile_memory, + 'with_stack': with_stack, + 'with_flops': with_flops, + 'with_modules': with_modules + } + for param_name, param_value in params.items(): + if not isinstance(param_value, bool): + print_warn_msg(f"{param_name} in enable_profiler_in_child_thread is not bool, reset it to False.") + params[param_name] = False + npu_prof_config = NpuProfilerConfig('', params['record_shapes'], params['profile_memory'], + params['with_stack'], params['with_flops'], params['with_modules'], + C_ExperimentalConfig()) + _enable_profiler_in_child_thread(npu_prof_config) + + @classmethod + @no_exception_func() + def disable_profiler_in_child_thread(cls): + torch_npu.npu.synchronize() + _disable_profiler_in_child_thread() + @no_exception_func() def analyse(profiler_path: str, max_process_number: int = Constant.DEFAULT_PROCESS_NUMBER, -- Gitee