diff --git a/torch_npu/csrc/profiler/init.cpp b/torch_npu/csrc/profiler/init.cpp index ce73aef70366436f53f5291663a6c3e0b08e339a..5880c8bf4274612252c52aab40200e4962d5d3b8 100644 --- a/torch_npu/csrc/profiler/init.cpp +++ b/torch_npu/csrc/profiler/init.cpp @@ -95,7 +95,9 @@ PyObject* profiler_initExtension(PyObject* _unused, PyObject *unused) py::arg("config"), py::arg("activities"), py::arg("scopes") = std::unordered_set()); + m.def("_start_cpu_profiler", &startCpuProfiler, py::arg("config")); m.def("_stop_profiler", stopNpuProfiler); + m.def("_stop_cpu_profiler", stopCpuProfiler); m.def("_finalize_profiler", finalizeNpuProfiler); m.def("_get_freq", at_npu::native::getFreq); m.def("_get_syscnt_enable", at_npu::native::isSyscntEnable); diff --git a/torch_npu/csrc/profiler/npu_profiler.cpp b/torch_npu/csrc/profiler/npu_profiler.cpp index 295eda9aea1f2a425c21caaf037e4f60713a463e..f805fabbca4a9e3a340af9acb66c460477693370 100644 --- a/torch_npu/csrc/profiler/npu_profiler.cpp +++ b/torch_npu/csrc/profiler/npu_profiler.cpp @@ -315,6 +315,21 @@ void warmupNpuProfiler(const NpuProfilerConfig &config, ProfilerMgr::GetInstance()->Warmup(npu_config, cpu_trace); } +void startCpuProfiler(const NpuProfilerConfig &config) +{ + std::set activities{NpuActivityType::CPU}; + auto state = std::make_shared(config, activities); + if (c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE) != nullptr) { + ASCEND_LOGE("Ascend Pytorch CPU Profiler is already enabled."); + return; + } + c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); + registerCallback({}); + if (state->tracePython()) { + python_tracer::call(python_tracer::Command::kStartOne); + } +} + void startNpuProfiler(const NpuProfilerConfig &config, const std::set &activities, const std::unordered_set &scopes) @@ -340,6 +355,19 @@ void startNpuProfiler(const NpuProfilerConfig &config, } } +void stopCpuProfiler() +{ + auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); + auto state_ptr = static_cast(state.get()); + if (state_ptr == nullptr) { + ASCEND_LOGE("Can't disable Ascend Pytorch CPU Profiler when it's not running."); + return; + } + if (state_ptr->tracePython()) { + python_tracer::call(python_tracer::Command::kStopOne); + } +} + void stopNpuProfiler() { auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); diff --git a/torch_npu/csrc/profiler/npu_profiler.h b/torch_npu/csrc/profiler/npu_profiler.h index 2127825bc134e0b49178fe00890cdff58011e62c..e6ab8a7244b9c5e8bb3a5e08ec002d4b75f4d943 100644 --- a/torch_npu/csrc/profiler/npu_profiler.h +++ b/torch_npu/csrc/profiler/npu_profiler.h @@ -15,7 +15,7 @@ namespace torch_npu { namespace profiler { namespace python_tracer { -enum class Command { kStartOne = 0, kStartAll, kStop, kClear }; +enum class Command { kStartOne = 0, kStartAll, kStop, kStopOne, kClear }; using CallFn = void (*)(Command); void registerFunctions(CallFn call); } // python_tracer @@ -123,8 +123,12 @@ void initNpuProfiler(const std::string &path, const std::set &a void warmupNpuProfiler(const NpuProfilerConfig &config, const std::set &activities); +void startCpuProfiler(const NpuProfilerConfig &config); + void startNpuProfiler(const NpuProfilerConfig &config, const std::set &activities, const std::unordered_set &scops = {}); +void stopCpuProfiler(); + void stopNpuProfiler(); void finalizeNpuProfiler(); diff --git a/torch_npu/csrc/profiler/profiler_python.cpp b/torch_npu/csrc/profiler/profiler_python.cpp index 571fb57bbce2c18e24ac6cd26aaedcd656968bd1..324ad16b6eaedf3b2ee85d9cfcaf6467a31ce85b 100644 --- a/torch_npu/csrc/profiler/profiler_python.cpp +++ b/torch_npu/csrc/profiler/profiler_python.cpp @@ -31,6 +31,7 @@ const std::string EXIT_EVENT_DESC = "__torch_npu_profiler_python_tracer_exit"; const size_t EXIT_EVENT_HASH_ID = c10::get_hash(EXIT_EVENT_DESC); // Special hash key for exit event const std::string MODULE_NAME_DELIMITER = "######"; constexpr size_t TRACE_DUMP_THRESHOLD = 1024 * DEFAULT_BLOCK_SIZE; +constexpr size_t STACK_MAX_DEPTH = 128; using TensorMetadata = torch_npu::toolkit::profiler::TensorMetadata; using ModuleParam = torch_npu::toolkit::profiler::ModuleParam; @@ -217,7 +218,9 @@ private: static PythonTracer& singleton(); void start(size_t max_threads = max_py_threads); + void startOne(); void stop(); + void stopOne(); void clear(); size_t genPyCallHashId(PyFrameObject* frame); void recordPyCall(TraceContext* ctx, PyFrameObject* frame); @@ -301,7 +304,6 @@ void PythonTracer::start(size_t max_threads) thread_states.resize(max_threads); } - const size_t STACK_MAX_DEPTH = 128; // Register the tracer in each thread. for (const auto thread_state : thread_states) { PyThreadState_Swap(thread_state); @@ -331,6 +333,34 @@ void PythonTracer::start(size_t max_threads) } } +void PythonTracer::startOne() +{ + if (!active_.load()) { + return; + } + GilAndRestoreThread gil; + auto thread_state = gil.initial_thread_state(); + if (thread_state && thread_state->c_profilefunc == nullptr) { + PyThreadState_Swap(thread_state); + thread_local_results_.emplace_back(this); + auto* ctx = thread_local_results_.back().ctx_; + + std::vector current_stack; + auto frame = PyEval_GetFrame_NPU(); + size_t depth = 0; // Make sure we can't infinite loop. + while (frame != nullptr && depth <= STACK_MAX_DEPTH) { + current_stack.emplace_back(frame); + frame = PyFrame_GetBack(frame); + ++depth; + } + // record py call before proflier start + for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { + start_py_call_info_[reinterpret_cast(ctx)].emplace_back(genPyCallHashId(*it)); + } + PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx); + } +} + void PythonTracer::stop() { TORCH_INTERNAL_ASSERT(active_.load(), "PythonTracer is not running.", PROF_ERROR(ErrCode::INTERNAL)); @@ -360,6 +390,19 @@ void PythonTracer::stop() reportParamData(); } +void PythonTracer::stopOne() +{ + if (!active_.load()) { + return; + } + GilAndRestoreThread gil; + auto thread_state = gil.initial_thread_state(); + if (thread_state && thread_state->c_profilefunc == &PythonTracer::pyProfileFn) { + PyThreadState_Swap(thread_state); + PyEval_SetProfile(nullptr, nullptr); + } +} + void PythonTracer::clear() { TORCH_CHECK(!active_.load(), "Cannot clear state while PythonTracer is active.", PROF_ERROR(ErrCode::INTERNAL)); @@ -591,24 +634,27 @@ int PythonTracer::pyProfileFn( PyObject* arg) { auto ctx = reinterpret_cast(obj); - auto thread_local_result = ctx->thread_local_result_; + auto active_tracer = ctx->thread_local_result_->active_tracer_; + if (active_tracer == nullptr || !active_tracer->active_.load(std::memory_order_relaxed)) { + return 0; + } switch (what) { case PyTrace_CALL: - thread_local_result->active_tracer_->recordPyCall(ctx, frame); + active_tracer->recordPyCall(ctx, frame); break; case PyTrace_C_CALL: - thread_local_result->active_tracer_->recordCCall(ctx, frame, arg); + active_tracer->recordCCall(ctx, frame, arg); break; case PyTrace_EXCEPTION: case PyTrace_RETURN: - thread_local_result->active_tracer_->recordReturn(ctx, frame, TraceTag::kPy_Return); + active_tracer->recordReturn(ctx, frame, TraceTag::kPy_Return); break; case PyTrace_C_EXCEPTION: case PyTrace_C_RETURN: - thread_local_result->active_tracer_->recordReturn(ctx, frame, TraceTag::kC_Return); + active_tracer->recordReturn(ctx, frame, TraceTag::kC_Return); break; default: @@ -621,7 +667,7 @@ void PythonTracer::call(Command c) { switch (c) { case Command::kStartOne: - PythonTracer::singleton().start(1); + PythonTracer::singleton().startOne(); break; case Command::kStartAll: @@ -632,6 +678,10 @@ void PythonTracer::call(Command c) PythonTracer::singleton().stop(); break; + case Command::kStopOne: + PythonTracer::singleton().stopOne(); + break; + case Command::kClear: PythonTracer::singleton().clear(); break; diff --git a/torch_npu/profiler/profiler.py b/torch_npu/profiler/profiler.py index 250e528be4355c469b327f6d9284d5f938951044..09734f1f360d3b0d938d01b7e626082c5e452393 100644 --- a/torch_npu/profiler/profiler.py +++ b/torch_npu/profiler/profiler.py @@ -5,7 +5,13 @@ from typing import Optional, Iterable, Callable, Any, Union import torch.autograd.profiler as prof import torch_npu.npu -from torch_npu._C._profiler import ProfilerActivity +from torch_npu._C._profiler import ( + _start_cpu_profiler, + _stop_cpu_profiler, + _ExperimentalConfig as C_ExperimentalConfig, + ProfilerActivity, + NpuProfilerConfig +) from torch_npu.utils._error_code import ErrCode, prof_error from .experimental_config import _ExperimentalConfig @@ -281,6 +287,23 @@ class profile(_KinetoProfile): self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num + self._step_num_offset)) self.step_rec_fn.__enter__() + @classmethod + @no_exception_func() + def start_cpu_profiler(cls, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False): + npu_prof_config = NpuProfilerConfig('', record_shapes, profile_memory, with_stack, with_flops, with_modules, C_ExperimentalConfig()) + _start_cpu_profiler(npu_prof_config) + + @classmethod + @no_exception_func() + def stop_cpu_profiler(cls): + torch_npu.npu.synchronize() + _stop_cpu_profiler() + @no_exception_func() def analyse(profiler_path: str, max_process_number: int = Constant.DEFAULT_PROCESS_NUMBER,