From 6fedcb6b7880f3adc6017a3189cd762cbc6c9e31 Mon Sep 17 00:00:00 2001 From: yskhhh Date: Tue, 8 Mar 2022 10:18:43 +0800 Subject: [PATCH] refactor(NpuAny Type): reimplement Npu Any type replace c10::any with c10::Any, any_cast with CastAs and bad_any_cast with AnyCastException BREAKING CHANGE: Methods of c10::Any class have changed Before: c10::any; any_cast; bad_any_cast After: c10::Any; CastAs; AnyCastException --- .../csrc/framework/graph/util/ATenGeBridge.cpp | 10 +++++----- torch_npu/csrc/framework/graph/util/ATenGeBridge.h | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torch_npu/csrc/framework/graph/util/ATenGeBridge.cpp b/torch_npu/csrc/framework/graph/util/ATenGeBridge.cpp index 13b09dc738a..15935d061a0 100644 --- a/torch_npu/csrc/framework/graph/util/ATenGeBridge.cpp +++ b/torch_npu/csrc/framework/graph/util/ATenGeBridge.cpp @@ -58,7 +58,7 @@ at::Tensor ConstructCpuTenosr( template <> void ATenGeBridge::SetGeOpAttr> - (const c10::any& attr_val, ge::OperatorPtr ge_op) { + (const c10::Any& attr_val, ge::OperatorPtr ge_op) { auto attr = TryToGetAnyValue>(attr_val); ge_op->SetAttr(attr.first.c_str(), ge::AscendString(attr.second.c_str())); } @@ -142,7 +142,7 @@ ge::TensorDesc ATenGeBridge::InferGeTenosrDesc( template void ATenGeBridge::SetGeOpConstInput( - const c10::any& const_input, + const c10::Any& const_input, ge::OperatorPtr ge_op) { auto const_input_tuple = ATenGeBridge::TryToGetAnyValue(const_input); @@ -162,7 +162,7 @@ void ATenGeBridge::SetGeOpConstInput( } void ATenGeBridge::SetSensitiveFormat( - const c10::any& sensitive_format, + const c10::Any& sensitive_format, ge::OperatorPtr ge_op, NodeExtInfoType ext_type) { auto sensitive_format_pair = @@ -183,7 +183,7 @@ void ATenGeBridge::SetSensitiveFormat( } void ATenGeBridge::AddNodeExtInfoIntoGeOp( - c10::ArrayRef> ext_info, + c10::ArrayRef> ext_info, ge::OperatorPtr ge_op) { for (const auto& info : ext_info) { switch (info.first) { @@ -238,7 +238,7 @@ void ATenGeBridge::PorcessDynamicInputReg( auto it = std::find_if( ext_info.begin(), ext_info.end(), - [](const std::pair& item) { + [](const std::pair& item) { return item.first == NodeExtInfoType::DYNAMIC_INPUT_FUNC; }); if (it != ext_info.end()) { diff --git a/torch_npu/csrc/framework/graph/util/ATenGeBridge.h b/torch_npu/csrc/framework/graph/util/ATenGeBridge.h index 4c1bda361f7..eed7b984153 100644 --- a/torch_npu/csrc/framework/graph/util/ATenGeBridge.h +++ b/torch_npu/csrc/framework/graph/util/ATenGeBridge.h @@ -49,11 +49,11 @@ public: private: template - static T TryToGetAnyValue(const c10::any& any_val) { + static T TryToGetAnyValue(const c10::Any& any_val) { T val; try { - val = c10::any_cast(any_val); - } catch (c10::bad_any_cast &bd) { + val = c10::CastAs(any_val); + } catch (c10::AnyCastException& bd) { AT_ERROR(bd.what(), typeid(T).name()); } return val; @@ -61,11 +61,11 @@ private: template static void SetGeOpConstInput( - const c10::any& const_input, + const c10::Any& const_input, ge::OperatorPtr ge_op); static void SetSensitiveFormat( - const c10::any& sensitive_format, + const c10::Any& sensitive_format, ge::OperatorPtr ge_op, NodeExtInfoType ext_type); @@ -75,13 +75,13 @@ private: std::string op_name); template - static void SetGeOpAttr(const c10::any& attr_val, ge::OperatorPtr ge_op) { + static void SetGeOpAttr(const c10::Any& attr_val, ge::OperatorPtr ge_op) { AttrType attr = TryToGetAnyValue(attr_val); ge_op->SetAttr(attr.first.c_str(), attr.second); } static void AddNodeExtInfoIntoGeOp( - c10::ArrayRef> ext_info, + c10::ArrayRef> ext_info, ge::OperatorPtr ge_op); }; } // namespace native -- Gitee