From c9877697a81d755d3579d9dd7581e9bd7853733e Mon Sep 17 00:00:00 2001 From: shibo19 Date: Fri, 21 Jan 2022 18:06:54 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0profiler=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=9Astep=201?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/__init__.py | 1 + torch_npu/csrc/InitNpuBindings.cpp | 3 +- torch_npu/csrc/profiler/init.cpp | 120 ++++++++++ torch_npu/csrc/profiler/init.h | 26 +++ torch_npu/csrc/profiler/profiler.h | 22 ++ torch_npu/csrc/profiler/profiler_npu.cpp | 93 ++++++++ torch_npu/csrc/profiler/utils.cpp | 268 +++++++++++++++++++++++ torch_npu/csrc/profiler/utils.h | 37 ++++ 8 files changed, 569 insertions(+), 1 deletion(-) create mode 100644 torch_npu/csrc/profiler/init.cpp create mode 100644 torch_npu/csrc/profiler/init.h create mode 100644 torch_npu/csrc/profiler/profiler.h create mode 100644 torch_npu/csrc/profiler/profiler_npu.cpp create mode 100644 torch_npu/csrc/profiler/utils.cpp create mode 100644 torch_npu/csrc/profiler/utils.h diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index ddd56469c2..8a9fa1cc12 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -39,6 +39,7 @@ def _apply_patches(): monkey_patches = [ ["npu", torch_npu.npu], ["npu.amp", torch_npu.npu.amp], + ["autograd.profiler", torch_npu.npu.profiler], ["distributed", torch_npu.distributed], ["distributed.distributed_c10d", torch_npu.distributed.distributed_c10d], ["nn.parallel.distributed._get_default_group", torch_npu.distributed.distributed_c10d._get_default_group] diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index df5fca05f5..9dd12dd179 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -24,7 +24,7 @@ #include "torch_npu/csrc/framework/allocator/THNPUCachingHostAllocator.h" #include "torch_npu/csrc/npu/Event.h" #include "torch_npu/csrc/distributed/Init.h" - +#include "torch_npu/csrc/profiler/init.h" PyObject* module; @@ -90,6 +90,7 @@ PyObject* initModule(){ AddPyMethodDefs(methods, TorchNpuMethods); AddPyMethodDefs(methods, THNPModule_get_methods()); + AddPyMethodDefs(methods, torch_npu::profiler::profiler_functions()); AddPyMethodDefs(methods, torch_npu::distributed::python_functions()); static struct PyModuleDef torchnpu_module = { PyModuleDef_HEAD_INIT, diff --git a/torch_npu/csrc/profiler/init.cpp b/torch_npu/csrc/profiler/init.cpp new file mode 100644 index 0000000000..bc079af3a0 --- /dev/null +++ b/torch_npu/csrc/profiler/init.cpp @@ -0,0 +1,120 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "torch_npu/csrc/profiler/profiler.h" + +namespace torch_npu { +namespace profiler{ + +PyObject* profiler_initExtension(PyObject* _unused, PyObject *unused) { + + auto torch_npu_C_module = THPObjectPtr(PyImport_ImportModule("torch_npu._C")); + if (!torch_npu_C_module) + return nullptr; + auto torch_npu_C_m = py::handle(torch_npu_C_module).cast(); + auto m = torch_npu_C_m.def_submodule("_profiler", "_profiler bindings"); + + py::enum_(m, "ProfilerState") + .value("Disabled", ProfilerState::Disabled) + .value("CPU", ProfilerState::CPU) + .value("CUDA", ProfilerState::CUDA) + .value("NPU", ProfilerState::NPU) + .value("NVTX", ProfilerState::NVTX) + .value("KINETO", ProfilerState::KINETO); + + py::class_(m, "ProfilerConfig") + .def(py::init()); + + py::class_(m, "ProfilerEvent") + .def("kind", &LegacyEvent::kindStr) + .def("name", [](const LegacyEvent& e) { return e.name(); }) + .def("thread_id", &LegacyEvent::threadId) + .def("fwd_thread_id", &LegacyEvent::fwdThreadId) + .def("device", &LegacyEvent::device) + .def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs) + .def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs) + .def("npu_elapsed_us", &LegacyEvent::npuElapsedUs) + .def("npu_destropy_event", &LegacyEvent::npu_destropy_event) + .def("has_cuda", &LegacyEvent::hasCuda) + .def("has_npu", &LegacyEvent::hasNpu) + .def("shapes", &LegacyEvent::shapes) + .def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage) + .def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage) + .def("npu_memory_usage", &LegacyEvent::npuMemoryUsage) + .def("handle", &LegacyEvent::handle) + .def("node_id", &LegacyEvent::nodeId) + .def("is_remote", &LegacyEvent::isRemote) + .def("sequence_nr", &LegacyEvent::sequenceNr) + .def("stack", &LegacyEvent::stack) + .def("scope", &LegacyEvent::scope) + .def("correlation_id", &LegacyEvent::correlationId) + .def("start_us", &LegacyEvent::cpuUs) + .def("flops", &LegacyEvent::flops); + + m.def("_enable_profiler_legacy", enableProfilerLegacy); + py::class_(m, "_ProfilerDisableOptions") + .def(py::init()); + m.def( + "_disable_profiler_legacy", + disableProfilerLegacy, + py::arg("profiler_disable_options") = ProfilerDisableOptions()); + m.def("_profiler_enabled", profilerEnabled); + m.def("_enable_record_function", [](bool enable) { + at::enableRecordFunction(enable); + }); + m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { + auto cb = at::RecordFunctionCallback(nullptr) + .needsInputs(true) + .samplingProb(sampling_prob); + if (is_global) { + at::addGlobalCallback(cb); + } else { + at::addThreadLocalCallback(cb); + } + }); + m.def("_clear_callbacks", []() { + at::clearCallbacks(); + }); + + Py_RETURN_TRUE; +} + +// autograd methods on torch._C +static PyMethodDef TorchProfilerMethods[] = { // NOLINT + {"_profiler_init", profiler_initExtension, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr} +}; + + +PyMethodDef* profiler_functions() { + return TorchProfilerMethods; +} + +} +} \ No newline at end of file diff --git a/torch_npu/csrc/profiler/init.h b/torch_npu/csrc/profiler/init.h new file mode 100644 index 0000000000..5ded508bf6 --- /dev/null +++ b/torch_npu/csrc/profiler/init.h @@ -0,0 +1,26 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROFILER_INIT_INC +#define PROFILER_INIT_INC + +namespace torch_npu { +namespace profiler{ +PyMethodDef* profiler_functions(); +} +} + +#endif // PROFILER_INIT_INC \ No newline at end of file diff --git a/torch_npu/csrc/profiler/profiler.h b/torch_npu/csrc/profiler/profiler.h new file mode 100644 index 0000000000..a9e3e4f684 --- /dev/null +++ b/torch_npu/csrc/profiler/profiler.h @@ -0,0 +1,22 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROFILER_INC +#define PROFILER_INC + +#include "torch_npu/csrc/profiler/profiler_legacy.h" + +#endif // PROFILER_INC \ No newline at end of file diff --git a/torch_npu/csrc/profiler/profiler_npu.cpp b/torch_npu/csrc/profiler/profiler_npu.cpp new file mode 100644 index 0000000000..05cfd101e3 --- /dev/null +++ b/torch_npu/csrc/profiler/profiler_npu.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include "third_party/acl/inc/acl/acl_rt.h" +#include "torch_npu/csrc/profiler/profiler.h" +#include + +namespace torch_npu { +namespace profiler{ + +namespace { + +static inline void npuCheck(aclError result, const char * file, int line) { + if(result != ACL_ERROR_NONE) { + std::stringstream ss; + ss << file << ":" << line << ": " << ", aclError id:" << result << "."; + throw std::runtime_error(ss.str()); + } +} +#define TORCH_NPU_CHECK(result) npuCheck(result,__FILE__,__LINE__); + +struct NPUMethods : public DeviceStubs { + void npu_destropy_event(aclrtEvent event) const override { + aclrtEventStatus status = ACL_EVENT_STATUS_RESERVED; + TORCH_NPU_CHECK(aclrtQueryEvent(event, &status)); + if (status == ACL_EVENT_STATUS_COMPLETE) { + TORCH_NPU_CHECK(aclrtDestroyEvent(event)); + } else { + std::cout << "Warning! NPU destroy event error, status is not completed." << std::endl; + } + } + void record(int* device, aclrtEvent* event1, int64_t* cpu_ns) const override { + TORCH_NPU_CHECK(aclrtGetDevice(device)); + TORCH_NPU_CHECK(c10::npu::acl::AclrtCreateEventWithFlag(event1, ACL_EVENT_TIME_LINE)); + auto stream = c10::npu::getCurrentNPUStream(); + *cpu_ns = getTime(); + TORCH_NPU_CHECK(aclrtRecordEvent(*event1, stream)); + } + float elapsed(const aclrtEvent& event1, const aclrtEvent& event2) const override { + TORCH_NPU_CHECK(aclrtSynchronizeEvent(event1)); + TORCH_NPU_CHECK(aclrtSynchronizeEvent(event2)); + float ms; + TORCH_NPU_CHECK(aclrtEventElapsedTime(&ms, event1, event2)); + return ms*1000.0; + } + void onEachDevice(std::function op) const override { + c10::npu::OptionalNPUGuard device_guard; + int dev = -1; + auto ret = aclrtGetDevice(&dev); + if (ret != ACL_ERROR_NONE) { + dev = 0; + } + device_guard.set_index(dev); + op(dev); + } + + void synchronize() const override { + c10::npu::npuSynchronizeDevice(); + } + bool enabled() const override { + return true; + } + +}; + +struct RegisterNPUMethods { + RegisterNPUMethods() { + static NPUMethods methods; + registerDeviceMethods(&methods); + } +}; +RegisterNPUMethods reg; + +} // namespaces +} // namespace profiler +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/profiler/utils.cpp b/torch_npu/csrc/profiler/utils.cpp new file mode 100644 index 0000000000..f30b037914 --- /dev/null +++ b/torch_npu/csrc/profiler/utils.cpp @@ -0,0 +1,268 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/profiler/utils.h" + +namespace torch_npu { +namespace profiler{ + +static constexpr auto kConv2dStride = 3; +static constexpr auto kConv2dPadding = 4; +static constexpr auto kConv2dDilation = 5; +static constexpr auto kConv2dGroups = 6; + +// List of supported operators +static constexpr auto kConv2dOp = "aten::conv2d"; +static constexpr auto kGemmOp = "aten::mm"; +static constexpr auto kMulOp = "aten::mul"; +static constexpr auto kAddOp = "aten::add"; + +static constexpr auto kInputSize = "input_size"; +static constexpr auto kWeightSize = "weight_size"; +static constexpr auto kGroups = "groups"; +static constexpr auto kPadding = "padding"; +static constexpr auto kStride = "stride"; +static constexpr auto kDilation = "dilation"; +static constexpr auto kMatSize = "mat_size"; +static constexpr auto kMat1Size = "mat1_size"; +static constexpr auto kMat2Size = "mat2_size"; + +static bool validateInput(const std::string &op_name, size_t min_size, + const std::vector& inputs, + const std::vector& should_be_tensor) { + std::stringstream ss; + if (inputs.size() < min_size) { + ss << "Failed to save extra arguments for flops compuation of op " + << op_name + << ", min size: " << min_size + << ", actual size: " << inputs.size(); + TORCH_WARN(ss.str()); + return false; + } + for (auto index : should_be_tensor) { + if (!inputs[index].isTensor()) { + ss << "Failed to save extra arguments for flops compuation of op " + << op_name + << ", input[" << index + << "] must be a tensor."; + TORCH_WARN(ss.str()); + return false; + } + } + return true; +} + +std::unordered_map saveExtraArgs(const at::RecordFunction& fn) { + // for specific types of fn, return the saved extra args for computing flops + std::unordered_map map; + std::vector inputs = fn.inputs(); + std::string fname(fn.name().str()); + + if (inputs.empty()) { + // Input shape is unavailable, return empty map + return map; + } + + if (fname == kConv2dOp) { + std::vector tensors{0, 1}; + bool check = validateInput(fname, kConv2dGroups + 1, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor input = inputs[0].toTensor(); + at::Tensor weight = inputs[1].toTensor(); + if (weight.sizes().size() != 4) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor."); + return map; + } + map[kInputSize] = at::IValue(input.sizes()); + map[kWeightSize] = at::IValue(weight.sizes()); + map[kStride] = inputs[kConv2dStride]; + map[kPadding] = inputs[kConv2dPadding]; + map[kDilation] = inputs[kConv2dDilation]; + map[kGroups] = inputs[kConv2dGroups]; + } else if (fname == kGemmOp) { + std::vector tensors{0, 1}; + bool check = validateInput(fname, 2, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor left = inputs[0].toTensor(); + at::Tensor right = inputs[1].toTensor(); + map[kMat1Size] = at::IValue(left.sizes()); + map[kMat2Size] = at::IValue(right.sizes()); + } else if (fname == kMulOp) { + std::vector tensors{0}; + bool check = validateInput(fname, 1, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor mat = inputs[0].toTensor(); + map[kMatSize] = at::IValue(mat.sizes()); + } else if (fname == kAddOp) { + std::vector tensors{0}; + bool check = validateInput(fname, 1, inputs, tensors); + if (!check) { + return map; + } + + at::Tensor mat = inputs[0].toTensor(); + map[kMatSize] = at::IValue(mat.sizes()); + } + + return map; +} + +uint64_t computeFlops(const std::string &op_name, const std::unordered_map &extra_args) { + if (op_name == kConv2dOp) { + if (extra_args.find(kInputSize) == extra_args.end() + || extra_args.find(kWeightSize) == extra_args.end() + || extra_args.find(kGroups) == extra_args.end() + || extra_args.find(kPadding) == extra_args.end() + || extra_args.find(kStride) == extra_args.end() + || extra_args.find(kDilation) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, and weight_size in saved arguments."); + return 0; + } + auto input_sizes_ref = extra_args.at(kInputSize); + auto kernel_sizes_ref = extra_args.at(kWeightSize); + auto groups_ref = extra_args.at(kGroups); + auto padding_ref = extra_args.at(kPadding); + auto stride_ref = extra_args.at(kStride); + auto dilation_ref = extra_args.at(kDilation); + if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes."); + return 0; + } + if (!padding_ref.isIntList() || !stride_ref.isIntList() || !dilation_ref.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because it requires padding, stride, and dilation values."); + return 0; + } + + const std::vector input_sizes = input_sizes_ref.toIntVector(); + const std::vector kernel_sizes = kernel_sizes_ref.toIntVector(); + const uint64_t groups = groups_ref.toInt(); + const std::vector padding = padding_ref.toIntVector(); + const std::vector stride = stride_ref.toIntVector(); + const std::vector dilation = dilation_ref.toIntVector(); + if (input_sizes.size() != 4 || kernel_sizes.size() != 4) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because both input and weight must be size 4."); + return 0; + } + if (!groups) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because group size must not be 0."); + return 0; + } + if (padding.size() != 2 || dilation.size() != 2) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2."); + return 0; + } + if (stride.size() != 2 || (stride[0] * stride[1] == 0)) { + TORCH_WARN("Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0."); + return 0; + } + // format of the input is defined in torch.nn.quantized.functional.conv2d() + uint64_t minibatch = 0, in_channels = 0, input_h = 0, input_w = 0; + uint64_t out_channels = 0, kernel_h = 0, kernel_w = 0; + const uint64_t conv2d_multiply_factor = 2; + std::tie(minibatch, in_channels, input_h, input_w) = std::make_tuple(input_sizes[0], input_sizes[1], + input_sizes[2], input_sizes[3]); + std::tie(out_channels, std::ignore, kernel_h, kernel_w) = std::make_tuple(kernel_sizes[0], kernel_sizes[1], + kernel_sizes[2], kernel_sizes[3]); + uint64_t output_h = (input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) / stride[0] + 1; + uint64_t output_w = (input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) / stride[1] + 1; + if (groups == 0) { + TORCH_CHECK(false, "groups can not be 0."); + } + return conv2d_multiply_factor * minibatch * output_h * output_w * + kernel_h * kernel_w * in_channels * out_channels / groups; + } else if (op_name == kGemmOp) { + if (extra_args.find(kMat1Size) == extra_args.end() + || extra_args.find(kMat2Size) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::mm requires mat1_size and mat2_size in saved arguments."); + return 0; + } + auto mat1_sizes_ref = extra_args.at(kMat1Size); + auto mat2_sizes_ref = extra_args.at(kMat2Size); + if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::mm because it requires mat1_size and mat2_size to be IntList."); + return 0; + } + + std::vector mat1_size = mat1_sizes_ref.toIntVector(); + std::vector mat2_size = mat2_sizes_ref.toIntVector(); + if (mat1_size.size() == 0) { + return 0; + } else { + int64_t overlap_dim = mat1_size.back(); + const uint64_t gemm_multiply_factor = 2; + uint64_t flops = 1; + for(int64_t dim : mat1_size) { + flops *= dim; + } + if (overlap_dim == 0) { + TORCH_CHECK(false, "overlap_dim can not be 0."); + } + flops /= overlap_dim; + for(int64_t dim : mat2_size) { + flops *= dim; + } + flops *= gemm_multiply_factor; + return flops; + } + } else if (op_name == kMulOp) { + if (extra_args.find(kMatSize) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::mul.Tensor requires mat_size in saved arguments."); + return 0; + } + auto mat_sizes = extra_args.at(kMatSize); + if (!mat_sizes.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::mul because it requires mat_size to be IntList."); + return 0; + } + + std::vector mat_size = mat_sizes.toIntVector(); + uint64_t flops = 1; + for(int64_t dim : mat_size) { + flops *= dim; + } + return flops; + } else if (op_name == kAddOp) { + if (extra_args.find(kMatSize) == extra_args.end()) { + TORCH_WARN("Calculating flops for aten::add.Tensor requires mat_size in saved arguments."); + return 0; + } + auto mat_sizes = extra_args.at(kMatSize); + if (!mat_sizes.isIntList()) { + TORCH_WARN("Failed to compute flops for op aten::add because it requires mat_size to be IntList."); + return 0; + } + + std::vector mat_size = mat_sizes.toIntVector(); + uint64_t flops = 1; + for(int64_t dim : mat_size) { + flops *= dim; + } + return flops; + } + return 0; +} + +} +} \ No newline at end of file diff --git a/torch_npu/csrc/profiler/utils.h b/torch_npu/csrc/profiler/utils.h new file mode 100644 index 0000000000..dc6ddced03 --- /dev/null +++ b/torch_npu/csrc/profiler/utils.h @@ -0,0 +1,37 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROFILER_NPU_INC +#define PROFILER_NPU_INC + +#include +#include +#include +#include +#include + +namespace torch_npu { +namespace profiler{ + +std::unordered_map saveExtraArgs(const at::RecordFunction& fn); + +uint64_t computeFlops(const std::string &op_name, + const std::unordered_map &extra_args); + +} +} + +#endif // PROFILER_NPU_INC \ No newline at end of file -- Gitee From db2345227108c0dee821b5f8c0a7048503010362 Mon Sep 17 00:00:00 2001 From: shibo19 Date: Fri, 21 Jan 2022 18:10:58 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0profiler=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=9Astep=202?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/profiler/profiler_legacy.cpp | 467 +++++++++++++++ torch_npu/csrc/profiler/profiler_legacy.h | 596 ++++++++++++++++++++ torch_npu/npu/__init__.py | 5 +- 3 files changed, 1066 insertions(+), 2 deletions(-) create mode 100644 torch_npu/csrc/profiler/profiler_legacy.cpp create mode 100644 torch_npu/csrc/profiler/profiler_legacy.h diff --git a/torch_npu/csrc/profiler/profiler_legacy.cpp b/torch_npu/csrc/profiler/profiler_legacy.cpp new file mode 100644 index 0000000000..a40574b6e3 --- /dev/null +++ b/torch_npu/csrc/profiler/profiler_legacy.cpp @@ -0,0 +1,467 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/profiler/profiler.h" +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace torch_npu { +namespace profiler{ + +std::vector callstackStr(const std::vector& cs) { + std::vector cs_str; + cs_str.reserve(cs.size()); + for (const auto& entry : cs) { + std::stringstream loc; + loc << entry.filename << "(" << entry.line << "): " << entry.funcname; + cs_str.push_back(loc.str()); + } + return cs_str; +} + +// We decompose the profiler logic into the following components: +// +// ThreadLocalDebugInfo: +// +// ThreadLocalDebugInfo is a thread local mapping from slots into +// the debug information structs. +// ThreadLocalDebugInfo is automatically propagated across thread +// boundaries, including the cases of: +// - launching async jobs with at::launch +// - executing JIT continuations +// - moving from the forward threads into autograd (backward) threads +// +// Entries in ThreadLocalDebugInfo are managed by DebugInfoGuard +// which can be used to add or overwrite an entry in the thread local +// mapping. A corresponding entry is removed when the guard is destroyed, +// potentially revealing the previously set value for the same slot. +// +// For the async tasks, slots previuosly set in the main thread before +// launching of an async task are shared and visible in the async task. +// +// On the other hand, any adding or overwriting of the mapping by the +// async task is not visible to the main thread and any modification +// (including removal of the entries) in the main thread is not visible +// to the async task if it happends after launching the task. +// +// We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config, +// as well as a list of events that happen during profiling. +// An instance of ThreadLocalDebugInfo is created each time we enter +// profiler (i.e. enter profiling context manager/call enableConfig) and +// uniquely identifies a profiling run. +// +// We automatically propagate ThreadLocalDebugInfo into async tasks, +// as well as across JIT continuations and autograd thread, so all +// the operations that happen between profiling start and end +// (not necessarily within the same thread) are recorded. +// Unless the profiling slot is overwritten as in the case of nested +// profiling ranges (in this case events for the subrange are handled +// by the nested profiler) +// +// When we exit a profiling range (either by exiting profiling context +// manager or by calling disableProfiler), we remove the previously set +// profiling entry for the given thread local mapping, and consolidate +// events in the profiling result +// +// +// ThreadLocalState: +// +// ThreadLocalState takes a 'snapshot' of thread local variables +// using provided getters. It is used together with ThreadLocalStateGuard +// to transfer the snapshot across thread boundary and set the thread local +// values as in the parent task. +// +// Profiler uses ThreadLocalState to propagate profiler's thread local state. +// ThreadLocalState also automatically propagates profiler callbacks. +// +// +// at::RecordFunction and observers +// +// Profiler uses observers mechanism to add a pair of thread local callbacks +// that are executed on a number of predetermined ranges, including: +// - c10/ATen ops +// - TorchScript functions/methods +// - user defined named ranges (see `record_function` python context manager) +// +// Profiler setups a pair of callbacks that record profiling events and save +// them into the thread local profiler struct (ThreadLocalDebugInfo, +// PROFILER_STATE slot) +// +// +// Thus, the overall logic is: +// +// enableProfiler: +// - checks that profiler is not enabled (otherwise throws) +// - pushes new ThreadLocalDebugInfo (slot PROFILER_STATE) as the profiler +// config for the current thread +// - pushes profiling callbacks for the current thread +// +// disableProfiler: +// - pops PROFILER_STATE slot from the current ThreadLocalDebugInfo and +// consolidates events +// - removes profiling callbacks +// +// ThreadLocalState: +// - propagates ThreadLocalDebugInfo across threads +// - propagates profiler callbacks across threads +// +// Profiler callbacks: +// - get the current profiling state (PROFILER slot in ThreadLocalDebugInfo) +// - save profiling events into the profiling state +// + +namespace { +const DeviceStubs default_stubs; +constexpr const DeviceStubs* default_stubs_addr = &default_stubs; +// Constant initialization, so it is guaranteed to be initialized before +// static initialization calls which may invoke registerCUDAMethods +inline const DeviceStubs*& device_stubs() { + static const DeviceStubs* stubs_ = default_stubs_addr; + return stubs_; +} +} + +// Profiler state +const ProfilerConfig& ProfilerThreadLocalState::config() const { + return config_; +} + +thread_event_lists ProfilerThreadLocalState::consolidate() { + std::lock_guard g(state_mutex_); + thread_event_lists result; + for (auto& kv : event_lists_map_) { + auto& list = kv.second; + result.emplace_back(list->consolidate()); + } + // Consolidate remote events if applicable as well. + if (remoteProfiledEvents_) { + result.insert( + result.end(), + std::make_move_iterator(remoteProfiledEvents_->begin()), + std::make_move_iterator(remoteProfiledEvents_->end())); + } + return result; +} + +void ProfilerThreadLocalState::mark(std::string name, bool include_device) { + if (config_.state == ProfilerState::Disabled) { + return; + } + if (config_.state == ProfilerState::NVTX) { + device_stubs()->nvtxMarkA(name.c_str()); + } else { + LegacyEvent evt( + EventKind::Mark, + at::StringView(std::move(name)), + at::RecordFunction::currentThreadId(), + include_device && (config_.state == ProfilerState::CUDA || config_.state == ProfilerState::NPU), + config_.state); + evt.setNodeId(at::RecordFunction::getDefaultNodeId()); + getEventList().record(std::move(evt)); + } +} + +void ProfilerThreadLocalState::setOrAddRemoteProfiledEvents( + std::vector&& remoteProfiledEvents) { + // Lock to serialize access from multiple callback threads. + std::lock_guard guard(state_mutex_); + if (remoteProfiledEvents_) { + (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents); + } else { + remoteProfiledEvents_ = {std::move(remoteProfiledEvents)}; + } +} + +void ProfilerThreadLocalState::pushRange( + const at::RecordFunction& fn, + const bool record_device, + const char* msg, + std::vector>&& shapes) { + if (config_.state == ProfilerState::Disabled) { + return; + } + if (config_.state == ProfilerState::NVTX) { + device_stubs()->nvtxRangePushA(getNvtxStr( + fn.name(), msg, fn.seqNr(), shapes).c_str()); + } else { + LegacyEvent evt( + EventKind::PushRange, + fn.name(), + at::RecordFunction::currentThreadId(), + record_device, + config_.state, + fn.handle(), + std::move(shapes), + at::RecordFunction::getDefaultNodeId()); + evt.setSequenceNr(fn.seqNr()); + evt.setFwdThreadId(fn.forwardThreadId()); + evt.setScope((uint8_t)fn.scope()); + if (config_.with_flops) { + evt.setExtraArgs(saveExtraArgs(fn)); + evt.setFlops(computeFlops(std::string(fn.name().str()), evt.extraArgs())); + } + getEventList().record(std::move(evt)); + } +} + +void ProfilerThreadLocalState::popRange(const at::RecordFunction& fn, const bool record_device) { + if (config_.state == ProfilerState::Disabled) { + return; + } + if (config_.state == ProfilerState::NVTX) { + device_stubs()->nvtxRangePop(); + } else { + // In some cases RecordFunction (and popRange) may be + // called on a different thread than pushRange + // As a convention, we put the async pop on the original + // thread and save current thread id in pop event + LegacyEvent evt( + EventKind::PopRange, + at::StringView(""), + at::RecordFunction::currentThreadId(), + record_device, + config_.state, + fn.handle()); + evt.setNodeId(at::RecordFunction::getDefaultNodeId()); + getEventList(fn.threadId()).record(std::move(evt)); + } +} + +void ProfilerThreadLocalState::reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) { + if (config_.profile_memory && config_.state != ProfilerState::Disabled) { + uint64_t thread_id = at::RecordFunction::currentThreadId(); + LegacyEvent evt( + EventKind::MemoryAlloc, + at::StringView(""), + thread_id, + config_.state == ProfilerState::CUDA || config_.state == ProfilerState::NPU, + config_.state); + evt.updateMemoryStats(alloc_size, device); + getEventList(thread_id).record(std::move(evt)); + } +} + +bool ProfilerThreadLocalState::memoryProfilingEnabled() const { + return config_.profile_memory; +} + +std::string ProfilerThreadLocalState::getNvtxStr( + const at::StringView& name, + const char* msg, + int64_t sequence_nr, + const std::vector>& shapes) const { + return name.str(); +} + +RangeEventList& ProfilerThreadLocalState::getEventList(int64_t thread_id) { + if (thread_id < 0) { + thread_id = at::RecordFunction::currentThreadId(); + } + RangeEventList* list_ptr = nullptr; + std::lock_guard guard(state_mutex_); + auto it = event_lists_map_.find(thread_id); + if (it != event_lists_map_.end()) { + list_ptr = it->second.get(); + } else { + auto event_list = std::make_shared(); + event_lists_map_[thread_id] = event_list; + list_ptr = event_list.get(); + } + return *list_ptr; +} + +std::vector> inputSizes(const at::RecordFunction& fn) { + std::vector> sizes; + sizes.reserve(fn.inputs().size()); + for (const c10::IValue& input : fn.inputs()) { + if (!input.isTensor()) { + sizes.emplace_back(); + continue; + } + const at::Tensor& tensor = input.toTensor(); + if (tensor.defined()) { + sizes.push_back(input.toTensor().sizes().vec()); + } else { + sizes.emplace_back(); + } + } + return sizes; +} + +namespace { + +enum EventIValueIdx { + KIND = 0, + NAME, + THREAD_ID, + HANDLE, + NODE_ID, + CPU_MEM_USAGE, + CPU_NS, + CUDA_RECORDED, + CUDA_MEM_USAGE, + CUDA_DEVICE, + CUDA_US, + SHAPES, + NUM_EVENT_IVALUE_IDX // must be last in list +}; + +enum ProfilerIValueIdx { + STATE = 0, + REPORT_INPUT_SHAPES, + PROFILE_MEMORY, + NUM_PROFILER_CFG_IVALUE_IDX // must be last in list +}; + +const std::unordered_set disable_cuda_profiling = { + "aten::view", + "aten::t", + "aten::transpose", + "aten::stride", + "aten::empty", + "aten::empty_like", + "aten::empty_strided", + "aten::as_strided", + "aten::expand", + "aten::resize_", + "aten::squeeze", + "aten::unsqueeze", + "aten::slice", + "aten::_unsafe_view", + "aten::size" +}; + +ProfilerThreadLocalState* getProfilerTLSState() { + return static_cast( + c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)); +} + +void pushProfilingCallbacksLegacy() { + auto state_ptr = getProfilerTLSState(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( + [](const at::RecordFunction& fn) -> std::unique_ptr { + auto state_ptr = getProfilerTLSState(); + if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { + return nullptr; + } + bool record_cuda = + state_ptr->config().state == ProfilerState::CUDA; + bool record_npu = + state_ptr->config().state == ProfilerState::NPU; + if (record_cuda && disable_cuda_profiling.find(fn.name().str()) != disable_cuda_profiling.end()) { + record_cuda = false; + } + + auto* msg = (fn.seqNr() >= 0) ? ", seq = " : ""; + if (state_ptr->config().report_input_shapes) { + auto sizes = inputSizes(fn); + state_ptr->pushRange(fn, record_cuda || record_npu, msg, std::move(sizes)); + } else { + state_ptr->pushRange(fn, record_cuda || record_npu, msg); + } + + return nullptr; + }, + [](const at::RecordFunction& fn, at::ObserverContext*) { + auto state_ptr = getProfilerTLSState(); + if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { + return; + } + bool record_cuda = + state_ptr->config().state == ProfilerState::CUDA; + bool record_npu = + state_ptr->config().state == ProfilerState::NPU; + if (record_cuda && disable_cuda_profiling.find(fn.name().str()) != disable_cuda_profiling.end()) { + record_cuda = false; + } + state_ptr->popRange(fn, record_cuda || record_npu); + }) + .needsInputs(state_ptr->config().report_input_shapes) + .needsIds(true)); + state_ptr->setCallbackHandle(handle); +} + +const int kCUDAWarmupStart = 5; +const int kNPUWarmupStart = 5; + +} // namespace + +void registerDeviceMethods(DeviceStubs* stubs) { + device_stubs() = stubs; +} + +at::IValue ProfilerConfig::toIValue() const { + c10::impl::GenericList eventIValueList(at::AnyType::get()); + eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX); + eventIValueList.emplace_back(static_cast(state)); + eventIValueList.emplace_back(report_input_shapes); + eventIValueList.emplace_back(profile_memory); + return eventIValueList; +} + +ProfilerConfig ProfilerConfig::fromIValue( + const at::IValue& profilerConfigIValue) { + TORCH_INTERNAL_ASSERT( + profilerConfigIValue.isList(), + "Expected IValue to contain type c10::impl::GenericList"); + auto ivalues = profilerConfigIValue.toList(); + TORCH_INTERNAL_ASSERT( + ivalues.size() == NUM_PROFILER_CFG_IVALUE_IDX, + c10::str( + "Expected exactly ", + NUM_PROFILER_CFG_IVALUE_IDX, + " ivalues to resconstruct ProfilerConfig.")); + return ProfilerConfig( + static_cast(ivalues.get(ProfilerIValueIdx::STATE).toInt()), + ivalues.get(ProfilerIValueIdx::REPORT_INPUT_SHAPES).toBool(), + ivalues.get(ProfilerIValueIdx::PROFILE_MEMORY).toBool()); +} + +ProfilerConfig getProfilerConfig() { + auto state_ptr = getProfilerTLSState(); + TORCH_CHECK( + state_ptr, + "Tried to access profiler config, but profiler is not enabled!"); + return state_ptr->config(); +} + +bool profilerEnabled() { + auto state_ptr = getProfilerTLSState(); + return state_ptr && state_ptr->config().state != ProfilerState::Disabled; +} diff --git a/torch_npu/csrc/profiler/profiler_legacy.h b/torch_npu/csrc/profiler/profiler_legacy.h new file mode 100644 index 0000000000..97ab9ed211 --- /dev/null +++ b/torch_npu/csrc/profiler/profiler_legacy.h @@ -0,0 +1,596 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROFILER_LEGACY_INC +#define PROFILER_LEGACY_INC + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#endif +#if defined(C10_IOS) && defined(C10_MOBILE) +#include // for gettimeofday() +#endif + +#include +#include +#include +#include "torch_npu/csrc/profiler/utils.h" + +namespace torch_npu { +namespace profiler{ + +enum class ProfilerState { + Disabled = 0, + CPU, // CPU-only profiling + CUDA, // CPU + CUDA events + NPU, // CPU + NPU events + NVTX, // only emit NVTX markers + KINETO, // use libkineto + NUM_PROFILER_STATES, // must be the last one +}; + +struct DeviceStubs { + + virtual void record(int* device, aclrtEvent* event, int64_t* cpu_ns) const { + fail(); + } + virtual float elapsed(const aclrtEvent& event1, const aclrtEvent& event2) const { + fail(); + return 0.f; + } + virtual void nvtxMarkA(const char* name) const { + fail(); + } + virtual void nvtxRangePushA(const char* name) const { + fail(); + } + virtual void nvtxRangePop() const { + fail(); + } + virtual void npu_destropy_event(aclrtEvent event) const { + fail(); + return; + } + virtual bool enabled() const { + return false; + } + virtual void onEachDevice(std::function op) const { + fail(); + } + virtual void synchronize() const { + fail(); + } + virtual ~DeviceStubs(); + +private: + void fail() const { + AT_ERROR("Device used in profiler but not enabled."); + } +}; + + void registerDeviceMethods(DeviceStubs* stubs); + +inline int64_t getTime() { +#if defined(C10_IOS) && defined(C10_MOBILE) +// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on +// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000000000 + static_cast(now.tv_usec) * 1000; +#elif defined(_WIN32) || defined(__MACH__) + using namespace std::chrono; + using clock = std::conditional::type; + return duration_cast(clock::now().time_since_epoch()).count(); +#else + // clock_gettime is *much* faster than std::chrono implementation on Linux + struct timespec t{}; + clock_gettime(CLOCK_MONOTONIC, &t); + return static_cast(t.tv_sec) * 1000000000 + static_cast(t.tv_nsec); +#endif +} + +enum class EventKind : uint16_t { + Mark, + PushRange, + PopRange, + MemoryAlloc, +}; + +// To be deprecated, once we switch to Kineto profiling +struct LegacyEvent { + LegacyEvent( + EventKind kind, + at::StringView name, + uint16_t thread_id, + bool record_device, + ProfilerState state = ProfilerState::CPU, + at::RecordFunctionHandle handle = 0, + std::vector>&& shapes = {}, + int node_id = -1) + : name_(std::move(name)), + kind_(kind), + thread_id_(thread_id), + handle_(handle), + shapes_(shapes), + node_id_(node_id), + state_(state) { + record(record_device); + } + + // Constructor to be used in conjunction with LegacyEvent::fromIValue. + LegacyEvent( + EventKind kind, + at::StringView name, + uint16_t thread_id, + at::RecordFunctionHandle handle, + std::vector>&& shapes, + int node_id, + bool is_remote, + int64_t cpu_memory_usage, + int64_t cpu_ns, + bool cuda_recorded, + int64_t cuda_memory_usage = 0, + int device = -1, + double cuda_us = -1) + : cpu_ns_(cpu_ns), + name_(std::move(name)), + kind_(kind), + thread_id_(thread_id), + handle_(handle), + shapes_(shapes), + cpu_memory_usage_(cpu_memory_usage), + cuda_memory_usage_(cuda_memory_usage), + device_(device), + node_id_(node_id), + is_remote_(is_remote), + cuda_us_(cuda_us) { + // Sanity check values that were deserialized + TORCH_INTERNAL_ASSERT(cpu_ns_ > 0); + if (cuda_recorded) { + TORCH_INTERNAL_ASSERT(device_ >= 0); + TORCH_INTERNAL_ASSERT(cuda_us_ >= 0); + } + } + + // Returns IValues corresponding to event structure, to be used for + // serialization. + at::IValue toIValue() const; + + // Reconstructs an event from IValues given by toIValue. + static LegacyEvent fromIValue(const at::IValue& eventIValue); + + void record(bool record_device); + + std::string kindStr() const { + switch (kind_) { + case EventKind::Mark: return "mark"; + case EventKind::PushRange: return "push"; + case EventKind::PopRange: return "pop"; + case EventKind::MemoryAlloc: return "memory_alloc"; + } + throw std::runtime_error("unknown event kind"); + } + + const char* name() const { + return name_.str(); + } + + uint64_t threadId() const { + return thread_id_; + } + + std::vector> shapes() const { + return shapes_; + } + + double cpuElapsedUs(const LegacyEvent& e) const { + return (e.cpu_ns_ - cpu_ns_)/(1000.0); + } + + void setCpuUs(int64_t cpu_us) { + cpu_ns_ = cpu_us * 1000.0; + } + + double cpuUs() const { + return cpu_ns_ / (1000.0); + } + + double cudaElapsedUs(const LegacyEvent& e) const; + + bool hasCuda() const { + return cuda_event != nullptr || (isRemote() && device_ != -1); + } + + double npuElapsedUs(const LegacyEvent& e) const; + + void npu_destropy_event(); + + bool hasNpu() const { + return npu_event != nullptr || (state_ == ProfilerState::NPU && device_ != -1); + } + + int device() const { + return device_; + } + + void updateMemoryStats(int64_t alloc_size, c10::Device device) { + if (device.type() == c10::DeviceType::CUDA || + device.type() == c10::DeviceType::HIP) { + cuda_memory_usage_ = alloc_size; + } else if (device.type() == c10::DeviceType::CPU || + device.type() == c10::DeviceType::MKLDNN || + device.type() == c10::DeviceType::IDEEP) { + cpu_memory_usage_ = alloc_size; + } else if (device.type() == c10::DeviceType::NPU) { + npu_memory_usage_ = alloc_size; + }else { + LOG(WARNING) << "Unsupported memory profiling device: " << device; + } + } + + int64_t cpuMemoryUsage() const { + return cpu_memory_usage_; + } + + int64_t cudaMemoryUsage() const { + return cuda_memory_usage_; + } + + int64_t npuMemoryUsage() const { + return npu_memory_usage_; + } + + at::RecordFunctionHandle handle() const { + return handle_; + } + + // Node ID corresponding to this event. + int nodeId( ) const { + return node_id_; + } + + // Set Node ID on this event. + void setNodeId(int node_id) { + node_id_ = node_id; + } + + void setName(at::StringView newName_) { + name_ = std::move(newName_); + } + + bool isRemote() const { + return is_remote_; + } + + void setCudaUs(int64_t cuda_us) { + cuda_us_ = cuda_us; + } + + void setSequenceNr(int64_t sequence_nr) { + sequence_nr_ = sequence_nr; + } + + int64_t sequenceNr() const { + return sequence_nr_; + } + + void setCorrelationId(uint64_t correlation_id) { + correlation_id_ = correlation_id; + } + + uint64_t correlationId() const { + return correlation_id_; + } + + const std::vector& stack() const { + return stack_; + } + + void setStack(const std::vector& stack) { + stack_ = stack; + } + + uint64_t fwdThreadId() const { + return fwd_thread_id_; + } + + void setFwdThreadId(uint64_t fwd_thread_id) { + fwd_thread_id_ = fwd_thread_id; + } + + uint8_t scope() const { + return scope_; + } + + void setScope(uint8_t scope) { + scope_ = scope; + } + + const std::unordered_map& extraArgs() const { + return extra_args_; + } + + void setExtraArgs(std::unordered_map&& save_args) { + extra_args_ = std::move(save_args); + } + + uint64_t flops() { + return flops_; + } + + void setFlops(uint64_t flops) { + flops_ = flops; + } + +private: + // signed to allow for negative intervals, initialized for safety. + int64_t cpu_ns_ = 0; + ProfilerState state_; + at::StringView name_; + EventKind kind_; + uint64_t thread_id_; + uint64_t fwd_thread_id_; + at::RecordFunctionHandle handle_ {0}; + std::vector> shapes_; + int64_t cpu_memory_usage_ = 0; + int64_t cuda_memory_usage_ = 0; + int64_t npu_memory_usage_ = 0; + int device_ = -1; + aclrtEvent cuda_event = nullptr; + aclrtEvent npu_event = nullptr; + int node_id_ = 0; + bool is_remote_ = false; + int64_t cuda_us_ = -1; + int64_t sequence_nr_ = -1; + + std::vector stack_; + uint8_t scope_; + uint64_t correlation_id_; + // Extra arguments for computing op flops + std::unordered_map extra_args_; + uint64_t flops_ = 0; +}; + +// a linked-list of fixed sized vectors, to avoid +// a std::vector resize from taking a large amount of time inside +// a profiling event +struct RangeEventList { + RangeEventList() { + events_.reserve(kReservedCapacity); + } + + template + void record(Args&&... args) { + std::lock_guard guard(mutex_); + events_.emplace_back(std::forward(args)...); + } + + std::vector consolidate() { + std::lock_guard lock(mutex_); + std::vector result; + result.insert( + result.begin(), + std::make_move_iterator(events_.begin()), + std::make_move_iterator(events_.end())); + events_.erase(events_.begin(), events_.end()); + return result; + } + + size_t size() { + std::lock_guard lock(mutex_); + return events_.size(); + } + +private: + // This mutex is used to serialize access when different threads are writing + // to the same instance of RangeEventList. + std::mutex mutex_; + std::vector events_; + + static const size_t kReservedCapacity = 1024; +}; + +struct ProfilerConfig { + ProfilerConfig( + ProfilerState state, + bool report_input_shapes = false, + bool profile_memory = false, + bool with_stack = false, + bool with_flops = false, + bool use_npu_simple = false) + : state(state), + report_input_shapes(report_input_shapes), + profile_memory(profile_memory), + with_stack(with_stack), + with_flops(with_flops), + use_npu_simple(use_npu_simple) { init_npu_simple(); } + ~ProfilerConfig() = default; + ProfilerState state; + bool report_input_shapes; + bool profile_memory; + bool with_stack; + bool with_flops; + bool use_npu_simple; + + void init_npu_simple() { + if (state == ProfilerState::NPU) { + // at::DisableRecordFunction::use_npu_simple = use_npu_simple; + return; + } + } + // Returns IValues corresponding to ProfilerConfig struct, to be used for + // serialization. + at::IValue toIValue() const; + + // Reconstructs a ProfilerConfig from IValues given by toIValue. + static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue); +}; + +// A struct to control settings of disableProfiler options. +struct ProfilerDisableOptions { + ProfilerDisableOptions() = default; + ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate) + : cleanupTLSState(shouldCleanupTLSState), + consolidate(shouldConsolidate) {} + // Whether we should clean up profiler states that are thread local, such as + // ThreadLocalDebugInfo and thread local RecordFunction callbacks. + bool cleanupTLSState = true; + // Whether we should consolidate all currently recorded profiled events. If + // false, will not consolidate and other threads can continue to write to the + // event lists. + bool consolidate = true; +}; + +// NOTE: profiler mode is thread local, with automatic propagation +// across thread boundary (e.g. at::launch tasks) + void enableProfilerLegacy(const ProfilerConfig&); +using thread_event_lists = std::vector>; + thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions = c10::nullopt); + +// adds profiledEvents to the current thread local recorded events. Each event +// will be marked with node ID given by fromNodeId. + void addEventList(std::vector&& profiledEvents); +// Returns if the profiler is currently enabled in the current thread. + bool profilerEnabled(); +// Retrieve the thread_local ProfilerConfig. + ProfilerConfig getProfilerConfig(); +// Writes profiled events to a stream. + void writeProfilerEventsToStream(std::ostream& out, const std::vector& events); + +struct RecordProfile { + RecordProfile(std::ostream& out); + RecordProfile(const std::string& filename); + + ~RecordProfile(); +private: + void init(); + std::unique_ptr file_; + std::ostream& out_; + void processEvents(const std::vector& events); +}; + +// A guard that enables the profiler, taking in an optional callback to process +// the results +struct TLSProfilerGuard { + explicit TLSProfilerGuard( + const ProfilerConfig& cfg, + c10::optional> + resultCallback = c10::nullopt, + c10::optional profilerDisableOptions = + c10::nullopt) + : cb_(std::move(resultCallback)), + profilerDisableOptions_(std::move(profilerDisableOptions)) { + enableProfilerLegacy(cfg); + } + ~TLSProfilerGuard() { + thread_event_lists event_lists = disableProfilerLegacy(profilerDisableOptions_); + if (cb_) { + try { + (*cb_)(event_lists); + } catch (const std::exception& e) { + LOG(ERROR) << "Got error processing profiler events: " << e.what(); + } + } + } + +private: + c10::optional> cb_; + const c10::optional profilerDisableOptions_; +}; + +struct FileLineFunc { + std::string filename; + size_t line; + std::string funcname; +}; + + std::vector callstackStr(const std::vector& cs); + std::vector> inputSizes(const at::RecordFunction& fn); + +struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { + explicit ProfilerThreadLocalState(const ProfilerConfig& config) + : config_(config), remoteProfiledEvents_{c10::nullopt} {} + ~ProfilerThreadLocalState() override = default; + + const ProfilerConfig& config() const; + + thread_event_lists consolidate(); + + void mark(std::string name, bool include_device = true); + + void setOrAddRemoteProfiledEvents( + std::vector&& remoteProfiledEvents); + + void pushRange( + const at::RecordFunction& fn, + const bool record_cuda, + const char* msg = "", + std::vector>&& shapes = {}); + + void popRange(const at::RecordFunction& fn, const bool record_cuda); + + void setCallbackHandle(at::CallbackHandle handle) { + handle_ = handle; + } + + at::CallbackHandle callbackHandle() const { + return handle_; + } + + bool hasCallbackHandle() { + return handle_ > 0; + } + + void reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) override; + + bool memoryProfilingEnabled() const override; + +protected: + std::string getNvtxStr( + const at::StringView& name, + const char* msg, + int64_t sequence_nr, + const std::vector>& shapes) const; + + RangeEventList& getEventList(int64_t thread_id = -1); + + std::mutex state_mutex_; + std::unordered_map> + event_lists_map_; + + ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); + at::CallbackHandle handle_ = 0; + c10::optional>> remoteProfiledEvents_; +}; + +} +} +#endif // PROFILER_LEGACY_INC \ No newline at end of file diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 7b033a169c..b97832386e 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -26,7 +26,7 @@ __all__ = [ "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", - "Stream", "Event" + "Stream", "Event", "profiler" ] @@ -42,4 +42,5 @@ from .memory import (_free_mutex, caching_allocator_alloc, caching_allocator_del reset_max_memory_allocated, reset_max_memory_cached, memory_allocated, max_memory_allocated, memory_reserved, max_memory_reserved, memory_cached, max_memory_cached, memory_snapshot, memory_summary) -from .streams import Stream, Event \ No newline at end of file +from .streams import Stream, Event +from . import profiler \ No newline at end of file -- Gitee From 19f07f257654fe52bdfca0d295e19ba6a97926f1 Mon Sep 17 00:00:00 2001 From: shibo19 Date: Fri, 21 Jan 2022 18:12:52 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0profiler=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=9Astep=203?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/profiler/profiler_legacy.cpp | 278 +++++++++++++ torch_npu/npu/profiler.py | 407 ++++++++++++++++++++ 2 files changed, 685 insertions(+) create mode 100644 torch_npu/npu/profiler.py diff --git a/torch_npu/csrc/profiler/profiler_legacy.cpp b/torch_npu/csrc/profiler/profiler_legacy.cpp index a40574b6e3..b044456669 100644 --- a/torch_npu/csrc/profiler/profiler_legacy.cpp +++ b/torch_npu/csrc/profiler/profiler_legacy.cpp @@ -465,3 +465,281 @@ bool profilerEnabled() { auto state_ptr = getProfilerTLSState(); return state_ptr && state_ptr->config().state != ProfilerState::Disabled; } + +void enableProfilerLegacy(const ProfilerConfig& new_config) { + TORCH_CHECK(new_config.state != ProfilerState::NVTX || device_stubs()->enabled(), + "Can't use NVTX profiler - PyTorch was compiled without CUDA"); + + TORCH_CHECK(new_config.state != ProfilerState::KINETO); + + auto state_ptr = getProfilerTLSState(); + TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread"); + auto state = std::make_shared(new_config); + c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); + + pushProfilingCallbacksLegacy(); + + if (new_config.state == ProfilerState::CUDA) { + // event recording appears to have some startup overhead, so we need to + // to generate some dummy events first before recording synchronization events + for (int idx = 0; idx < kCUDAWarmupStart; ++idx) { + device_stubs()->onEachDevice([state](int /* unused */) { + state->mark("__cuda_startup"); + device_stubs()->synchronize(); + }); + } + + // cuda events must be on the same device, so we need a start event recorded + // for each gpu. we then use this event to synchronize time on the GPU + // with the CPU clock. + device_stubs()->onEachDevice([state](int d) { + state->mark("__cuda_start_event"); + }); + } + else if (new_config.state == ProfilerState::NPU) { + // event recording appears to have some startup overhead, so we need to + // to generate some dummy events first before recording synchronization events + for (int idx = 0; idx < kNPUWarmupStart; ++idx) { + device_stubs()->onEachDevice([state](int /* unused */) { + state->mark("__npu_startup"); + device_stubs()->synchronize(); + }); + } + + // npu events must be on the same device, so we need a start event recorded + // for each npu. we then use this event to synchronize time on the NPU + // with the CPU clock. + device_stubs()->onEachDevice([state](int d) { + state->mark("__npu_start_event"); + }); + } + state->mark("__start_profile", false); +} + +thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions) { + auto cleanupTLSState = profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true; + auto consolidate = profilerDisableOptions ? profilerDisableOptions->consolidate : true; + // all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard + std::shared_ptr state; + if (cleanupTLSState) { + state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); + } else { + state = c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE); + } + + auto state_ptr = static_cast(state.get()); + TORCH_CHECK(state_ptr && state_ptr->config().state != ProfilerState::Disabled, + "Can't disable profiler when it's not running"); + + if (cleanupTLSState) { + at::removeCallback(state_ptr->callbackHandle()); + } + + if (!consolidate || state_ptr->config().state == ProfilerState::NVTX) { + return thread_event_lists(); + } + + state_ptr->mark("__stop_profile"); + // Note that this will erase the underlying events. + return state_ptr->consolidate(); +} + +void addEventList(std::vector&& profiledEvents) { + auto state_ptr = getProfilerTLSState(); + TORCH_CHECK(state_ptr, "Profiler must be enabled."); + state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents)); +} + +void LegacyEvent::record(bool record_device) { + if (record_device && state_ == ProfilerState::CUDA) { + device_stubs()->record(&device_, &cuda_event, &cpu_ns_); + return; + } else if (c10::ObservedOperators::EnableNpuOp && record_device && state_ == ProfilerState::NPU) { + device_stubs()->record(&device_, &npu_event, &cpu_ns_); + return; + } + cpu_ns_ = getTime(); +} + +LegacyEvent LegacyEvent::fromIValue(const at::IValue& eventIValue) { + TORCH_INTERNAL_ASSERT( + eventIValue.isList(), + "Expected IValue to contain type c10::impl::GenericList"); + auto ivalues = eventIValue.toList(); + TORCH_INTERNAL_ASSERT( + ivalues.size() >= NUM_EVENT_IVALUE_IDX, + "Expected at least ", + NUM_EVENT_IVALUE_IDX, + " elements to reconstruct LegacyEvent."); + + // Reconstruct input shapes from ivalues. + auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES); + TORCH_INTERNAL_ASSERT( + shapeListIValue.isList(), + "Expected profiler shapes IValue to contain type c10::impl::GenericList."); + + auto shapeList = shapeListIValue.toList(); + std::vector> shapes; + shapes.reserve(shapeList.size()); + for (size_t i = 0 ; i < shapeList.size(); ++i) { + std::vector s; + auto shapeIValue = shapeList.get(i); + TORCH_INTERNAL_ASSERT( + shapeIValue.isList(), + "Expected each profiler shape element to contain shapes of type c10::impl::GenericList.") + auto curShapesList = shapeIValue.toList(); + s.reserve(curShapesList.size()); + for (size_t j = 0; j < curShapesList.size(); ++j) { + s.emplace_back(curShapesList.get(j).toInt()); + } + shapes.emplace_back(s); + } + + LegacyEvent evt( + static_cast( + ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind + at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name + ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id + static_cast( + ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle + std::move(shapes), // input shapes + ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id + true, // is remote + ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage + ivalues.get(EventIValueIdx::CPU_NS).toInt(), // cpu_ns + ivalues.get(EventIValueIdx::CUDA_RECORDED).toBool(), // was cuda recorded + ivalues.get(EventIValueIdx::CUDA_MEM_USAGE).toInt(), // cuda memory usage + ivalues.get(EventIValueIdx::CUDA_DEVICE).toInt(), // device + ivalues.get(EventIValueIdx::CUDA_US).toInt() // cuda_us + ); + return evt; +} + +at::IValue LegacyEvent::toIValue() const { + c10::impl::GenericList eventIValueList(at::AnyType::get()); + eventIValueList.reserve(NUM_EVENT_IVALUE_IDX); + eventIValueList.emplace_back(static_cast(kind_)); + eventIValueList.emplace_back(std::string(name_.str())); + eventIValueList.emplace_back(static_cast(thread_id_)); + eventIValueList.emplace_back(static_cast(handle_)); + eventIValueList.emplace_back(node_id_); + eventIValueList.emplace_back(cpu_memory_usage_); + eventIValueList.emplace_back(cpu_ns_); + // CUDA event information + bool cuda_profiling_enabled = hasCuda(); + eventIValueList.emplace_back(cuda_profiling_enabled); + eventIValueList.emplace_back(static_cast(cuda_memory_usage_)); + eventIValueList.emplace_back(device_); + eventIValueList.emplace_back(cuda_us_); + // Shapes + c10::impl::GenericList shapesList = + c10::impl::GenericList(at::ListType::create(at::IntType::get())); + shapesList.reserve(shapes_.size()); + for (const auto& shape : shapes_) { + c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get()); + s.reserve(shape.size()); + for (const auto& k : shape) { + s.emplace_back(k); + } + shapesList.emplace_back(s); + } + eventIValueList.emplace_back(shapesList); + return at::IValue(eventIValueList); +} + +double LegacyEvent::npuElapsedUs(const LegacyEvent& e) const { + TORCH_CHECK(e.hasNpu() && hasNpu(), "Events were not recorded for NPU"); + TORCH_CHECK( + e.device() == device(), + c10::str( + "Events are not on the same device: ", e.device(), " vs ", device())); + return device_stubs()->elapsed(npu_event, e.npu_event); +} + +void LegacyEvent::npu_destropy_event() { + if (!hasNpu()) { + throw std::logic_error("Events were not recorded for NPU"); + } + device_stubs()->npu_destropy_event(npu_event); +} + +double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const { + return 0.0; +} + +DeviceStubs::~DeviceStubs() = default; + +void writeProfilerEventsToStream(std::ostream& out, const std::vector& events) { + TORCH_CHECK(out, "Could not open file"); + LegacyEvent* profiler_start = nullptr; + for (LegacyEvent* e : events) { + if (0 == strcmp(e->name(), "__start_profile")) { + profiler_start = e; + break; + } + } + TORCH_CHECK(profiler_start, "Could not find __start_profile mark"); + + struct PairHash { + size_t operator()(std::pair p) const + noexcept { + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + std::unordered_map, LegacyEvent*, PairHash> events_map; + out << "[\n"; + bool first = true; + for (LegacyEvent* evt : events) { + if (evt->kindStr() == "push") { + events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt; + } else if (evt->kindStr() == "pop") { + if (!first) { + out << ",\n"; + } + first = false; + auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId())); + TORCH_CHECK(it != events_map.end(), "Unmatched pop event"); + LegacyEvent* evt_start = it->second; + events_map.erase(it); + } + } + out << "]\n"; +} + +RecordProfile::RecordProfile(std::ostream& out) +: out_(out) { + init(); +} + +RecordProfile::RecordProfile(const std::string& filename) +: file_(new std::ofstream(filename)), out_(*file_) { + init(); +} + +void RecordProfile::init() { + enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU)); +} + +RecordProfile::~RecordProfile() { + try { + thread_event_lists event_lists = disableProfilerLegacy(); + std::vector events; + for (auto& l : event_lists) { + for (auto& e : l) { + events.push_back(&e); + } + } + processEvents(events); + } catch (const std::exception& e) { + LOG(ERROR) << e.what() << std::endl; + } catch (...) { + LOG(ERROR) << "Unknown error" << std::endl; + } +} + +void RecordProfile::processEvents(const std::vector& events) { + writeProfilerEventsToStream(out_, events); +} + +} +} \ No newline at end of file diff --git a/torch_npu/npu/profiler.py b/torch_npu/npu/profiler.py new file mode 100644 index 0000000000..eead621553 --- /dev/null +++ b/torch_npu/npu/profiler.py @@ -0,0 +1,407 @@ + # Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from typing import Any +from collections import defaultdict, namedtuple +from operator import attrgetter +from typing import Dict, List, Tuple, Optional +import math +from enum import Enum +import torch_npu + +try: + from contextlib import ContextDecorator +except ImportError: + import functools + + class ContextDecorator(object): # type: ignore[no-redef] + + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + + def __call__(self, func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapped +__all__ = ["profile", "record_function", ] + +if not torch_npu._C._profiler_init(): + raise RuntimeError("proflier initialization failed") + +class DeviceType(Enum): + CPU = 0, + CUDA = 1, # CUDA. + MKLDNN = 2, # Reserved for explicit MKLDNN + OPENGL = 3, # OpenGL + OPENCL = 4, # OpenCL + IDEEP = 5, # IDEEP. + HIP = 6, # AMD HIP + FPGA = 7, # FPGA + MSNPU = 8, # MSNPU + XLA = 9, # XLA / TPU + Vulkan = 10, # Vulkan + Metal = 11, # Metal + XPU = 12, # XPU + NPU = 13, # NPU +class EventList(list): + """A list of Events (for pretty printing)""" + def __init__(self, *args, **kwargs): + use_cuda = False + use_npu = kwargs.pop('use_npu', True) and torch_npu.npu.is_available() + profile_memory = kwargs.pop('profile_memory', False) + with_flops = kwargs.pop('with_flops', False) + super(EventList, self).__init__(*args, **kwargs) + self._use_cuda = use_cuda + self._use_npu = use_npu + self._profile_memory = profile_memory + self._tree_built = False + self._with_flops = with_flops + assert not (self._use_cuda and self._use_npu), "use_cuda and use_npu can't be True simultaneously." + + def _build_tree(self): + self._populate_cpu_children() + self._remove_dup_nodes() + self._set_backward_stacktraces() + self._tree_built = True + + def __str__(self): + return self.table() + + def _remove_dup_nodes(self): + while True: + to_delete = [] + for idx in range(len(self)): + if (self[idx].cpu_parent is not None and + self[idx].cpu_parent.name == self[idx].name and + len(self[idx].cpu_parent.cpu_children) == 1): + self[idx].cpu_parent.cpu_children = self[idx].cpu_children + self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up + for ch in self[idx].cpu_children: + ch.cpu_parent = self[idx].cpu_parent + to_delete.append(idx) + if len(to_delete) == 0: + break + new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete] + self.clear() + self.extend(new_evts) + + def _populate_cpu_children(self): + """Populates child events into each underlying FunctionEvent object. + One event is a child of another if [s1, e1) is inside [s2, e2). Where + s1 and e1 would be start and end of the child event's interval. And + s2 and e2 start and end of the parent event's interval + + Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10] + be a parent of two other intervals. + + If for any reason two intervals intersect only partially, this function + will not record a parent child relationship between then. + """ + + # Some events can be async (i.e. start and end on different threads), + # since it's generally undefined how to attribute children ranges to + # async ranges, we do not use them when calculating nested ranges and stats + sync_events = [evt for evt in self if not evt.is_async and evt.device_type == DeviceType.CPU] + events = sorted( + sync_events, + key=attrgetter("thread"), + ) + # Group by both thread and node_id, so that events that happen to have + # the same thread_id but are from different nodes aren't incorrectly + # grouped together. + threads = itertools.groupby( + events, key=lambda event: (event.thread, event.node_id) + ) + + # For each thread we keep a stack of current nested parents. + # We maintain the invariant that each interval is a subset of all other + # intervals lower in the stack. + # + # First we sort the intervals by their start time. Then we iterate over them. + # Every time we see a new interval we remove several parents from + # the top until we restore the invariant. Then parent child relationship + # if recorded if the stack is not empty. + # Finally we add new interval to the list + # + # Algorithm has O(N * log(N)) complexity where N is number of + # intervals + for thread_id, thread_events in threads: + thread_events_ = sorted( + thread_events, + key=lambda event: [event.time_range.start, -event.time_range.end], + ) + current_events: List[FunctionEvent] = [] + cur_end = 0 + for event in thread_events_: + while len(current_events) > 0: + parent = current_events[-1] + if event.time_range.start >= parent.time_range.end or \ + event.time_range.end > parent.time_range.end: + # this can't be a parent + current_events.pop() + else: + parent.append_cpu_child(event) + assert ( + event.cpu_parent is None + ), "There is already a CPU parent event for {}".format( + event.key + ) + event.set_cpu_parent(parent) + break + + current_events.append(event) + + def _set_backward_stacktraces(self): + def bw_parent(evt): + if evt is None: + return None + elif evt.scope == 1: # BACKWARD_FUNCTION + return evt + else: + return bw_parent(evt.cpu_parent) + + fwd_stacks = {} + for evt in self: + if bw_parent(evt) is None and evt.stack is not None: + t = (evt.sequence_nr, evt.thread) + if t not in fwd_stacks: + fwd_stacks[t] = evt.stack + + for evt in self: + p = bw_parent(evt) + if p is not None: + assert p.fwd_thread is not None + t = (p.sequence_nr, p.fwd_thread) + if t in fwd_stacks: + evt.stack = fwd_stacks[t] + else: + evt.stack = [] + + @property + def self_cpu_time_total(self): + return sum([event.self_cpu_time_total for event in self]) + + def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False): + """Prints an EventList as a nicely formatted table. + + Args: + sort_by (str, optional): Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, + ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. + top_level_events_only(bool, optional): Boolean flag to determine the + selection of events to display. If true, the profiler will only + display events at top level like top-level invocation of python + `lstm`, python `add` or other functions, nested events like low-level + cpu/cuda ops events are omitted for profiler result readability. + + Returns: + A string containing the table. + """ + return build_table( + self, + sort_by=sort_by, + row_limit=row_limit, + max_src_column_width=max_src_column_width, + header=header, + profile_memory=self._profile_memory, + with_flops=self._with_flops, + top_level_events_only=top_level_events_only, + use_cuda=self._use_cuda, + use_npu=self._use_npu) + + def export_chrome_trace(self, path): + """Exports an EventList as a Chrome tracing tools file. + + The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL. + + Args: + path (str): Path where the trace will be written. + """ + import os + with open(path, 'w') as f: + chrome_events = [] + next_id = 0 + # Use file IO over using json.dump since JSON dumping is very slow and + # this technique is proven to give a 4x speedup. + f.write("[") + for evt in self: + if evt.trace_name is None: + continue + f.write( + '{"name": "%s", ' + '"ph": "X", ' + '"ts": %s, ' + '"dur": %s, ' + '"tid": %s, ' + '"pid": "CPU functions", ' + '"args": {}}, ' + % ( + evt.trace_name, + evt.time_range.start, + evt.time_range.elapsed_us(), + evt.thread + if not evt.is_remote + else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "', + ) + ) + if self._use_cuda: + for k in evt.kernels: + # 's' and 'f' draw Flow arrows from + # the CPU launch to the GPU kernel + f.write('{"name": "%s", ' + '"ph": "s", ' + '"ts": %s, ' + '"tid": %s, ' + '"pid": "CPU functions", ' + '"id": %s, ' + '"cat": "cpu_to_cuda", ' + '"args": {}}, ' % (evt.trace_name, evt.time_range.start, + evt.thread, next_id)) + f.write('{"name": "%s", ' + '"ph": "f", ' + '"ts": %s, ' + '"tid": %s, ' + '"pid": "CUDA functions", ' + '"id": %s, ' + '"cat": "cpu_to_cuda", ' + '"args": {}}, ' % (k.name, k.interval.start, k.device, next_id)) + f.write('{"name": "%s", ' + '"ph": "X", ' + '"ts": %s, ' + '"dur": %s, ' + '"tid": %s, ' + '"pid": "CUDA functions", ' + '"args": {}}, ' % (k.name, k.interval.start, + k.interval.elapsed_us(), k.device)) + next_id += 1 + elif self._use_npu: + for k in evt.kernels: + # 's' and 'f' draw Flow arrows from + # the CPU launch to the GPU kernel + f.write('{"name": "%s", ' + '"ph": "s", ' + '"ts": %s, ' + '"tid": %s, ' + '"pid": "CPU functions", ' + '"id": %s, ' + '"cat": "cpu_to_npu", ' + '"args": {}}, ' % (evt.trace_name, evt.time_range.start, + evt.thread, next_id)) + f.write('{"name": "%s", ' + '"ph": "f", ' + '"ts": %s, ' + '"tid": %s, ' + '"pid": "NPU functions", ' + '"id": %s, ' + '"cat": "cpu_to_npu", ' + '"args": {}}, ' % (k.name, k.interval.start, k.device, next_id)) + f.write('{"name": "%s", ' + '"ph": "X", ' + '"ts": %s, ' + '"dur": %s, ' + '"tid": %s, ' + '"pid": "NPU functions", ' + '"args": {}}, ' % (k.name, k.interval.start, + k.interval.elapsed_us(), k.device)) + next_id += 1 + + # remove trailing whitespace and comma + f.seek(f.tell() - 2, os.SEEK_SET) + f.truncate() + f.write("]") + + def supported_export_stacks_metrics(self): + if self._use_npu: + return ["self_cpu_time_total", "self_npu_time_total"] + return ["self_cpu_time_total", "self_cuda_time_total"] + + def export_stacks(self, path: str, metric: str): + if metric not in self.supported_export_stacks_metrics(): + raise ValueError("metric should be one of: " + str(self.supported_export_stacks_metrics())) + translate_table = str.maketrans(" ;\t\n", "____") + with open(path, 'w') as f: + for evt in self: + if evt.stack and len(evt.stack) > 0: + metric_value = getattr(evt, metric) + if int(metric_value) > 0: + stack_str = "" + for entry in reversed(evt.stack): + stack_str += entry.translate(translate_table) + stack_str += ";" + stack_str = stack_str[:-1] + " " + str(int(metric_value)) + f.write(stack_str + "\n") + + def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0): + """Averages all function events over their keys. + + Args: + group_by_input_shapes: group entries by + (event name, input shapes) rather than just event name. + This is useful to see which input shapes contribute to the runtime + the most and may help with size-specific optimizations or + choosing the best candidates for quantization (aka fitting a roof line) + + group_by_stack_n: group by top n stack trace entries + + Returns: + An EventList containing FunctionEventAvg objects. + """ + assert self._tree_built + stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) + + def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]: + key = [str(event.key), str(event.node_id), str(event.device_type), str(event.is_legacy)] + if group_by_input_shapes: + key.append(str(event.input_shapes)) + if group_by_stack_n > 0: + key += event.stack[:group_by_stack_n] + return tuple(key) + for evt in self: + stats[get_key(evt, group_by_input_shapes, group_by_stack_n)].add(evt, self._use_cuda, self._use_npu) + + avg_list = EventList( + stats.values(), + use_cuda=self._use_cuda, + use_npu=self._use_npu, + profile_memory=self._profile_memory, + with_flops=self._with_flops) + for evt in avg_list: + evt.stack = evt.stack[:group_by_stack_n] + if not group_by_input_shapes: + evt.input_shapes = "" + return avg_list + + def total_average(self): + """Averages all events. + + Returns: + A FunctionEventAvg object. + """ + total_stat = FunctionEventAvg() + for evt in self: + total_stat += evt + total_stat.key = None + total_stat.key = 'Total' + return total_stat -- Gitee From cbdf6fccf268c1d976faf6b18b01e3ef5182a129 Mon Sep 17 00:00:00 2001 From: shibo19 Date: Fri, 21 Jan 2022 18:14:39 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0profiler=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=9Astep=204?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/npu/profiler.py | 1331 ++++++++++++++++++++++++++++++++++++- 1 file changed, 1330 insertions(+), 1 deletion(-) diff --git a/torch_npu/npu/profiler.py b/torch_npu/npu/profiler.py index eead621553..0dbf9de39a 100644 --- a/torch_npu/npu/profiler.py +++ b/torch_npu/npu/profiler.py @@ -90,7 +90,7 @@ class EventList(list): def _remove_dup_nodes(self): while True: to_delete = [] - for idx in range(len(self)): + for idx, _ in enumerate(self): if (self[idx].cpu_parent is not None and self[idx].cpu_parent.name == self[idx].name and len(self[idx].cpu_parent.cpu_children) == 1): @@ -405,3 +405,1332 @@ class EventList(list): total_stat.key = None total_stat.key = 'Total' return total_stat + +class profile(object): + """Context manager that manages autograd profiler state and holds a summary of results. + Under the hood it just records events of functions being executed in C++ and + exposes those events to Python. You can wrap any code into it and it will + only report runtime of PyTorch functions. + Note: profiler is thread local and is automatically propagated into the async tasks + + Args: + enabled (bool, optional): Setting this to False makes this context manager a no-op. + + use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + + use_npu (bool, optional): Enables timing of NPU events as well using the aclEvent API. + Adds approximately 4us of overhead to each tensor operation. + + record_shapes (bool, optional): If shapes recording is set, information + about input dimensions will be collected. This allows one to see which + dimensions have been used under the hood and further group by them + using prof.key_averages(group_by_input_shape=True). Please note that + shape recording might skew your profiling data. It is recommended to + use separate runs with and without shape recording to validate the timing. + Most likely the skew will be negligible for bottom most events (in a case + of nested function calls). But for higher level functions the total + self cpu time might be artificially increased because of the shape + collection. + + with_flops (bool, optional): If with_flops is set, the profiler will estimate + the FLOPS (floating pointer operations per second) value using the operator's input shape + and total time. This allows one to estimate the hardware performance. Currently, + this option only works for the matrix multiplication and 2D convolution operators. + + profile_memory (bool, optional): track tensor memory allocation/deallocation. + + with_stack (bool, optional): record source information (file and line number) for the ops. + + use_kineto (bool, optional): experimental, enable profiling with Kineto profiler. + + use_cpu (bool, optional): profile CPU events; setting to ``False`` requires + ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling. + + .. warning: + Enabling memory profiling or source attribution incurs additional profiler + overhead + + .. warning: + This context managers should not be called recursively, i.e. no nested + instances are allowed + + .. warning: + Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_), + one cannot use the profiler with ``use_cuda = True`` to benchmark + DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, + please use ``use_cuda = False`` or ``num_workers = 0``. + + Example: + >>> x = torch.randn((1, 1), requires_grad=True) + >>> with torch.autograd.profiler.profile() as prof: + >>> for _ in range(100): # any normal python code, really! + >>> y = x ** 2 + >> y.backward() + >>> # NOTE: some columns were removed for brevity + >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) + ----------------------------------- --------------- --------------- --------------- + Name Self CPU total CPU time avg Number of Calls + ----------------------------------- --------------- --------------- --------------- + mul 32.048ms 32.048ms 200 + pow 27.041ms 27.041ms 200 + PowBackward0 9.727ms 55.483ms 100 + torch::autograd::AccumulateGrad 9.148ms 9.148ms 100 + torch::autograd::GraphRoot 691.816us 691.816us 100 + ----------------------------------- --------------- --------------- --------------- + + """ + def __init__( + self, + enabled=True, + *, + use_cuda=False, + use_npu=False, + record_shapes=False, + with_flops=False, + profile_memory=False, + with_stack=False, + use_kineto=False, + use_cpu=True, + use_npu_simple=False): + self.enabled: bool = enabled + if not self.enabled: + return + self.use_cuda = use_cuda + self.use_npu = use_npu + self.use_npu_simple = use_npu_simple + self.function_events = None + self.entered = False + self.record_shapes = record_shapes + self.with_flops = with_flops + self.record_shapes |= self.with_flops + self.profile_memory = profile_memory + self.with_stack = with_stack + self.use_cpu = use_cpu + self.kineto_results = None + if not self.use_cpu: + assert use_kineto, \ + "Device-only events supported only with Kineto (use_kineto=True)" + + self.profiler_kind = None + self.kineto_activities = set() + if self.use_cuda: + # legacy CUDA mode + self.profiler_kind = torch_npu._C._profiler.ProfilerState.CUDA + elif self.use_npu: + self.profiler_kind = torch_npu._C._profiler.ProfilerState.NPU + else: + self.profiler_kind = torch_npu._C._profiler.ProfilerState.CPU + + + def config(self): + assert self.profiler_kind is not None + return torch_npu._C._profiler.ProfilerConfig( + self.profiler_kind, + self.record_shapes, + self.profile_memory, + self.with_stack, + self.with_flops, + self.use_npu_simple) + + def __enter__(self): + if not self.enabled: + return + if self.entered: + raise RuntimeError("profiler context manager is not reentrant") + self.entered = True + torch_npu._C._profiler._enable_profiler_legacy(self.config()) + return self + + def _prepare_kineto_trace(self): + assert self.kineto_activities + self.entered = True + torch_npu._C._profiler._prepare_profiler(self.config(), self.kineto_activities) + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.enabled: + return + records = torch_npu._C._profiler._disable_profiler_legacy() + parsed_results = parse_legacy_records(records) + self.function_events = EventList( + parsed_results, + use_cuda=self.use_cuda, + use_npu=self.use_npu, + profile_memory=self.profile_memory, + with_flops=self.with_flops) + self.function_events._build_tree() + return False + + def __repr__(self): + if self.function_events is None: + return '' + return repr(self.function_events) + + def __str__(self): + if self.function_events is None: + return '' + return str(self.function_events) + + def _check_finish(self): + if self.function_events is None: + raise RuntimeError("can't export a trace that didn't finish running") + + def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False): + self._check_finish() + assert self.function_events is not None + return self.function_events.table( + sort_by=sort_by, row_limit=row_limit, max_src_column_width=max_src_column_width, header=header, + top_level_events_only=top_level_events_only + ) + table.__doc__ = EventList.table.__doc__ + + def export_chrome_trace(self, path): + self._check_finish() + if self.kineto_results is not None: + self.kineto_results.save(path) + else: + assert self.function_events is not None + return self.function_events.export_chrome_trace(path) + export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ + + def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): + self._check_finish() + assert self.function_events is not None, "Expected profiling results" + assert self.with_stack, "export_stacks() requires with_stack=True" + return self.function_events.export_stacks(path, metric) + + def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): + self._check_finish() + assert self.function_events is not None, "Expected profiling results" + return self.function_events.key_averages(group_by_input_shape, group_by_stack_n) + key_averages.__doc__ = EventList.key_averages.__doc__ + + def total_average(self): + self._check_finish() + assert self.function_events is not None, "Expected profiling results" + return self.function_events.total_average() + total_average.__doc__ = EventList.total_average.__doc__ + + @property + def self_cpu_time_total(self): + """ Returns total time spent on CPU obtained as a sum of + all self times across all the events. + """ + self._check_finish() + assert self.function_events is not None + return self.function_events.self_cpu_time_total + + +class record_function(ContextDecorator): + """Context manager/function decorator that adds a label to a block of + Python code (or function) when running autograd profiler. It is + useful when tracing the code profile. + + Args: + name (str): Label assigned to the block of code. + node_id (int): ID of node, for distributed profiling. Unset in + non-distributed cases. + + Example: + >>> x = torch.randn((1, 1), requires_grad=True) + >>> with torch.autograd.profiler.profile() as prof: + ... y = x ** 2 + ... with torch.autograd.profiler.record_function("label-z"): # label the block + ... z = y ** 3 + ... y.backward() + ... + >>> # NOTE: some columns were removed for brevity + >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) + ----------------------------------- --------------- --------------- --------------- + Name Self CPU total % CPU time avg Number of Calls + ----------------------------------- --------------- --------------- --------------- + pow 60.77% 47.470us 3 + mul 21.73% 25.465us 2 + PowBackward0 12.03% 121.891us 1 + torch::autograd::AccumulateGrad 2.70% 6.324us 1 + label-z 2.13% 12.421us 1 + torch::autograd::GraphRoot 0.64% 1.503us 1 + ----------------------------------- --------------- --------------- --------------- + Self CPU time total: 234.344us + CUDA time total: 0.000us + + """ + def __init__(self, name: str): + self.name: str = name + # Whether or not we should run record function's end callbacks when exiting. + self.run_callbacks_on_exit: bool = True + # Stores underlying RecordFunction as a tensor. TODO: move to custom + # class (https://github.com/pytorch/pytorch/issues/35026). + self.handle = None + + def __enter__(self): + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + if self.run_callbacks_on_exit: + pass + + def _call_end_callbacks_on_future(self, fut): + """ + _call_end_callbacks_on_future is meant to be used for profiling async + calls that return a future. Calling this function will extend recording + beyond this scope, until the future is satisfied. It is useful for profiling + the end to end time of asynchronous calls. This function should only be called + once to attach the callback onto the future, and will throw if called multiple + times. + + Args: + fut: (torch._C.Future): future for which to schedule + callback for. + + Returns: + A future that completes with the value of the passed in future when + the profiling callbacks have ran. + + """ + # Throw if we have already attached a callback onto the future. + if not self.run_callbacks_on_exit: + raise RuntimeError("_call_end_callbacks_on_future can only be called once.") + + # We are scheduling to run this RecordFunction's end callbacks when the + # passed in future completes, so don't run end callbacks on exit. + self.run_callbacks_on_exit = False + profiled_future = None + return profiled_future + + + +def load_nvprof(path): + """Opens an nvprof trace file and parses autograd annotations. + + Args: + path (str): path to nvprof trace + """ + return EventList(parse_nvprof_trace(path)) + + +################################################################################ +# FunctionEvent + +def format_time(time_us): + """Defines how to format time in FunctionEvent""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return '{:.3f}s'.format(time_us / US_IN_SECOND) + if time_us >= US_IN_MS: + return '{:.3f}ms'.format(time_us / US_IN_MS) + return '{:.3f}us'.format(time_us) + + +def format_time_share(time_us, total_time_us): + """Defines how to format time in FunctionEvent""" + if total_time_us == 0: + assert time_us == 0, "Expected time_us == 0 but got {}".format(time_us) + return "NaN" + return '{:.2f}%'.format(time_us * 100.0 / total_time_us) + +def format_memory(nbytes): + """Returns a formatted memory size string""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if (abs(nbytes) >= GB): + return '{:.2f} Gb'.format(nbytes * 1.0 / GB) + elif (abs(nbytes) >= MB): + return '{:.2f} Mb'.format(nbytes * 1.0 / MB) + elif (abs(nbytes) >= KB): + return '{:.2f} Kb'.format(nbytes * 1.0 / KB) + else: + return str(nbytes) + ' b' + +def attr_formatter(name): + return property(lambda self: format_time(getattr(self, name))) + + +class FormattedTimesMixin(object): + """Helpers for FunctionEvent and FunctionEventAvg. + + The subclass should define `*_time_total` and `count` attributes. + """ + cpu_time_str = attr_formatter('cpu_time') + cuda_time_str = attr_formatter('cuda_time') + npu_time_str = attr_formatter('npu_time') + cpu_time_total_str = attr_formatter('cpu_time_total') + cuda_time_total_str = attr_formatter('cuda_time_total') + npu_time_total_str = attr_formatter('npu_time_total') + self_cpu_time_total_str = attr_formatter('self_cpu_time_total') + self_cuda_time_total_str = attr_formatter('self_cuda_time_total') + self_npu_time_total_str = attr_formatter('self_npu_time_total') + + @property + def cpu_time(self): + return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore + + @property + def cuda_time(self): + return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count # type: ignore + + @property + def npu_time(self): + return 0.0 if self.count == 0 else 1.0 * self.npu_time_total / self.count # type: ignore + +class Interval(object): + def __init__(self, start, end): + self.start = start + self.end = end + + def elapsed_us(self): + return self.end - self.start + + +Kernel = namedtuple('Kernel', ['name', 'device', 'interval']) + + +class FunctionEvent(FormattedTimesMixin): + """Profiling information about a single function.""" + def __init__( + self, id_event, name, thread, start_us, end_us, fwd_thread=None, input_shapes=None, + stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, npu_memory_usage=0, + is_async=False, is_remote=False, sequence_nr=-1, node_id=-1, device_type=DeviceType.CPU, + device_index=0, is_legacy=False, flops=None, trace_name=None): + self.id: int = id_event + self.node_id: int = node_id + self.name: str = name + self.trace_name: str = trace_name + self.time_range: Interval = Interval(start_us, end_us) + self.thread: int = thread + self.fwd_thread: Optional[int] = fwd_thread + self.kernels: List[Kernel] = [] + self.count: int = 1 + self.cpu_children: List[FunctionEvent] = [] + self.cpu_parent: Optional[FunctionEvent] = None + self.input_shapes: Tuple[int, ...] = input_shapes + self.stack: List = stack + self.scope: int = scope + self.cpu_memory_usage: int = cpu_memory_usage + self.cuda_memory_usage: int = cuda_memory_usage + self.npu_memory_usage: int = npu_memory_usage + self.is_async: bool = is_async + self.is_remote: bool = is_remote + self.sequence_nr: int = sequence_nr + self.device_type: DeviceType = device_type + self.device_index: int = device_index + self.is_legacy: bool = is_legacy + self.flops: Optional[float] = flops + + def append_kernel(self, name, device, start, end): + assert self.device_type == DeviceType.CPU + self.kernels.append(Kernel(name, device, Interval(start, end))) + + def append_cpu_child(self, child): + """Append a CPU child of type FunctionEvent. + + One is supposed to append only direct children to the event to have + correct self cpu time being reported. + """ + assert(self.device_type == DeviceType.CPU) + assert(isinstance(child, FunctionEvent)) + assert(child.device_type == DeviceType.CPU) + self.cpu_children.append(child) + + def set_cpu_parent(self, parent): + """Set the immediate CPU parent of type FunctionEvent + + One profiling FunctionEvent should have only one CPU parent such that + the child's range interval is completely inside the parent's. We use + this connection to determine the event is from top-level op or not. + """ + assert(self.device_type == DeviceType.CPU) + assert(isinstance(parent, FunctionEvent)) + assert(parent.device_type == DeviceType.CPU) + self.cpu_parent = parent + + # Note: async events don't have children, are not used when computing 'self' + # metrics of other events, have only total cpu time + @property + def self_cpu_memory_usage(self): + if self.is_async or self.device_type != DeviceType.CPU: + return 0 + return self.cpu_memory_usage - sum( + [child.cpu_memory_usage for child in self.cpu_children] + ) + + @property + def self_cuda_memory_usage(self): + if self.is_async or self.device_type != DeviceType.CPU: + return 0 + return self.cuda_memory_usage - sum( + [child.cuda_memory_usage for child in self.cpu_children] + ) + + @property + def self_npu_memory_usage(self): + if self.is_async or self.device_type != DeviceType.CPU: + return 0 + return self.npu_memory_usage - sum( + [child.npu_memory_usage for child in self.cpu_children] + ) + + @property + def self_cpu_time_total(self): + if self.is_async or self.device_type != DeviceType.CPU: + return 0 + return self.cpu_time_total - sum( + [child.cpu_time_total for child in self.cpu_children] + ) + + @property + def cuda_time_total(self): + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + if not self.is_legacy: + # account for the kernels in the children ops + return (sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + + sum(ch.cuda_time_total for ch in self.cpu_children)) + else: + # each legacy cpu events has a single (fake) kernel + return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + else: + assert self.device_type == DeviceType.CUDA + return self.time_range.elapsed_us() + + @property + def self_cuda_time_total(self): + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + return self.cuda_time_total - \ + sum([child.cuda_time_total for child in self.cpu_children]) + else: + assert(self.device_type == DeviceType.CUDA) + return self.cuda_time_total + + @property + def npu_time_total(self): + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + assert self.is_legacy, "profiling with NPU only support for legacy." + # each legacy cpu events has a single (fake) kernel + return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + else: + assert self.device_type == DeviceType.NPU + return self.time_range.elapsed_us() + + @property + def self_npu_time_total(self): + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + return self.npu_time_total - \ + sum([child.npu_time_total for child in self.cpu_children]) + else: + assert(self.device_type == DeviceType.NPU) + return self.npu_time_total + + @property + def cpu_time_total(self): + if self.device_type == DeviceType.CPU: + return self.time_range.elapsed_us() + else: + return 0 + + @property + def key(self): + return self.name + + def __repr__(self): + return ( + ''.format( + self.id, + self.name, + self.device_type, + self.node_id, + self.cpu_time_str, + self.time_range.start, + self.time_range.end, + str([child.id for child in self.cpu_children]), + self.cuda_time_str, + self.npu_time_str, + self.name, + self.thread, + str(self.input_shapes), + self.cpu_memory_usage, + self.cuda_memory_usage, + self.npu_memory_usage, + self.is_async, + self.is_remote, + self.sequence_nr, + self.is_legacy, + ) + ) + + +class FunctionEventAvg(FormattedTimesMixin): + """Used to average stats over multiple FunctionEvent objects.""" + def __init__(self): + self.key: Optional[str] = None + self.count: int = 0 + self.node_id: int = 0 + self.is_async: bool = False + self.is_remote: bool = False + self.use_cuda = True + self.use_npu = False + self.cpu_time_total: int = 0 + self.cuda_time_total: int = 0 + self.npu_time_total: int = 0 + self.self_cpu_time_total: int = 0 + self.self_cuda_time_total: int = 0 + self.self_npu_time_total: int = 0 + self.input_shapes: Optional[List[List[int]]] = None + self.stack: Optional[List] = None + self.scope: Optional[int] = None + self.cpu_memory_usage: int = 0 + self.cuda_memory_usage: int = 0 + self.npu_memory_usage: int = 0 + self.self_cpu_memory_usage: int = 0 + self.self_cuda_memory_usage: int = 0 + self.self_npu_memory_usage: int = 0 + self.cpu_children: Optional[List[FunctionEvent]] = None + self.cpu_parent: Optional[FunctionEvent] = None + self.device_type: DeviceType = DeviceType.CPU + self.is_legacy: bool = False + self.flops: float = 0.0 + + def add(self, other, use_cuda=True, use_npu=False): + self.use_cuda = use_cuda + self.use_npu = use_npu + if self.key is None: + # First function being recorded as part of FunctionEventAvg, propagate + # fields. + self.key = other.key + self.node_id = other.node_id + self.is_async = other.is_async + self.is_remote = other.is_remote + self.cpu_parent = other.cpu_parent + self.cpu_children = other.cpu_children + + self.input_shapes = other.input_shapes + self.stack = other.stack + self.scope = other.scope + self.device_type = other.device_type + self.is_legacy = other.is_legacy + + assert isinstance(other, (FunctionEvent, FunctionEventAvg)) + assert other.key == self.key + self.cpu_time_total += other.cpu_time_total + self.self_cpu_time_total += other.self_cpu_time_total + self.cpu_memory_usage += other.cpu_memory_usage + self.self_cpu_memory_usage += other.self_cpu_memory_usage + self.count += other.count + if self.use_cuda: + self.cuda_time_total += other.cuda_time_total + self.self_cuda_time_total += other.self_cuda_time_total + self.cuda_memory_usage += other.cuda_memory_usage + self.self_cuda_memory_usage += other.self_cuda_memory_usage + elif self.use_npu: + self.npu_time_total += other.npu_time_total + self.self_npu_time_total += other.self_npu_time_total + self.npu_memory_usage += other.npu_memory_usage + self.self_npu_memory_usage += other.self_npu_memory_usage + if self.flops is None: + self.flops = other.flops + elif other.flops is not None: + self.flops += other.flops + return self + + def __iadd__(self, other): + return self.add(other) + + def __repr__(self): + if self.use_npu: + return ( + ''.format( + self.key, + self.self_cpu_time_total_str, + self.cpu_time_str, + self.self_npu_time_total_str, + self.npu_time_str, + str(self.input_shapes), + self.cpu_memory_usage, + self.npu_memory_usage, + ) + ) + else: + return ( + ''.format( + self.key, + self.self_cpu_time_total_str, + self.cpu_time_str, + self.self_cuda_time_total_str, + self.cuda_time_str, + str(self.input_shapes), + self.cpu_memory_usage, + self.cuda_memory_usage, + ) + ) + + +################################################################################ +# Utilities + +class StringTable(defaultdict): + def __missing__(self, key): + # manage cases like 't' (demangled to 'unsigned short') separately, + # for now simply check the length to avoid unexpected results for + # the short sequences + self[key] = key + return self[key] + +def filter_stack_entry(entry): + filtered_entries = [ + ("autograd/__init__", "_make_grads"), + ("autograd/__init__", "backward"), + ("torch/tensor", "backward"), + ("_internal/common_utils", "prof_callable"), + ("_internal/common_utils", "prof_func_call"), + ("_internal/common_utils", "prof_meth_call"), + ] + return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries]) + +def filter_name(name): + # ignoring the following utility ops + filtered_out_names = [ + "profiler::_record_function_enter", + "profiler::_record_function_exit", + "aten::is_leaf", + "aten::output_nr", + "aten::_version", + ] + return name in filtered_out_names + +# Demangles and optionally rewrites the provided event name, +# with_wildcard - whether to replace certain numbered event names +# with a wildcard name to aggregate them together in the profiler table +# output +def rewrite_name(name, with_wildcard=False): + string_table = StringTable() + name = string_table[name] + if with_wildcard: + if name.startswith("ProfilerStep#"): + name = "ProfilerStep*" + return name + +# Parsing of kineto profiler events +def parse_kineto_results(result): + # result.events() has most of the events - PyTorch op-level and device-level events + # result.legacy_events() has events not yet ported to kineto + # (e.g. start/stop marks, tensor memory allocator events) + + # First, find __start_profile mark to get the absolute time of the start of the trace; + # save memory allocation records + start_record = None + mem_records = [] + for record in itertools.chain(*result.legacy_events()): + if record.kind() == 'mark' and record.name() == '__start_profile': + assert start_record is None + start_record = record + if record.kind() == 'memory_alloc': + mem_records.append([record, False]) + assert start_record is not None, "Invalid profiler output, __start_profile is missing" + + # Create and return FunctionEvent list + function_events = [] + cuda_corr_map: Dict[int, List[FunctionEvent]] = {} + for kineto_event in result.events(): + if filter_name(kineto_event.name()): + continue + rel_start_us = kineto_event.start_us() - start_record.start_us() + rel_end_us = rel_start_us + kineto_event.duration_us() + abs_end_us = kineto_event.start_us() + kineto_event.duration_us() + + cpu_memory_usage = 0 + cuda_memory_usage = 0 + if kineto_event.device_type() == DeviceType.CPU: + # find the corresponding memory allocation events + for mem_record in mem_records: + if (mem_record[0].start_us() >= kineto_event.start_us() and + mem_record[0].start_us() <= abs_end_us): + cpu_memory_usage += mem_record[0].cpu_memory_usage() + cuda_memory_usage += mem_record[0].cuda_memory_usage() + mem_record[1] = True + + is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id() + fe = FunctionEvent( + id=kineto_event.correlation_id(), + name=rewrite_name(name=kineto_event.name(), with_wildcard=True), + trace_name=rewrite_name(name=kineto_event.name(), with_wildcard=False), + thread=kineto_event.start_thread_id(), + start_us=rel_start_us, + end_us=rel_end_us, + fwd_thread=kineto_event.fwd_thread_id(), + input_shapes=kineto_event.shapes(), + stack=[entry for entry in kineto_event.stack() if filter_stack_entry(entry)], + scope=kineto_event.scope(), + cpu_memory_usage=cpu_memory_usage, + cuda_memory_usage=cuda_memory_usage, + is_async=is_async, + sequence_nr=kineto_event.sequence_nr(), + device_type=kineto_event.device_type(), + device_index=kineto_event.device_index(), + flops=kineto_event.flops(), + ) + function_events.append(fe) + corr_id = kineto_event.linked_correlation_id() + if corr_id > 0: + if corr_id not in cuda_corr_map: + cuda_corr_map[corr_id] = [] + cuda_corr_map[corr_id].append(fe) + + # associate CUDA kernels and CUDA runtime (CPU) with CPU events + for fe in function_events: + if (fe.device_type == DeviceType.CPU and not fe.is_async and + fe.id in cuda_corr_map): + for f_evt in cuda_corr_map[fe.id]: + if f_evt.device_type == DeviceType.CUDA: + fe.append_kernel( + f_evt.name, + f_evt.device_index, + f_evt.time_range.start, + f_evt.time_range.end) + elif f_evt.device_type == DeviceType.CPU: + # make sure that 'thread' of a CPU Kineto (e.g. CUDA Runtime) event is associated + # with the 'thread' of the corresponding linked PyTorch event to properly track + # parents and children + f_evt.thread = fe.thread + + + # output top-level memory events + for mem_record in mem_records: + if not mem_record[1]: + fe = FunctionEvent( + id=mem_record[0].handle(), + name="[memory]", + trace_name=None, # not outputting in the trace + thread=mem_record[0].thread_id(), + start_us=mem_record[0].start_us(), + end_us=mem_record[0].start_us(), # no duration + fwd_thread=mem_record[0].fwd_thread_id(), + input_shapes=[], + stack=[], + scope=mem_record[0].scope(), + cpu_memory_usage=mem_record[0].cpu_memory_usage(), + cuda_memory_usage=mem_record[0].cuda_memory_usage(), + is_async=False, + sequence_nr=-1, + device_type=DeviceType.CPU, + device_index=0, + ) + function_events.append(fe) + + function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) + return function_events + +# Parsing of legacy profiler events +def parse_legacy_records(thread_records): + def get_record_key(record): + """ + Returns a tuple to be used by parse_legacy_records for correlating start and + end records. + """ + return (record.handle(), record.node_id()) + + next_id = 0 + start_record = None + cuda_records = {} + npu_records = {} + functions = [] + record_stack = [] + profiler_type = DeviceType.CPU + + # cuda start events and the overall profiler start event don't happen + # at exactly the same time because we need to record an event on each device + # and each record takes ~4us. So we adjust here by the difference + # adding the difference in CPU time between the profiler start event + # and the CPU time of the cuda start event for the device + def adjusted_time(device_record, device_records_map): + assert device_record.device() != -1 + assert start_record is not None + if device_record.has_cuda(): + cuda_time_0 = device_records_map[(device_record.node_id(), device_record.device())] + return cuda_time_0.cuda_elapsed_us(device_record) + start_record.cpu_elapsed_us(cuda_time_0) + elif device_record.has_npu(): + npu_time_0 = device_records_map[(device_record.node_id(), device_record.device())] + return npu_time_0.npu_elapsed_us(device_record) + start_record.cpu_elapsed_us(npu_time_0) + + # '__start_profile' is not guaranteed to be first, so we must find it here + for record in itertools.chain(*thread_records): + name = record.name() + if start_record is None and name == '__start_profile': + start_record = record + elif '__cuda_start_event' in name: + # N.B.: Each CUDA device has its own __cuda_start_event. + assert record.device() != -1 + # key for cuda_records is (node_id, device) in case of multiple nodes + # having the same device + cuda_records[(record.node_id(), record.device())] = record + profiler_type = DeviceType.CUDA + elif '__npu_start_event' in name: + # N.B.: Each NPU device has its own __npu_start_event. + assert record.device() != -1 + # key for npu_records is (node_id, device) in case of multiple nodes + # having the same device + npu_records[(record.node_id(), record.device())] = record + profiler_type = DeviceType.NPU + + assert start_record is not None and not start_record.is_remote() + + for thread_record_list in thread_records: + # accumulated memory allocations per handle + cpu_memory_allocs = {} + cuda_memory_allocs = {} + npu_memory_allocs = {} + # ranges per handle + range_starts = {} + + filtered_handles = set() + prev_record = None + for record in thread_record_list: + record_key = get_record_key(record) + if (filter_name(record.name()) or + record_key in filtered_handles): + filtered_handles.add(record_key) + continue + + if record.kind() == 'push': + # workaround to reduce double logging from operator + # wrappers and redispatch + if prev_record is not None: + duplicate = ( + prev_record.name() == record.name() + and prev_record.kind() == record.kind() + and prev_record.node_id() == record.node_id() + ) + if duplicate: + filtered_handles.add(record_key) + continue + + range_starts[record_key] = record + cpu_memory_allocs[record_key] = 0 + cuda_memory_allocs[record_key] = 0 + npu_memory_allocs[record_key] = 0 + elif record.kind() == 'pop': + assert ( + record_key in range_starts + ), """Expected record with key {} to exist in range_starts. + This means that the pop event did not have a corresponding push.""".format( + record_key + ) + + start = range_starts[record_key] + + cpu_memory_usage = cpu_memory_allocs[record_key] + cuda_memory_usage = cuda_memory_allocs[record_key] + npu_memory_usage = npu_memory_allocs[record_key] + is_async = start.thread_id() != record.thread_id() + is_remote_event = record.is_remote() + start_flops = start.flops() + + fe = FunctionEvent( + id=record.handle(), + node_id=record.node_id(), + name=rewrite_name(name=start.name(), with_wildcard=True), + trace_name=rewrite_name(name=start.name(), with_wildcard=False), + thread=start.thread_id(), + start_us=start_record.cpu_elapsed_us(start), + end_us=start_record.cpu_elapsed_us(record), + fwd_thread=start.fwd_thread_id(), + input_shapes=start.shapes(), + stack=[entry for entry in start.stack() if filter_stack_entry(entry)], + scope=start.scope(), + cpu_memory_usage=cpu_memory_usage, + cuda_memory_usage=cuda_memory_usage, + npu_memory_usage=npu_memory_usage, + is_async=is_async, + is_remote=is_remote_event, + sequence_nr=start.sequence_nr(), + device_type=DeviceType.CPU, + is_legacy=True, + flops=start_flops, + ) + # note: async events have only cpu total time + if not is_async and start.has_cuda(): + cuda_start = adjusted_time(start, cuda_records) + cuda_end = adjusted_time(record, cuda_records) + if (cuda_end - cuda_start) > 0: + fe.append_kernel( + start.name(), + start.device(), + cuda_start, + cuda_end) + elif not is_async and start.has_npu(): + npu_start = adjusted_time(start, npu_records) + npu_end = adjusted_time(record, npu_records) + if (npu_end - npu_start) > 0: + fe.append_kernel( + start.name(), + start.device(), + npu_start, + npu_end) + functions.append(fe) + del range_starts[record_key] + del cpu_memory_allocs[record_key] + del cuda_memory_allocs[record_key] + del npu_memory_allocs[record_key] + elif record.kind() == 'memory_alloc': + num_open_handles_cpu = len(cpu_memory_allocs) + num_open_handles_cuda = len(cuda_memory_allocs) + assert num_open_handles_cpu == num_open_handles_cuda + for handle in cpu_memory_allocs.keys(): + cpu_memory_allocs[handle] += record.cpu_memory_usage() + for handle in cuda_memory_allocs.keys(): + cuda_memory_allocs[handle] += record.cuda_memory_usage() + for handle in npu_memory_allocs.keys(): + npu_memory_allocs[handle] += record.npu_memory_usage() + if num_open_handles_cpu == 0: + # output event as a top-level memory event + fe = FunctionEvent( + id=0, + name="[memory]", + trace_name=None, + thread=0, + start_us=0, + end_us=0, + stack=[], + cpu_memory_usage=record.cpu_memory_usage(), + cuda_memory_usage=record.cuda_memory_usage(), + npu_memory_usage=record.npu_memory_usage(), + is_legacy=True, + ) + functions.append(fe) + prev_record = record + + # Sort functions by start time then by end time ascending. + # This ensures that--in the case of nested events which + # have the same start time (which may happen due to the + # granularity of the given clock tick)--we always show + # the outermost nested call first. This adds stability + # in how FunctionEvents appear + functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) + if profiler_type == DeviceType.NPU: + for record in itertools.chain(*thread_records): + if record.has_npu(): + record.npu_destropy_event() + return functions + + +################################################################################ +# CUDA checkpoints + +class EnforceUnique(object): + """Raises an error if a key is seen more than once.""" + def __init__(self): + self.seen = set() + + def see(self, *key): + if key in self.seen: + raise RuntimeError('duplicate key: ' + str(key)) + self.seen.add(key) + + +################################################################################ +# Pretty printer + + +def build_table( + events, + sort_by=None, + header=None, + row_limit=100, + max_src_column_width=75, + with_flops=False, + profile_memory=False, + top_level_events_only=False, + use_cuda=False, + use_npu=False): + """Prints a summary of events (which can be a list of FunctionEvent or FunctionEventAvg).""" + if len(events) == 0: + return "" + + has_cuda_time = any([event.self_cuda_time_total > 0 for event in events]) and use_cuda + has_cuda_mem = any([event.self_cuda_memory_usage > 0 for event in events]) and use_cuda + has_npu_time = any([event.self_npu_time_total > 0 for event in events]) and use_npu + has_npu_mem = any([event.self_npu_memory_usage > 0 for event in events]) and use_npu + has_input_shapes = any( + [(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events]) + + if sort_by is not None: + events = EventList(sorted( + events, key=lambda evt: getattr(evt, sort_by), reverse=True + ), use_cuda=has_cuda_time, use_npu=has_npu_time, profile_memory=profile_memory, with_flops=with_flops) + + MAX_NAME_COLUMN_WIDTH = 55 + name_column_width = max([len(evt.key) for evt in events]) + 4 + name_column_width = min(name_column_width, MAX_NAME_COLUMN_WIDTH) + + DEFAULT_COLUMN_WIDTH = 12 + shapes_column_width = max([len(str(evt.input_shapes)) for evt in events]) + 4 + shapes_column_width = min(shapes_column_width, 45) + + flops_column_width = DEFAULT_COLUMN_WIDTH + + src_column_width = None + stacks = [] + for evt in events: + if evt.stack is not None and len(evt.stack) > 0: + stacks.append(evt.stack) + has_stack = len(stacks) > 0 + if has_stack: + src_column_width = max([max([len(entry) for entry in stack]) for stack in stacks]) + 4 + src_column_width = min(src_column_width, max_src_column_width) + + headers = [ + 'Name', + 'Self CPU %', + 'Self CPU', + 'CPU total %', + 'CPU total', + 'CPU time avg', + ] + if has_cuda_time: + headers.extend([ + 'Self CUDA', + 'Self CUDA %', + 'CUDA total', + 'CUDA time avg', + ]) + if has_npu_time: + headers.extend([ + 'Self NPU', + 'Self NPU %', + 'NPU total', + 'NPU time avg', + ]) + if profile_memory: + headers.extend([ + 'CPU Mem', + 'Self CPU Mem', + ]) + if has_cuda_mem: + headers.extend([ + 'CUDA Mem', + 'Self CUDA Mem', + ]) + if has_npu_mem: + headers.extend([ + 'NPU Mem', + 'Self NPU Mem', + ]) + headers.append( + '# of Calls' + ) + # Only append Node ID if any event has a valid (>= 0) Node ID + append_node_id = any([evt.node_id != -1 for evt in events]) + if append_node_id: + headers.append('Node ID') + + # Have to use a list because nonlocal is Py3 only... + SPACING_SIZE = 2 + row_format_lst = [""] + header_sep_lst = [""] + line_length_lst = [-SPACING_SIZE] + MAX_STACK_ENTRY = 5 + + def add_column(padding, text_dir='>'): + row_format_lst[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE) + header_sep_lst[0] += '-' * padding + (' ' * SPACING_SIZE) + line_length_lst[0] += padding + SPACING_SIZE + + def auto_scale_flops(flops): + flop_headers = [ + 'FLOPS', + 'KFLOPS', + 'MFLOPS', + 'GFLOPS', + 'TFLOPS', + 'PFLOPS', + ] + assert flops > 0 + log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1))) + assert log_flops >= 0 and log_flops < len(flop_headers) + return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)]) + + add_column(name_column_width) + for _ in headers[1:]: + add_column(DEFAULT_COLUMN_WIDTH) + + if has_input_shapes: + headers.append('Input Shapes') + add_column(shapes_column_width) + + if has_stack: + headers.append('Source Location') + add_column(src_column_width, text_dir='<') + + if with_flops: + # Auto-scaling of flops header + US_IN_SECOND = 1000.0 * 1000.0 # cpu_time_total is in us + raw_flops = [] + for evt in events: + if evt.flops > 0: + if evt.cuda_time_total != 0: + evt.flops = float(evt.flops) / evt.cuda_time_total * US_IN_SECOND + else: + evt.flops = float(evt.flops) / evt.cpu_time_total * US_IN_SECOND + raw_flops.append(evt.flops) + if len(raw_flops) != 0: + (flops_scale, flops_header) = auto_scale_flops(min(raw_flops)) + headers.append(flops_header) + add_column(flops_column_width) + else: + with_flops = False # can't find any valid flops + + row_format = row_format_lst[0] + header_sep = header_sep_lst[0] + line_length = line_length_lst[0] + add_column = None # type: ignore + + # Have to use a list because nonlocal is Py3 only... + result = [] + + def append(s): + result.append(s) + result.append('\n') # Yes, newline after the end as well + + sum_self_cpu_time_total = sum([event.self_cpu_time_total for event in events]) + sum_self_cuda_time_total = 0 + sum_self_npu_time_total = 0 + for evt in events: + if evt.device_type == DeviceType.CPU: + # in legacy profiler, kernel info is stored in cpu events + if evt.is_legacy: + sum_self_cuda_time_total += evt.self_cuda_time_total + sum_self_npu_time_total += evt.self_npu_time_total + elif evt.device_type == DeviceType.CUDA: + # in kineto profiler, there're events with the correct device type (e.g. CUDA) + sum_self_cuda_time_total += evt.self_cuda_time_total + elif evt.device_type == DeviceType.NPU: + # in kineto profiler, there're events with the correct device type (e.g. CUDA) + sum_self_npu_time_total += evt.self_npu_time_total + + # Actual printing + if header is not None: + append('=' * line_length) + append(header) + if top_level_events_only: + append('=' * line_length) + append('This report only display top-level ops statistics') + append(header_sep) + append(row_format.format(*headers)) + + append(header_sep) + + def trim_path(path, src_column_width): + if len(path) > src_column_width: + offset = len(path) - src_column_width + path = path[offset:] + if len(path) > 3: + path = "..." + path[3:] + return path + + event_limit = 0 + for evt in events: + if event_limit == row_limit: + break + if top_level_events_only and evt.cpu_parent is not None: + continue + else: + event_limit += 1 + name = evt.key + if len(name) >= MAX_NAME_COLUMN_WIDTH - 3: + name = name[:(MAX_NAME_COLUMN_WIDTH - 3)] + "..." + row_values = [ + name, + # Self CPU total %, 0 for async events. + format_time_share(evt.self_cpu_time_total, + sum_self_cpu_time_total), + evt.self_cpu_time_total_str, # Self CPU total + # CPU total %, 0 for async events. + format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) if not evt.is_async else 0, + evt.cpu_time_total_str, # CPU total + evt.cpu_time_str, # CPU time avg + ] + if has_cuda_time: + row_values.extend([ + evt.self_cuda_time_total_str, + # CUDA time total % + format_time_share(evt.self_cuda_time_total, sum_self_cuda_time_total), + evt.cuda_time_total_str, + evt.cuda_time_str, # Cuda time avg + ]) + if has_npu_time: + row_values.extend([ + evt.self_npu_time_total_str, + # NPU time total % + format_time_share(evt.self_npu_time_total, sum_self_npu_time_total), + evt.npu_time_total_str, + evt.npu_time_str, # Npu time avg + ]) + if profile_memory: + row_values.extend([ + # CPU Mem Total + format_memory(evt.cpu_memory_usage), + # Self CPU Mem Total + format_memory(evt.self_cpu_memory_usage), + ]) + if has_cuda_mem: + row_values.extend([ + # CUDA Mem Total + format_memory(evt.cuda_memory_usage), + # Self CUDA Mem Total + format_memory(evt.self_cuda_memory_usage), + ]) + if has_npu_mem: + row_values.extend([ + # NPU Mem Total + format_memory(evt.npu_memory_usage), + # Self NPU Mem Total + format_memory(evt.self_npu_memory_usage), + ]) + row_values.append( + evt.count, # Number of calls + ) + + if append_node_id: + row_values.append(evt.node_id) + if has_input_shapes: + row_values.append(str(evt.input_shapes)[:shapes_column_width]) + if with_flops: + if evt.flops <= 0.0: + row_values.append("--") + else: + row_values.append('{0:8.3f}'.format(evt.flops * flops_scale)) + if has_stack: + src_field = "" + if len(evt.stack) > 0: + src_field = trim_path(evt.stack[0], src_column_width) + row_values.append(src_field) + append(row_format.format(*row_values)) + + if has_stack: + empty_headers = [""] * (len(headers) - 1) + for entry in evt.stack[1:MAX_STACK_ENTRY]: + append(row_format.format(*(empty_headers + [trim_path(entry, src_column_width)]))) + empty_headers.append("") + append(row_format.format(*empty_headers)) + + append(header_sep) + append("Self CPU time total: {}".format(format_time(sum_self_cpu_time_total))) + if has_cuda_time: + append("Self CUDA time total: {}".format(format_time(sum_self_cuda_time_total))) + if has_npu_time: + append("Self NPU time total: {}".format(format_time(sum_self_npu_time_total))) + return ''.join(result) -- Gitee From e637770a11a50344cf5c9bc48dcd5a3c872d2398 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Fri, 28 Jan 2022 10:42:22 +0800 Subject: [PATCH 5/6] =?UTF-8?q?min=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= =?UTF-8?q?=EF=BC=88=E5=90=ABamin=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_min.py | 519 +++++++++++++++++++++++ torch_npu/csrc/aten/ops/MinKernelNpu.cpp | 185 ++++++++ 2 files changed, 704 insertions(+) create mode 100644 test/test_network_ops/test_min.py create mode 100644 torch_npu/csrc/aten/ops/MinKernelNpu.cpp diff --git a/test/test_network_ops/test_min.py b/test/test_network_ops/test_min.py new file mode 100644 index 0000000000..bc6554df4a --- /dev/null +++ b/test/test_network_ops/test_min.py @@ -0,0 +1,519 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestMin(TestCase): + def cpu_op_exec(self, input1): + output = torch.min(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.min(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_other_exec(self, input1, input2): + output = torch.min(input1, input2) + output = output.numpy() + return output + + def npu_op_other_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.min(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_other_exec_out(self, input1, input2, out): + torch.min(input1, input2, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def cpu_op_dim_exec(self, input1, dim, keepdim): + output1, output2 = torch.min(input1, dim, keepdim) + output1 = output1.numpy() + output2 = output2.int().numpy() # 这里需要将索引从64位转32位 便于拿去与npu的对比 + return output1, output2 + + def npu_op_dim_exec(self, input1, dim, keepdim): + input1 = input1.to("npu") + output1, output2 = torch.min(input1, dim, keepdim) + output1 = output1.to("cpu") + output2 = output2.to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def cpu_op_dim_exec_out(self, input1, dim, keepdim): + out = torch.tensor(0).to(input1.dtype) + indices = torch.tensor(0).to(torch.long) + torch.min(input1, dim=dim, keepdim=keepdim, out=(out,indices)) + out = out.numpy() + indices = indices.numpy() + return out,indices + + def npu_op_dim_exec_out(self, input1, dim, keepdim): + out = torch.tensor(0).to(input1.dtype).npu() + indices = torch.tensor(0).to(torch.long).npu() + torch.min(input1, dim=dim, keepdim=keepdim, out=(out,indices)) + out = out.to("cpu").numpy() + indices = indices.to("cpu").numpy() + return out,indices + + def cpu_op_amin_exec(self, input1, dim, keepdim): + output = torch.amin(input1, dim, keepdim) + output = output.numpy() + return output + + def npu_op_amin_exec(self, input1, dim, keepdim): + output = torch.amin(input1, dim, keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_amin_exec_out(self, input1, dim, keepdim, out): + torch.amin(input1, dim, keepdim, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def min_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + cpu_output = cpu_output.astype(npu_output.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + + def min_result_dim(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_dim, cpu_output_indices = self.cpu_op_dim_exec(cpu_input1, item[1], item[2]) + npu_output_dim, npu_output_indices = self.cpu_op_dim_exec(cpu_input1, item[1], item[2]) + cpu_output_dim = cpu_output_dim.astype(npu_output_dim.dtype) + self.assertRtolEqual(cpu_output_dim, npu_output_dim) + + def min_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 10) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output_other = self.cpu_op_other_exec(cpu_input1, cpu_input2) + npu_output_other = self.npu_op_other_exec(npu_input1, npu_input2) + cpu_output_other = cpu_output_other.astype(npu_output_other.dtype) + + self.assertRtolEqual(cpu_output_other, npu_output_other) + + def min_out_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], -100, 100) + cpu_input3, npu_input3 = create_common_tensor(item[0], -100, 100) + cpu_input4, npu_input4 = create_common_tensor(item[1], -100, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_other_exec(cpu_input1, cpu_input2) + npu_output_out1 = self.npu_op_other_exec(npu_input1, npu_input2) + npu_output_out2 = self.npu_op_other_exec_out(npu_input1, npu_input2, npu_input4) + cpu_output = cpu_output.astype(npu_output_out1.dtype) + self.assertRtolEqual(cpu_output, npu_output_out1) + self.assertRtolEqual(cpu_output, npu_output_out2) + cpu_out_dim, cpu_out_indices = self.cpu_op_dim_exec_out(cpu_input1, dim=0, keepdim=True) + npu_out_dim, npu_out_indices = self.npu_op_dim_exec_out(npu_input1, dim=0, keepdim=True) + npu_output_dim, npu_output_indices = self.npu_op_dim_exec(npu_input1, dim=0, keepdim=True) + cpu_out_dim = cpu_out_dim.astype(npu_out_dim.dtype) + if cpu_out_dim.dtype != np.float16: + self.assertRtolEqual(npu_out_dim, cpu_out_dim) + self.assertRtolEqual(npu_out_indices, cpu_out_indices) + else: + self.assertRtolEqual(npu_out_dim, npu_output_dim) + self.assertRtolEqual(npu_out_indices, npu_output_indices) + + def min_name_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input1.names = item[0][3] + npu_input1.names = item[0][3] + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_dim, cpu_output_indices = self.cpu_op_dim_exec(cpu_input1, item[1], item[2]) + npu_output_dim, npu_output_indices = self.npu_op_dim_exec(cpu_input1, item[1], item[2]) + + if npu_output_dim.dtype != np.float16: + self.assertRtolEqual(npu_output_dim, cpu_output_dim) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + else: + self.assertRtolEqual( npu_output_dim, cpu_output_dim.astype(np.float16)) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + + def min_name_out_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input1.names = item[0][3] + npu_input1.names = item[0][3] + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_dim, cpu_output_indices = self.cpu_op_dim_exec_out(cpu_input1, item[1], item[2]) + npu_output_dim, npu_output_indices = self.npu_op_dim_exec_out(npu_input1, item[1], item[2]) + + if npu_output_dim.dtype != np.float16: + self.assertRtolEqual(npu_output_dim, cpu_output_dim) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + else: + self.assertRtolEqual( npu_output_dim, cpu_output_dim.astype(np.float16)) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + + def amin_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_amin = self.cpu_op_amin_exec(cpu_input1, item[1], item[2]) + npu_output_amin = self.npu_op_amin_exec(npu_input1, item[1], item[2]) + npu_output_amin_out = self.npu_op_amin_exec_out(npu_input1, item[1], item[2], npu_input2) + cpu_output_amin = cpu_output_amin.astype(npu_output_amin.dtype) + self.assertRtolEqual(cpu_output_amin, npu_output_amin) + self.assertRtolEqual(cpu_output_amin, npu_output_amin_out) + + def test_min_out_result(self, device): + shape_format = [ + [[np.float16, 0, [128, 116, 14, 14]], [np.float16, 0, [256, 116, 1, 1]]], + [[np.float16, 0, [128, 58, 28, 28]], [np.float16, 0, [58, 58, 1, 1]]], + [[np.float16, 0, [128, 3, 224, 224]], [np.float16, 0, [3, 3, 3]]], + [[np.float16, 0, [128, 116, 14, 14]], [np.float16, 0, [128, 116, 14, 14]]], + [[np.float32, 0, [256, 128, 7, 7]], [np.float32, 0, [128, 256, 3, 3]]], + [[np.float32, 0, [256, 3, 224, 224]], [np.float32, 0, [3, 3, 7, 7]]], + [[np.float32, 0, [2, 3, 3, 3]], [np.float32, 0, [3, 1, 3]]], + [[np.float32, 0, [128, 232, 7, 7]], [np.float32, 0, [128, 232, 7, 7]]], + ] + self.min_out_result_other(shape_format) + + def test_min_shape_format_fp16_1d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_1d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp16_2d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_2d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp16_3d(self, device): + format_list = [0, 3, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_3d(self, device): + format_list = [0, 3, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp16_4d(self, device): + format_list = [0, 4, 3, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_dim_shape_format_fp16_1d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_1d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp16_4d(self, device): + format_list = [0, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_other_shape_format_fp16_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp16_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_dimname_shape_format(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_result_other(shape_format) + + def test_min_dimname_shape_format_fp16(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_result_other(shape_format) + + def test_min_dimname_out_shape_format(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_out_result_other(shape_format) + + def test_min_dimname_out_shape_format_fp16(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_out_result_other(shape_format) + + def test_amin_shape_format_fp16_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_2d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp16_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.amin_result(shape_format) + + +instantiate_device_type_tests(TestMin, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/MinKernelNpu.cpp b/torch_npu/csrc/aten/ops/MinKernelNpu.cpp new file mode 100644 index 0000000000..1940e03bea --- /dev/null +++ b/torch_npu/csrc/aten/ops/MinKernelNpu.cpp @@ -0,0 +1,185 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" + +namespace at_npu { +namespace native { + +tuple min_out_npu_nocheck( + const at::Tensor& self, + int64_t dim, + bool keepdim, + at::Tensor& output, + at::Tensor& indices) { + OpCommand cmd; + cmd.Name("ArgMinWithValue") + .Input(self) + .Output(indices) + .Output(output) + .Attr("dimension", dim) + .Attr("keep_dims", keepdim) + .Run(); + return std::tie(output, indices); +} + +tuple NPUNativeFunctions::min_out( + const at::Tensor& self, + int64_t dim, + bool keepdim, + at::Tensor& output, + at::Tensor& indices) { + c10::SmallVector dims = {dim}; + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + c10::SmallVector indicesSize = outputSize; + auto func = [&self, dim, keepdim](at::Tensor& output, at::Tensor& indices) { + min_out_npu_nocheck(self, dim, keepdim, output, indices); + }; + + at::Tensor indices_tmp; + OpPipeWithMultiOut pipe(output, indices_tmp); + return pipe.FixOutputSizeAndFormat<0>({self}, self, ACL_FORMAT_ND, outputSize) + .ApplyOutputWithSpecailParams<1>(indicesSize, self.options().dtype(at::ScalarType::Int), ACL_FORMAT_ND) + .Call(func) + .ReflushOutputDtype<1>(at::ScalarType::Long) + .FixOutputExceptDtype<1>({self}, ACL_FORMAT_ND, at::ScalarType::Long, indicesSize) + .FixOutputWithReplace<1>(indices) + .ReturnRef(); +} + +tuple NPUNativeFunctions::min(const at::Tensor& self, int64_t dim, bool keepdim) { + at::Tensor selfCast = self; + if(self.dtype() == at::ScalarType::Bool){ + selfCast = self.to(at::ScalarType::Float); + } + c10::SmallVector dims = {dim}; + auto outputSize = reduce_ops_npu_output_size(selfCast, dims, keepdim); + c10::SmallVector indicesSize = outputSize; + auto func = [&selfCast, dim, keepdim](at::Tensor outputs, at::Tensor indices) { + min_out_npu_nocheck(selfCast, dim, keepdim, outputs, indices); + }; + + at::Tensor outputs, indices; + OpPipeWithDefinedMultiOut pipe(outputs, indices); + std::tie(outputs, indices) = pipe.ApplyOutputWithSpecailParams<0>(outputSize, selfCast.options(), ACL_FORMAT_ND) + .ApplyOutputWithSpecailParams<1>(indicesSize, selfCast.options().dtype(at::ScalarType::Int), ACL_FORMAT_NCHW) + .Call(func) + .ReflushOutputDtype<1>(at::ScalarType::Long) + .Return(); + + if(self.dtype() == at::ScalarType::Bool){ + outputs = outputs.to(at::ScalarType::Bool); + } + return std::tie(outputs, indices); +} + +tuple NPUNativeFunctions::min_out( + const at::Tensor& self, + at::Dimname dim, + bool keepdim, + at::Tensor& output, + at::Tensor& indices) { + return min_out(self, dimname_to_position(self, dim), keepdim, output, indices); +} + +tuple NPUNativeFunctions::min(const at::Tensor& self, at::Dimname dim, bool keepdim) { + return min(self, dimname_to_position(self, dim), keepdim); +} + +at::Tensor& min_out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("Minimum") + .Input(self) + .Input(other) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::min_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + ACL_FORMAT_ND, + self.scalar_type(), + self.sizes()); + min_out_npu_nocheck(self, other, result); + return result; +} + +at::Tensor NPUNativeFunctions::min(const at::Tensor& self, const at::Tensor& other) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + min_out_npu_nocheck(self, other, result); + return result; +} + +at::Tensor& min_out_npu_nocheck( + const at::Tensor& self, + at::IntArrayRef dims, + bool keepdim, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("ReduceMin") + .Input(self) + .Input(dims) + .Output(result) + .Attr("keep_dims", keepdim) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::amin(const at::Tensor& self, at::IntArrayRef dims, bool keepdim) { + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + int64_t npu_format = CalcuOpUtil::get_tensor_npu_format(self); + if (outputSize.empty()) { + npu_format = ACL_FORMAT_NCHW; + } + at::Tensor result = OpPreparation::ApplyTensorWithFormat(self, outputSize, npu_format); + min_out_npu_nocheck(self, dims, keepdim, result); + return result; +} + +at::Tensor NPUNativeFunctions::min(const at::Tensor& self) { + c10::SmallVector dims = CalcuOpUtil::get_dimlist_for_tensor(self); + return amin(self, dims, false); +} + +at::Tensor& NPUNativeFunctions::amin_out( + const at::Tensor& self, + at::IntArrayRef dims, + bool keepdim, + at::Tensor& result) { + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + OpPreparation::CheckOut( + {self}, + result, + ACL_FORMAT_ND, + self.scalar_type(), + outputSize); + min_out_npu_nocheck(self, dims, keepdim, result); + return result; +} +} // namespace native +} // namespace at_npu -- Gitee From 0471f6c186036ed59508ae6b9c71494802b4b988 Mon Sep 17 00:00:00 2001 From: pipihugh <3213866847@qq.com> Date: Sat, 29 Jan 2022 09:26:06 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E5=87=8F=E5=B0=8FUT=E4=B8=AD=E7=9A=84size?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_min.py | 72 +++++++++++++++---------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/test/test_network_ops/test_min.py b/test/test_network_ops/test_min.py index bc6554df4a..c70c365bc3 100644 --- a/test/test_network_ops/test_min.py +++ b/test/test_network_ops/test_min.py @@ -55,7 +55,7 @@ class TestMin(TestCase): def cpu_op_dim_exec(self, input1, dim, keepdim): output1, output2 = torch.min(input1, dim, keepdim) output1 = output1.numpy() - output2 = output2.int().numpy() # 这里需要将索引从64位转32位 便于拿去与npu的对比 + output2 = output2.int().numpy() return output1, output2 def npu_op_dim_exec(self, input1, dim, keepdim): @@ -210,14 +210,14 @@ class TestMin(TestCase): def test_min_out_result(self, device): shape_format = [ - [[np.float16, 0, [128, 116, 14, 14]], [np.float16, 0, [256, 116, 1, 1]]], - [[np.float16, 0, [128, 58, 28, 28]], [np.float16, 0, [58, 58, 1, 1]]], - [[np.float16, 0, [128, 3, 224, 224]], [np.float16, 0, [3, 3, 3]]], - [[np.float16, 0, [128, 116, 14, 14]], [np.float16, 0, [128, 116, 14, 14]]], - [[np.float32, 0, [256, 128, 7, 7]], [np.float32, 0, [128, 256, 3, 3]]], - [[np.float32, 0, [256, 3, 224, 224]], [np.float32, 0, [3, 3, 7, 7]]], + [[np.float16, 0, [19, 16, 14, 14]], [np.float16, 0, [25, 16, 1, 1]]], + [[np.float16, 0, [19, 17, 28, 28]], [np.float16, 0, [17, 17, 1, 1]]], + [[np.float16, 0, [19, 3, 22, 22]], [np.float16, 0, [3, 3, 3]]], + [[np.float16, 0, [19, 16, 14, 14]], [np.float16, 0, [19, 16, 14, 14]]], + [[np.float32, 0, [25, 19, 7, 7]], [np.float32, 0, [19, 25, 3, 3]]], + [[np.float32, 0, [25, 3, 22, 22]], [np.float32, 0, [3, 3, 7, 7]]], [[np.float32, 0, [2, 3, 3, 3]], [np.float32, 0, [3, 1, 3]]], - [[np.float32, 0, [128, 232, 7, 7]], [np.float32, 0, [128, 232, 7, 7]]], + [[np.float32, 0, [19, 23, 7, 7]], [np.float32, 0, [19, 23, 7, 7]]], ] self.min_out_result_other(shape_format) @@ -239,7 +239,7 @@ class TestMin(TestCase): def test_min_shape_format_fp16_2d(self, device): format_list = [0, 3] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 1), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list ] self.min_result(shape_format) @@ -247,7 +247,7 @@ class TestMin(TestCase): def test_min_shape_format_fp32_2d(self, device): format_list = [0, 3] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 1), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list ] self.min_result(shape_format) @@ -255,7 +255,7 @@ class TestMin(TestCase): def test_min_shape_format_fp16_3d(self, device): format_list = [0, 3, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 1), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list ] self.min_result(shape_format) @@ -263,7 +263,7 @@ class TestMin(TestCase): def test_min_shape_format_fp32_3d(self, device): format_list = [0, 3, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 1), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list ] self.min_result(shape_format) @@ -271,7 +271,7 @@ class TestMin(TestCase): def test_min_shape_format_fp16_4d(self, device): format_list = [0, 4, 3, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 1), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list ] self.min_result(shape_format) @@ -279,7 +279,7 @@ class TestMin(TestCase): def test_min_shape_format_fp32_4d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 1), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list ] self.min_result(shape_format) @@ -302,7 +302,7 @@ class TestMin(TestCase): def test_min_dim_shape_format_fp16_2d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in keepdim_list ] self.min_result_dim(shape_format) @@ -310,7 +310,7 @@ class TestMin(TestCase): def test_min_dim_shape_format_fp32_2d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in keepdim_list ] self.min_result_dim(shape_format) @@ -318,7 +318,7 @@ class TestMin(TestCase): def test_min_dim_shape_format_fp16_3d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in keepdim_list ] self.min_result_dim(shape_format) @@ -326,7 +326,7 @@ class TestMin(TestCase): def test_min_dim_shape_format_fp32_3d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in keepdim_list ] self.min_result_dim(shape_format) @@ -334,7 +334,7 @@ class TestMin(TestCase): def test_min_dim_shape_format_fp16_4d(self, device): format_list = [0, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j in keepdim_list ] self.min_result_dim(shape_format) @@ -342,7 +342,7 @@ class TestMin(TestCase): def test_min_dim_shape_format_fp32_4d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j in keepdim_list ] self.min_result_dim(shape_format) @@ -365,7 +365,7 @@ class TestMin(TestCase): def test_min_other_shape_format_fp16_2d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in keepdim_list ] self.min_result_other(shape_format) @@ -373,7 +373,7 @@ class TestMin(TestCase): def test_min_other_shape_format_fp32_2d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in keepdim_list ] self.min_result_other(shape_format) @@ -381,7 +381,7 @@ class TestMin(TestCase): def test_min_other_shape_format_fp16_3d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in keepdim_list ] self.min_result_other(shape_format) @@ -389,7 +389,7 @@ class TestMin(TestCase): def test_min_other_shape_format_fp32_3d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in keepdim_list ] self.min_result_other(shape_format) @@ -397,7 +397,7 @@ class TestMin(TestCase): def test_min_other_shape_format_fp16_4d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j in keepdim_list ] self.min_result_other(shape_format) @@ -405,7 +405,7 @@ class TestMin(TestCase): def test_min_other_shape_format_fp32_4d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j in keepdim_list ] self.min_result_other(shape_format) @@ -413,7 +413,7 @@ class TestMin(TestCase): def test_min_dimname_shape_format(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + shape_format = [[[np.float32, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j in keepdim_list @@ -423,7 +423,7 @@ class TestMin(TestCase): def test_min_dimname_shape_format_fp16(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + shape_format = [[[np.float16, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j in keepdim_list @@ -433,7 +433,7 @@ class TestMin(TestCase): def test_min_dimname_out_shape_format(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + shape_format = [[[np.float32, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j in keepdim_list @@ -443,7 +443,7 @@ class TestMin(TestCase): def test_min_dimname_out_shape_format_fp16(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64, 34], ('N', 'C', 'H', 'W')], + shape_format = [[[np.float16, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j in keepdim_list @@ -468,7 +468,7 @@ class TestMin(TestCase): def test_amin_shape_format_fp16_2d(self, device): format_list = [0, 3, 4] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in keepdim_list ] self.amin_result(shape_format) @@ -476,7 +476,7 @@ class TestMin(TestCase): def test_amin_shape_format_fp32_2d(self, device): format_list = [0, 3, 4] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256]], np.random.randint(0, 2), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in keepdim_list ] self.amin_result(shape_format) @@ -484,7 +484,7 @@ class TestMin(TestCase): def test_amin_shape_format_fp16_3d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in keepdim_list ] self.amin_result(shape_format) @@ -492,7 +492,7 @@ class TestMin(TestCase): def test_amin_shape_format_fp32_3d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64]], np.random.randint(0, 3), j] for i in format_list for j in + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in keepdim_list ] self.amin_result(shape_format) @@ -500,7 +500,7 @@ class TestMin(TestCase): def test_amin_shape_format_fp16_4d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float16, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j in keepdim_list ] self.amin_result(shape_format) @@ -508,7 +508,7 @@ class TestMin(TestCase): def test_amin_shape_format_fp32_4d(self, device): format_list = [0, 3, 4, 29] keepdim_list = [True, False] - shape_format = [[[np.float32, i, [18, 256, 64, 34]], np.random.randint(0, 4), j] for i in format_list for j + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j in keepdim_list ] self.amin_result(shape_format) -- Gitee